diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..338e2200f70de1601b59d1320bf1be169b737ff3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,145 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+.vscode
+.vscode/*
+.DS_Store
+
+output/*
+work_dirs
+work_dirs/*
+work_dirs/
+
+data/temp/*
+slurm_tools/
+slurm_run.sh
+
+core*
+dist_url_*
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ade4360be1c7c6e21cccca653f18c961cbfcc2b5
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,224 @@
+Copyright (c) 2022 - present, SenseTime. All Rights Reserved.
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 - present, SenseTime
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+
+
+X-modaler
+
+Copyright 2021 Jingdong Technology Information Technology Co., Ltd
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
+
+
diff --git a/README.md b/README.md
index 154df8298fab5ecf322016157858e08cd1bccbe1..52e2a4b59e34f4059b609d19a7ce726ba7e92baa 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,201 @@
----
-license: apache-2.0
----
+# Uni-Perceiver
+
+This repository contains training (pre-training, fine-tuning, prompt-tuning), evaluation code and pretrained models for the following papers:
+
+
+> [Uni-Perceiver](https://arxiv.org/abs/2112.01522): Pre-training Unified Architecture for Generic Perception for Zero-shot and Few-shot Tasks, CVPR 2022.
+
+> [Uni-Perceiver-MoE](https://arxiv.org/abs/2206.04674): Learning Sparse Generalist Models with Conditional MoEs, NeurIPS 2022.
+
+
+
+## Introduction
+
+__Uni-Perceiver__ is a generalist model (generic perception model) that can process a variety of modalities and tasks with unified modeling and shared
+parameters. Different perception tasks are modeled as the same formulation, that is, finding the maximum likelihood target for each input through the similarity of their representations. Meanwhile, Uni-Perceiver is pre-trained on several uni-modal and multi-modal tasks, and evaluated on a variety of downstream tasks, including novel tasks that did not appear in the pre-training stage.
+Thanks to the unified formulation, it shows the ability of zero-shot inference on novel tasks, and shows promising performance close to or on par with SOTA results by prompt tuning or finetuning.
+
+![UnPerceiver-intro](./figs/overview.png)
+
+In __Uni-Perceiver-MoE__, we found that the interference among different tasks and modalities can lead to performance degradation of generalist models on some tasks compared with task-specialized models. We introduce the Conditional Mixture-of-Experts (Conditional MoEs) to mitigate such interference. By incorporating the proposed Conditional MoEs, Uni-Perceiver-MoE can effectively mitigate the interference across tasks and modalities, and achieves state-of-the-art results on a series of downstream tasks via prompt tuning on 1% of downstream data. Moreover, the introduction of Conditional MoEs still holds the generalization ability of generalist models to conduct zero-shot inference on new tasks,
+
+![UnPerceiver-moe-intro](./figs/overview_moe.png)
+
+
+## Main Results and Pretrained Models
+
+### Base Models
+
+
+
+ Task |
+ Image Classification |
+ Image Caption |
+ Image Retrieval |
+ Video Classification | Video Caption | Video Retrieval |
+
+
+ Dataset | ImageNet-1k | MSCOCO | Flickr30k | MSCOCO | Flickr30k | Kinetics-400 | MSVD | MSVD |
+
+
+ Split | ILSVRC 2012 val | Karpathy test | test | Karpathy test | test | test-dev | val | val | val |
+
+
+ Metric | Acc@1 | BLEU-4 | BLEU-4 | R@1 i2t | R@1 t2i | R@1 i2t | R@1 t2i | Acc@1 | BLEU-4 | R@1 v2t | R@1 t2v |
+
+
+
+ Uni-PerceiverBASE w/o Tuning | 79.2 | 32.0 | 14.7 | 64.9 | 50.7 | 82.3 | 71.1 | 74.5 | 22.6 | 50.3 | 38.7 |
+
+
+ Uni-PerceiverBASE PT (1%) | 80.9 | 35.5 | 30.2 | 68.4 | 51.9 | 91.0 | 76.0 | 74.8 | 59.5 | 62.7 | 43.8 |
+
+
+ Uni-PerceiverBASE FT (100%) | 84.0 | 36.4 | 31.2 | 69.8 | 53.9 | 92.7 | 77.5 | 77.7 | 63.3 | 62.8 | 45.8 |
+
+
+
+ Uni-Perceiver-MoEBASE w/o Tuning | 80.3 | 33.2 | 15.9 | 64.6 | 51.6 | 82.1 | 75.8 | 76.8 | 23.4 | 52.8 | 40.0 |
+
+
+ Uni-Perceiver-MoEBASE PT (1%) | 82.0 | 36.8 | 30.7 | 68.9 | 52.6 | 91.3 | 78.5 | 77.2 | 60.0 | 65.6 | 45.3 |
+
+
+ Uni-Perceiver-MoEBASE FT (100%) | 84.5 | 37.3 | 32.4 | 70.5 | 54.1 | 93.6 | 79.8 | 79.3 | 65.4 | 65.0 | 47.8 |
+
+
+
+
+### Large Models
+
+
+
+ Task |
+ Image Classification |
+ Image Caption |
+ Image Retrieval |
+ Video Classification | Video Caption | Video Retrieval |
+
+
+ Dataset | ImageNet-1k | MSCOCO | Flickr30k | MSCOCO | Flickr30k | Kinetics-400 | MSVD | MSVD |
+
+
+ Split | ILSVRC 2012 val | Karpathy test | test | Karpathy test | test | test-dev | val | val | val |
+
+
+ Metric | Acc@1 | BLEU-4 | BLEU-4 | R@1 i2t | R@1 t2i | R@1 i2t | R@1 t2i | Acc@1 | BLEU-4 | R@1 v2t | R@1 t2v |
+
+
+ Uni-PerceiverLARGE w/o Tuning | 82.7 | 35.3 | 15.1 | 67.8 | 54.1 | 83.7 | 74.2 | 79.5 | 24.7 | 45.4 | 34.2 |
+
+
+ Uni-PerceiverLARGE PT (1%) | 84.2 | 38.6 | 32.9 | 73.3 | 56.2 | 92.1 | 80.0 | 80.0 | 67.2 | 65.5 | 48.6 |
+
+
+ Uni-PerceiverLARGE FT (100%) | 86.2 | 39.2 | 35.5 | 74.4 | 57.9 | 94.7 | 82.1 | 81.9 | 68.3 | 65.2 | 50.8 |
+
+
+
+ Uni-Perceiver-MoELARGE w/o Tuning | 83.4 | 35.5 | 15.8 | 67.9 | 55.3 | 83.6 | 75.9 | 82.1 | 24.6 | 45.7 | 41.9 |
+
+
+ Uni-Perceiver-MoELARGE PT (1%) | 84.9 | 39.3 | 33.7 | 73.3 | 57.1 | 92.4 | 80.6 | 83.0 | 67.6 | 66.4 | 50.3 |
+
+
+ Uni-Perceiver-MoELARGE FT (100%) | 86.4 | 40.5 | 36.2 | 74.7 | 58.3 | 94.1 | 83.7 | 84.2 | 68.9 | 67.6 | 52.3 |
+
+
+
+ * The numbers are slightly better than the original paper of Uni-Perceiver, which are from the reproduced version of Uni-Perceiver used as the baseline of [Uni-Perceiver-MoE](https://arxiv.org/abs/2206.04674).
+ * The image resolution for all tasks is `224x224`.
+ * See [OtherResults.md](data/other_results.md) for results on more tasks and datasets.
+
+
+
+## Usage
+### Requirements
+* Linux, CUDA>=10.1, GCC>=5.4
+
+* Python >=3.7
+
+* pytorch >= 1.8.0
+
+* JAVA >= 1.8 (for caption task evaluation)
+
+
+### Installation
+```bash
+git clone https://github.com/fundamentalvision/Uni-Perceiver
+cd Uni-Perceiver
+pip install -r requirements.txt
+```
+
+
+### Data
+See [prepare_data.md](data/prepare_data.md).
+
+### Pre-trained Model Weights
+See [checkpoints.md](data/checkpoints.md).
+
+
+### Pre-training
+See [pretraining.md](data/pretraining.md).
+
+### Fine-tuning
+See [finetuning.md](data/finetuning.md).
+
+### Prompt-tuning
+
+See [prompt_tuning.md](data/prompt_tuning.md).
+
+
+### Inference
+
+See [inference.md](data/inference.md).
+
+### TODO
+
+* release more pretrained models
+ - [ ] Uni-Perceiver Tiny model
+ - [ ] Uni-Perceiver Small model
+ - [ ] Uni-Perceiver Huge model
+
+* support more datasets and tasks
+
+
+
+## License
+Uni-Perceiver is licensed under the [Apache-2.0 License](./LICENSE).
+
+
+
+
+## Citing Uni-Perceiver
+If you find Uni-Perceiver useful in your research, please consider giving a star ⭐ and citing:
+```bibtex
+ @article{zhu2021uni,
+ title={Uni-Perceiver: Pre-training Unified Architecture for Generic Perception for Zero-shot and Few-shot Tasks},
+ author={Zhu, Xizhou and Zhu, Jinguo and Li, Hao and Wu, Xiaoshi and Wang, Xiaogang and Li, Hongsheng and Wang, Xiaohua and Dai, Jifeng},
+ journal={arXiv preprint arXiv:2112.01522},
+ year={2021}
+
+}
+```
+
+```bibtex
+@article{zhu2022uni,
+ title={Uni-Perceiver-MoE: Learning Sparse Generalist Models with Conditional MoEs},
+ author={Zhu, Jinguo and Zhu, Xizhou and Wang, Wenhai and Wang, Xiaohua and Li, Hongsheng and Wang, Xiaogang and Dai, Jifeng},
+ journal={arXiv preprint arXiv:2206.04674},
+ year={2022}
+}
+```
+
+### Acknowledgements
+Many thanks to following codes that help us a lot in building this codebase:
+* [Detectron2](https://github.com/facebookresearch/detectron2)
+* [X-modaler](https://github.com/YehLi/xmodaler)
+* [deit](https://github.com/facebookresearch/deit)
+* [VL-BERT](https://github.com/jackroos/VL-BERT)
+* [TimeSformer](https://github.com/facebookresearch/TimeSformer)
+* [CLIP](https://github.com/openai/CLIP)
diff --git a/configs/BERT_L12_H192_experiments/4tasks_training.yaml b/configs/BERT_L12_H192_experiments/4tasks_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0253bc07cbd4199bfa4ca51b4f63fe5e9e7ce946
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/4tasks_training.yaml
@@ -0,0 +1,729 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ # VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 100
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ TEST: 'ImageTextPairDataset'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 77
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ TEST: 'ImageTextPairDataset'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ TASK_TYPE: caption
+ # DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ NUM_HIDDEN_LAYERS: 1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/4tasks_training_small_datasets.yaml b/configs/BERT_L12_H192_experiments/4tasks_training_small_datasets.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..898e9451359f64afdf0d453afd81adb5f6b7deeb
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/4tasks_training_small_datasets.yaml
@@ -0,0 +1,292 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'small_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'small_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'small_source_dataset/imagenet'
+ ANNO_FOLDER: 'small_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'small_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'small_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 1.0
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'small_source_dataset/mscoco_caption/coco_origin'
+ ANNO_FOLDER: 'small_source_dataset/mscoco_caption/annotations'
+ SEQ_PER_SAMPLE: 1
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'small_source_dataset/mscoco_caption/annotations/captions_val5k.json'
+ TEST_ANNFILE: 'small_source_dataset/mscoco_caption/annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 100
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'small_source_dataset/mscoco_caption/coco_origin'
+ ANNO_FOLDER: 'small_source_dataset/mscoco_caption/annotations'
+ SEQ_PER_SAMPLE: 1
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ NUM_HIDDEN_LAYERS: 1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..04c58939c78a4ffc781c6c20845bcd5b4338170e
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training.yaml
@@ -0,0 +1,416 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 2.5
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 3.5
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ ########## Image Captioning ###########
+
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.6889
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.8780
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5895
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.4618
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_apex_o2.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_apex_o2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b621cf05226a371d7bf72a2fc9201cbfdf046d15
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_apex_o2.yaml
@@ -0,0 +1,9 @@
+_BASE_: "7tasks_berttiny_training.yaml"
+
+####################################### Optimizer #######################################
+SOLVER:
+
+ AMP_FP16: False
+ APEX_FP16: True # dangerous
+ APEX_OPT_LEVEL: 'O2'
+ CHECKPOINT_PERIOD: 100000
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_lamb.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_lamb.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d90fef8986fc3fc80d20980b266ed3fb348b6f64
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_lamb.yaml
@@ -0,0 +1,418 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 2.5
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 3.5
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ ########## Image Captioning ###########
+
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.6889
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.8780
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5895
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.4618
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'LAMB'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.01
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..20d76047fdef481f79d07d34976d0e0fb7a6fe6b
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe.yaml
@@ -0,0 +1,25 @@
+_BASE_: "7tasks_berttiny_training.yaml"
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_lsfp32_gate_softmax_layernorm_fp16.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_lsfp32_gate_softmax_layernorm_fp16.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a2a387f5b5d9813385ce845d05d96cbcdd596de
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_lsfp32_gate_softmax_layernorm_fp16.yaml
@@ -0,0 +1,42 @@
+_BASE_: "7tasks_berttiny_training.yaml"
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
+
+MODEL:
+ LAYER_SCALE_FP32: True
+ GATE_FP32: False
+ TAG_TRANSFORM_FP32: False
+
+
+SOLVER:
+
+
+ FORCE_SOFTMAX_FP16: True
+ FORCE_LN_FP16: True
+ FORCE_NORM_FP16: True
+ # FORCE_TEMP_FP16: True
+ FORCE_EMBED_FP16: True
+
+ # FORCE_EXPERT_ADDING_FP16: True
diff --git a/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_scale_before.yaml b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_scale_before.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a6779bee5567deb9489e6359745fb25bab6ebb2d
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/7tasks_berttiny_training_moe_scale_before.yaml
@@ -0,0 +1,444 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 2.5
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 3.5
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ ########## Image Captioning ###########
+
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.6889
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.8780
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5895
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.4618
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+ SCALE_MULTI_BEFORE: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H192_experiments/base_model_bert_l12_h192.yaml b/configs/BERT_L12_H192_experiments/base_model_bert_l12_h192.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..87cfae8c5c53cdf6ad08269635874852c4e6a2c4
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/base_model_bert_l12_h192.yaml
@@ -0,0 +1,73 @@
+
+######################################### MODEL #########################################
+MODEL:
+ VOCAB_SIZE: 49411 # include /
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder_v3'
+ ENCODER_DIM: 192
+ DECODER: ''
+ DECODER_DIM: 192
+
+ PREDICTOR: 'EmbedClsAsRetrievalPredictor'
+ FEATURE_GATHER: True
+ LEARN_TEMP: True
+ PRED_USE_NORM: True
+ PRED_TEMPERATURE: 0.07
+
+ BertParamsInit: True
+
+ CLS_TOKEN: False
+
+ QUEUE_LEN: 1024
+ MAX_LABEL_LEN: 12
+
+ OUTPUT_PROJ: True # output projection
+
+
+# #################################### Token embedding ####################################
+ TOKEN_EMBED:
+ NAME: 'TokenBaseEmbedding'
+ DIM: 192
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ POSITION: 'NNEmbeddingEncoding'
+ POSITION_MAX_LEN: 512
+ TYPE_VOCAB_SIZE: 2
+
+# #################################### Visual embedding ####################################
+ VISUAL_EMBED:
+ NAME: 'none'
+
+# #################################### video embedding ####################################
+ VIDEO_EMBED:
+ NAME: 'VideoBaseEmbedding'
+ IN_DIM: 768
+ OUT_DIM: 192
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ TYPE_SIZE: 1 # video to encoder
+ POSITION: 'NNEmbeddingEncoding'
+ MAX_LENGTH: 1600
+ PATCH_SIZE_S: 16
+ PATCH_SIZE_T: 1
+ DIVIDE_ST_POS: True
+ USE_VISUAL_TOKENIZER: True
+ USE_VISUAL_POS: True
+ MAX_FRAMES: 8
+
+####################################### BERT ############################################
+ BERT:
+ DROP_PATH_PROB: 0.0
+ HIDDEN_SIZE: 192
+ HIDDEN_DROPOUT_PROB: 0.
+ HIDDEN_ACT: "gelu"
+ NUM_ATTENTION_HEADS: 3
+ INTERMEDIATE_SIZE: 768
+ INTERMEDIATE_DROP: 0.
+ FFN_DROPOUT_PROB: 0.
+ ATTENTION_PROBS_DROPOUT_PROB: 0.
+ NUM_HIDDEN_LAYERS: 12
+ NUM_GENERATION_LAYERS: 0
+
\ No newline at end of file
diff --git a/configs/BERT_L12_H192_experiments/in1k_training.yaml b/configs/BERT_L12_H192_experiments/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f667ba49473c082ac7daf008cb5edbe6e0c5ad6d
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/in1k_training.yaml
@@ -0,0 +1,197 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ NUM_HIDDEN_LAYERS: 1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+
+ OLD_CHECKPONT: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ # CHECKPOINT_FILETER: False
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/in1k_training_moe.yaml b/configs/BERT_L12_H192_experiments/in1k_training_moe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..58e728fda725436df4c0e6b7afa087355f16078e
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/in1k_training_moe.yaml
@@ -0,0 +1,219 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+
+ OLD_CHECKPONT: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ # CHECKPOINT_FILETER: False
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H192_experiments/moe_debug.yaml b/configs/BERT_L12_H192_experiments/moe_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aac92d63cd4dfbf860be8a2a6d333337f3439f01
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/moe_debug.yaml
@@ -0,0 +1,536 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ # -
+ # NAME: 'ImageNet1k'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ # -
+ # NAME: 'Kinetics400'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: False
+
+
+
+TASKS:
+
+ # -
+ # NAME: imagenet
+ # DATASETS:
+ # TRAIN: 'ImageNetDataset'
+ # VAL: 'ImageNetDataset'
+ # TASK_TYPE: 'image_classification'
+ # DATASET_NAME: 'ImageNet1k'
+ # TARGET_SET: ['ImageNet1k']
+
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 720
+ # # TEST_BATCH_SIZE: 2
+ # NUM_WORKERS: 4
+ # FEATS_FOLDER: 'cluster2:s3://imagenet'
+ # ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ # SAMPLING_WEIGHT: 2.5
+ # CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ # MIXUP: 0.8
+ # CUTMIX: 1.0
+ # MIXUP_PROB: 1.0
+ # MIXUP_SWITCH_PROB: 0.5
+ # MIXUP_MODE: 'batch'
+ # MIXUP_LABEL_SMOOTHING: 0.1
+ # MODEL:
+ # MAX_SEQ_LEN: -1
+ # LABELS_NUM: 1000
+ # TEMP_NAME: logit_scale_img_cls
+ # LOSSES:
+ # NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # # LOSS_FP32: True
+ # INFERENCE:
+ # NAME: 'ImageNetEvaler'
+ # ID_KEY: 'image_id'
+ # VALUE: 'cls_logits'
+ # VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # TEST_ANNFILE: ''
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: K400_retrieve
+ # DATASETS:
+ # TRAIN: 'VideoDataSet'
+ # VAL: 'VideoDataSet'
+ # TASK_TYPE: 'video_classification'
+ # DATASET_NAME: 'K400'
+ # TARGET_SET: ['Kinetics400']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 12 # 256
+ # TEST_BATCH_SIZE: 4 # debug
+ # NUM_WORKERS: 4 # debug 4
+ # FEATS_FOLDER: 'open_source_dataset/K400_official'
+ # ANNO_FOLDER: 'open_source_dataset/K400_official'
+ # S3_PATH: 's3://K400/'
+ # FRAMES_PER_CLIP: 8
+ # STRIDE: 32
+ # FILE_EXTENSION: ''
+ # ANNO_FILE: 'annotation.json'
+ # TIMESFORMER_AUG: True
+ # SAMPLING_WEIGHT: 1.0
+ # MODEL:
+ # MAX_SEQ_LEN: -1
+ # TEMP_NAME: logit_scale_video_cls
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1.0
+ # INFERENCE:
+ # NAME: 'MiTEvaler'
+ # ID_KEY: 'video_name'
+ # VALUE: 'label'
+ # VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ # TEST_ANNFILE: ''
+ # GENERATION_MODE: False
+ # NUM_VIEWS: 1
+
+ # -
+ # NAME: bookswiki_pretrain
+ # DATASETS:
+ # TRAIN: 'GeneralCorpusDataset'
+ # TASK_TYPE: 'text_mlm'
+ # DATASET_NAME: 'BooksWiki'
+ # TARGET_SET: ['Vocab_Word']
+ # VERSION: 'v2'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 512
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # SEQ_PER_SAMPLE: 128
+ # MIN_SEQ_PER_SAMPLE: 128
+ # APPEND_EOS: True
+ # ONE_STREAM: False
+ # SAMPLING_WEIGHT: 3.5
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 128
+ # TEMP_NAME: logit_scale_text_mlm
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 100
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.0
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ ########## Image Captioning ###########
+
+
+ # -
+ # NAME: cc12m_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'CC12M'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # S3_ANNO_FOLDER: 's3://cc12m/'
+ # ANNO_FOLDER: 'open_source_dataset/c12m/'
+ # ANNO_FILENAME: 'train_available.json'
+ # FEATS_FOLDER: 'open_source_dataset/c12m/'
+ # S3_PATH: 's3://cc12m/'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.6889
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: cc3m_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'CC3M'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # ANNO_FOLDER: 's3://cc3m/'
+ # ANNO_FILENAME: 'train_spacy.json'
+ # FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ # S3_PATH: 's3://cc3m/'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.8780
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: vg_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'VG'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ # ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ # S3_PATH: 's3://visual_genome/images'
+ # ANNO_FILENAME: 'vg_captions_128filter.json'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5895
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: True
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ # -
+ # NAME: sbu_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'SBU'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # S3_ANNO_FOLDER: 's3://SBU/annotations'
+ # ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ # ANNO_FILENAME: 'subcaption.json'
+ # FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ # S3_PATH: 's3://SBU/images'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.4618
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+ LAYER_SCALE_FP32: True
+ GATE_FP32: False
+ TAG_TRANSFORM_FP32: False
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+ STRATEGY: 'turn'
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+ FORCE_SOFTMAX_FP16: True
+ FORCE_LN_FP16: True
+ FORCE_NORM_FP16: True
+ # FORCE_TEMP_FP16: True
+ FORCE_EMBED_FP16: True
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H192_experiments/moe_debug_load_ds_checkpoint.yaml b/configs/BERT_L12_H192_experiments/moe_debug_load_ds_checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d5a1effcc890943775e710151a94af90466950ce
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/moe_debug_load_ds_checkpoint.yaml
@@ -0,0 +1,541 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+ # -
+ # NAME: 'ImageNet1k'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ # -
+ # NAME: 'Kinetics400'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: False
+
+
+
+TASKS:
+
+ # -
+ # NAME: imagenet
+ # DATASETS:
+ # TRAIN: 'ImageNetDataset'
+ # VAL: 'ImageNetDataset'
+ # TASK_TYPE: 'image_classification'
+ # DATASET_NAME: 'ImageNet1k'
+ # TARGET_SET: ['ImageNet1k']
+
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 720
+ # # TEST_BATCH_SIZE: 2
+ # NUM_WORKERS: 4
+ # FEATS_FOLDER: 'cluster2:s3://imagenet'
+ # ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ # SAMPLING_WEIGHT: 2.5
+ # CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ # MIXUP: 0.8
+ # CUTMIX: 1.0
+ # MIXUP_PROB: 1.0
+ # MIXUP_SWITCH_PROB: 0.5
+ # MIXUP_MODE: 'batch'
+ # MIXUP_LABEL_SMOOTHING: 0.1
+ # MODEL:
+ # MAX_SEQ_LEN: -1
+ # LABELS_NUM: 1000
+ # TEMP_NAME: logit_scale_img_cls
+ # LOSSES:
+ # NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # # LOSS_FP32: True
+ # INFERENCE:
+ # NAME: 'ImageNetEvaler'
+ # ID_KEY: 'image_id'
+ # VALUE: 'cls_logits'
+ # VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # TEST_ANNFILE: ''
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: K400_retrieve
+ # DATASETS:
+ # TRAIN: 'VideoDataSet'
+ # VAL: 'VideoDataSet'
+ # TASK_TYPE: 'video_classification'
+ # DATASET_NAME: 'K400'
+ # TARGET_SET: ['Kinetics400']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 12 # 256
+ # TEST_BATCH_SIZE: 4 # debug
+ # NUM_WORKERS: 4 # debug 4
+ # FEATS_FOLDER: 'open_source_dataset/K400_official'
+ # ANNO_FOLDER: 'open_source_dataset/K400_official'
+ # S3_PATH: 's3://K400/'
+ # FRAMES_PER_CLIP: 8
+ # STRIDE: 32
+ # FILE_EXTENSION: ''
+ # ANNO_FILE: 'annotation.json'
+ # TIMESFORMER_AUG: True
+ # SAMPLING_WEIGHT: 1.0
+ # MODEL:
+ # MAX_SEQ_LEN: -1
+ # TEMP_NAME: logit_scale_video_cls
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1.0
+ # INFERENCE:
+ # NAME: 'MiTEvaler'
+ # ID_KEY: 'video_name'
+ # VALUE: 'label'
+ # VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ # TEST_ANNFILE: ''
+ # GENERATION_MODE: False
+ # NUM_VIEWS: 1
+
+ # -
+ # NAME: bookswiki_pretrain
+ # DATASETS:
+ # TRAIN: 'GeneralCorpusDataset'
+ # TASK_TYPE: 'text_mlm'
+ # DATASET_NAME: 'BooksWiki'
+ # TARGET_SET: ['Vocab_Word']
+ # VERSION: 'v2'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 512
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # SEQ_PER_SAMPLE: 128
+ # MIN_SEQ_PER_SAMPLE: 128
+ # APPEND_EOS: True
+ # ONE_STREAM: False
+ # SAMPLING_WEIGHT: 3.5
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 128
+ # TEMP_NAME: logit_scale_text_mlm
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 100
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.0
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ ########## Image Captioning ###########
+
+
+ # -
+ # NAME: cc12m_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'CC12M'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # S3_ANNO_FOLDER: 's3://cc12m/'
+ # ANNO_FOLDER: 'open_source_dataset/c12m/'
+ # ANNO_FILENAME: 'train_available.json'
+ # FEATS_FOLDER: 'open_source_dataset/c12m/'
+ # S3_PATH: 's3://cc12m/'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.6889
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: cc3m_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'CC3M'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # S3_ANNO_FOLDER: 's3://cc3m/'
+ # ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ # ANNO_FILENAME: 'train_spacy.json'
+ # FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ # S3_PATH: 's3://cc3m/'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.8780
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: vg_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'VG'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ # ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ # S3_PATH: 's3://visual_genome/images'
+ # ANNO_FILENAME: 'vg_captions_128filter.json'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5895
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: True
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ # -
+ # NAME: sbu_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'SBU'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 300
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # S3_ANNO_FOLDER: 's3://SBU/annotations'
+ # ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ # ANNO_FILENAME: 'subcaption.json'
+ # FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ # S3_PATH: 's3://SBU/images'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.4618
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ VIDEO_EMBED:
+ ADD_TYPE_EMBED: True
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+ STRATEGY: 'turn'
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+ FORCE_SOFTMAX_FP16: True
+ FORCE_LN_FP16: True
+ FORCE_NORM_FP16: True
+ # FORCE_TEMP_FP16: True
+ FORCE_EMBED_FP16: True
+
+# # used for debug only
+ FORCE_WG_RECAST: True
+ FORCE_EXPERT_ADDING_FP16: True
+
+ # !!! note that the VIDEO_EMBED.ADD_TYPE_EMBED=True is current config
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'all' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H192_experiments/mscoco_caption_debug.yaml b/configs/BERT_L12_H192_experiments/mscoco_caption_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd2d1dea31ad48580b6f3a6a3caf2a5ee5167a4c
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/mscoco_caption_debug.yaml
@@ -0,0 +1,234 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+
+
+TASKS:
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 100
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.3817
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.33333
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 150000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H192_experiments/vqa_debug.yaml b/configs/BERT_L12_H192_experiments/vqa_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f6550264665961d39dccd591be60fc601be8aac4
--- /dev/null
+++ b/configs/BERT_L12_H192_experiments/vqa_debug.yaml
@@ -0,0 +1,189 @@
+_BASE_: "base_model_bert_l12_h192.yaml"
+
+SHARED_TARGETS:
+
+
+ -
+ NAME: 'VQA_Answer'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/VQA_Answers_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: vqa
+ DATASETS:
+ TRAIN: 'VQADataset'
+ VAL: 'VQADataset'
+ DATASET_NAME: 'VQA'
+ TASK_TYPE: 'vqa'
+ TARGET_SET: ['VQA_Answer']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/VQA'
+ SEQ_PER_SAMPLE: 1
+ MAX_FEAT_NUM: 51
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ DO_AS_GEN: True
+ SINGLE_CLASS: True
+ MODEL:
+ # VOCAB_SIZE: 49409 # include /
+ PREDICTOR: 'MLPClassifer'
+ # MM_PREDICTOR:
+ # LABELS_NUM: 3129
+ # PREDICT: 'first_one'
+ # PRED_DROPOUT: 0.5
+ MAX_SEQ_LEN: 23
+ # QUERY_EMBED:
+ # NAME: QueryBaseEmbedding
+ # DIM: 512
+ # QUERY_SIZE: 10 # more than 1 is ok
+ # ACTIVATION: 'none'
+ # USE_NORM: True
+ # DROPOUT: 0.1
+ # POSITION: 'none' # must be none now
+ # TYPE_VOCAB_SIZE: -1 # must < 0
+ LOSSES:
+ # not single class
+ # NAMES: ['BCEWithLogits']
+ # LOSS_WEIGHT: 0.05
+ # for single class
+ NAMES: ['CrossEntropy']
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ NAME: 'VQAEvaler'
+ ID_KEY: 'question_id'
+ VALUE: 'answer'
+ VAL_ANNFILE: 'open_source_dataset/VQA/val_target.pkl'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+######################################### Engine #########################################
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+######################################### Scheduled sampling #########################################
+SCHEDULED_SAMPLING:
+ START_EPOCH: 0
+ INC_EVERY_EPOCH: 5
+ INC_PROB: 0.05
+ MAX_PROB: 0.25
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+
+######################################### MODEL #########################################
+MODEL:
+ TEMP_NAME: logit_scale_downstream
+ # VOCAB_SIZE: 49409 # include /
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+ # ENCODER_DIM: 512
+ # DECODER: 'UnifiedTransformerDecoder'
+ # DECODER_DIM: 512
+
+ BertParamsInit: True
+ # WEIGHTS: open_source_dataset/our_model/cc3m_encoder_decoder_warm1w_150k_retrivetask_gatherfeature_caption_mlm/model_Epoch_90000_Iter_0089999.pth
+
+ CLS_TOKEN: True
+ # PREDICTOR: 'BasePredictor'
+ # PRED_DROPOUT: 0.5
+ # MAX_SEQ_LEN: 20
+
+# #################################### Token embedding ####################################
+ # TOKEN_EMBED:
+ # NAME: 'TokenBaseEmbedding'
+ # DIM: 512
+ # ACTIVATION: 'none'
+ # USE_NORM: True
+ # DROPOUT: 0.1
+ # POSITION: 'NNEmbeddingEncoding'
+ # POSITION_MAX_LEN: 512
+ # TYPE_VOCAB_SIZE: 2
+
+# #################################### Visual embedding ####################################
+ # VISUAL_EMBED:
+ # NAME: 'VisualPatchEmbedding'
+ # IN_DIM: 3
+ # OUT_DIM: 512
+ # ACTIVATION: 'none'
+ # USE_NORM: True
+ # DROPOUT: 0.0
+ # PATCH_SIZE: 16
+
+####################################### BERT ############################################
+ BERT:
+ DROP_PATH_PROB: 0.05
+ # HIDDEN_SIZE: 512
+ HIDDEN_SIZE: 192
+ HIDDEN_DROPOUT_PROB: 0.
+ HIDDEN_ACT: "gelu"
+ NUM_ATTENTION_HEADS: 8
+ INTERMEDIATE_SIZE: 2048
+ INTERMEDIATE_DROP: 0.
+ FFN_DROPOUT_PROB: 0.
+ ATTENTION_PROBS_DROPOUT_PROB: 0.
+ NUM_HIDDEN_LAYERS: 6
+ NUM_GENERATION_LAYERS: 6
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'AdamW'
+ # EPOCH: 1
+ MAX_ITER: 30000
+ CHECKPOINT_PERIOD: 5000
+ CHECKPOINT_MAX_SAVE: 5
+ EVAL_PERIOD: 1000
+ BASE_LR: 0.00005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.01
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-8
+ GRAD_CLIP: 5.0
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ CHECKPOINT_MAPPING:
+ # -
+ # ORIGIN: cc3m_caption
+ # DEST: mscoco
+ -
+ ORIGIN: cc3m_retrieve
+ DEST: flickr30k
+
+ CHECKPOINT_MAP: True
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 1000
+ MIN_LR: 0.00000001
+
+# ####################################### losses #######################################
+# LOSSES:
+# NAMES: ['LabelSmoothing']
+# LABELSMOOTHING: 0.1
+
+####################################### decode strategy #######################################
+# DECODE_STRATEGY:
+# NAME: 'BeamSearcher'
+# BEAM_SIZE: 2
+
+####################################### evaluation #######################################
+INFERENCE:
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H384_experiments/base_model_bert_l12_h384.yaml b/configs/BERT_L12_H384_experiments/base_model_bert_l12_h384.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4cc5b5a6270526886460421dc33d925b84f36226
--- /dev/null
+++ b/configs/BERT_L12_H384_experiments/base_model_bert_l12_h384.yaml
@@ -0,0 +1,80 @@
+
+######################################### MODEL #########################################
+MODEL:
+ VOCAB_SIZE: 49411 # include /
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: ''
+ ENCODER_DIM: 384
+ DECODER: ''
+ DECODER_DIM: 384
+
+ PREDICTOR: 'EmbedClsAsRetrievalPredictor'
+ FEATURE_GATHER: True
+ LEARN_TEMP: True
+ PRED_USE_NORM: True
+ PRED_TEMPERATURE: 0.07
+
+ BertParamsInit: True
+
+ CLS_TOKEN: False
+
+ QUEUE_LEN: 1024
+ MAX_LABEL_LEN: 12
+
+ OUTPUT_PROJ: True # output projection
+
+
+# #################################### Token embedding ####################################
+ TOKEN_EMBED:
+ NAME: 'TokenBaseEmbedding'
+ DIM: 384
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ POSITION: 'NNEmbeddingEncoding'
+ POSITION_MAX_LEN: 512
+ TYPE_VOCAB_SIZE: 2
+
+# #################################### Visual embedding ####################################
+ VISUAL_EMBED:
+ NAME: 'VisualPatchEmbedding'
+ IN_DIM: 3
+ OUT_DIM: 384
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ PATCH_SIZE: 16
+ TYPE_SIZE: 1 # image to encoder
+
+# #################################### video embedding ####################################
+ VIDEO_EMBED:
+ NAME: 'VideoBaseEmbedding'
+ IN_DIM: 768
+ OUT_DIM: 384
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ TYPE_SIZE: 1 # video to encoder
+ POSITION: 'NNEmbeddingEncoding'
+ MAX_LENGTH: 1600
+ PATCH_SIZE_S: 16
+ PATCH_SIZE_T: 1
+ DIVIDE_ST_POS: True
+ USE_VISUAL_TOKENIZER: True
+ USE_VISUAL_POS: True
+ MAX_FRAMES: 8
+
+####################################### BERT ############################################
+ BERT:
+ DROP_PATH_PROB: 0.1
+ HIDDEN_SIZE: 384
+ HIDDEN_DROPOUT_PROB: 0.
+ HIDDEN_ACT: "gelu"
+ NUM_ATTENTION_HEADS: 6
+ INTERMEDIATE_SIZE: 1536
+ INTERMEDIATE_DROP: 0.
+ FFN_DROPOUT_PROB: 0.
+ ATTENTION_PROBS_DROPOUT_PROB: 0.
+ NUM_HIDDEN_LAYERS: 12
+ NUM_GENERATION_LAYERS: 0
+
\ No newline at end of file
diff --git a/configs/BERT_L12_H384_experiments/in1k_training.yaml b/configs/BERT_L12_H384_experiments/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5865f1d7c0657fa32c9a067d485828217b99fb40
--- /dev/null
+++ b/configs/BERT_L12_H384_experiments/in1k_training.yaml
@@ -0,0 +1,189 @@
+_BASE_: "base_model_bert_l12_h384.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 10
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.3
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training.yaml b/configs/BERT_L12_H768_experiments/16tasks_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c15835ed4a5e09c8f86c00ca8d963e496710ddf1
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training.yaml
@@ -0,0 +1,738 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 4
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.76
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 112
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.44
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 160
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 400000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.2
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 10000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_apex_o2.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_apex_o2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1ffcab94316ed54329cbec477dbfb8251486d5c4
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_apex_o2.yaml
@@ -0,0 +1,11 @@
+_BASE_: "16tasks_training.yaml"
+
+####################################### Optimizer #######################################
+SOLVER:
+
+ AMP_FP16: False
+ APEX_FP16: True # dangerous
+ APEX_OPT_LEVEL: 'O2'
+ MIN_LOSS_SCLE: 128.0
+ CHECKPOINT_PERIOD: 10000
+
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage1_64gpu.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage1_64gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bcff2de8772a3d093d44cf75790278c020a0fadb
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage1_64gpu.yaml
@@ -0,0 +1,739 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 4
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.76
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 112
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.44
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 160
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+ OLD_CHECKPONT: True
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.2
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 10000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage2_64gpu.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage2_64gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b6e74236aa5ce2f050ae90741436cc3ba3a35d6a
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage2_64gpu.yaml
@@ -0,0 +1,750 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 440
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # VAL_ANNFILE: '/mnt/lustrenew/lihao2/projects/xmodaler_2/val_debug.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 12
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.05
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 68
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.2
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.05
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.25
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+ OLD_CHECKPONT: True
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 45000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage1_56gpu.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage1_56gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0b316a5ea8bb64e0dc18551d34844eae88f7aebf
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage1_56gpu.yaml
@@ -0,0 +1,733 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 720
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 4
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.76
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 112
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.44
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 300
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: True
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 160
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE_FP32: True
+ GATE_FP32: False
+ TAG_TRANSFORM_FP32: False
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 230000
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.2
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 10000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage2_56gpu.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage2_56gpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ed525b33cbf9f78278a5dd8feda49aadf29fd79e
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage2_56gpu.yaml
@@ -0,0 +1,744 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 440
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # VAL_ANNFILE: '/mnt/lustrenew/lihao2/projects/xmodaler_2/val_debug.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 12
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 68
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.2
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE_FP32: True
+ GATE_FP32: False
+ TAG_TRANSFORM_FP32: False
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 50000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/16tasks_training_stage2_64gpu_v1.yaml b/configs/BERT_L12_H768_experiments/16tasks_training_stage2_64gpu_v1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..009470c7dec847430a0161007159042553356f49
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/16tasks_training_stage2_64gpu_v1.yaml
@@ -0,0 +1,750 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet22k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+ -
+ NAME: 'MomentsInTime'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ -
+ NAME: 'Kinetics700'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+
+ -
+ NAME: imagenet22k
+ DATASETS:
+ TRAIN: 'ImageNet22KDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet22k'
+ TARGET_SET: ['ImageNet22k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 440
+ # TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/imagenet22k'
+ S3_PATH: 'cluster2:s3://imagenet22k'
+ ANNO_FOLDER: 'open_source_dataset/'
+ SAMPLING_WEIGHT: 2.486
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 21842
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # VAL_ANNFILE: '/mnt/lustrenew/lihao2/projects/xmodaler_2/val_debug.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ -
+ NAME: K700_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K700'
+ TARGET_SET: ['Kinetics700']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 12
+ TEST_BATCH_SIZE: 24
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/K700'
+ ANNO_FOLDER: 'open_source_dataset/K700'
+ S3_PATH: 's3://K700/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: MomentsInTime
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'MiT'
+ TARGET_SET: ['MomentsInTime']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 68
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/MomentsInTime'
+ ANNO_FOLDER: 'open_source_dataset/MomentsInTime'
+ S3_PATH: 's3://MomentsInTime/'
+ FRAMES_PER_CLIP: 3
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.2
+
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/MomentsInTime/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ VERSION: 'v2'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 512
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/text_corpus' # 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 2.75
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ -
+ NAME: yfcc_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'YFCC'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC12M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'CC3M'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'VG'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: True
+
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 50
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ -
+ NAME: sbu_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'SBU'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: yfcc_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'YFCC'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 'cluster2:s3://yfcc'
+ ANNO_FOLDER: 'open_source_dataset/yfcc'
+ ANNO_FILENAME: 'yfcc100m_subset_available_untokenized.json'
+ FEATS_FOLDER: 'open_source_dataset/yfcc/'
+ S3_PATH: 'cluster2:s3://yfcc/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5840
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc12m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC12M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc12m/'
+ ANNO_FOLDER: 'open_source_dataset/c12m/'
+ ANNO_FILENAME: 'train_available.json'
+ FEATS_FOLDER: 'open_source_dataset/c12m/'
+ S3_PATH: 's3://cc12m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5057
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: cc3m_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'CC3M'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ S3_ANNO_FOLDER: 's3://cc3m/'
+ ANNO_FOLDER: 'open_source_dataset/cc3m/'
+ ANNO_FILENAME: 'train_spacy.json'
+ FEATS_FOLDER: 'open_source_dataset/cc3m/'
+ S3_PATH: 's3://cc3m/'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.26295
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: vg_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'VG'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/visual_genome/images'
+ ANNO_FOLDER: 'open_source_dataset/visual_genome/annotations'
+ S3_PATH: 's3://visual_genome/images'
+ ANNO_FILENAME: 'vg_captions_128filter.json'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1766
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1144
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+ -
+ NAME: sbu_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'SBU'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 320
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 1
+ S3_ANNO_FOLDER: 's3://SBU/annotations'
+ ANNO_FOLDER: 'open_source_dataset/sbucaption/annotations'
+ ANNO_FILENAME: 'subcaption.json'
+ FEATS_FOLDER: 'open_source_dataset/sbucaption/'
+ S3_PATH: 's3://SBU/images'
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.1383
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 50
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+ OLD_CHECKPONT: True
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 32
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 45000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 10000000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 5000
+ MIN_LR: 0.000001
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/base_model_bert_l12_h768.yaml b/configs/BERT_L12_H768_experiments/base_model_bert_l12_h768.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bc3040352ae5aa111e6e459824e11138404444c5
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/base_model_bert_l12_h768.yaml
@@ -0,0 +1,73 @@
+
+######################################### MODEL #########################################
+MODEL:
+ VOCAB_SIZE: 49411 # include /
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder_v3'
+ ENCODER_DIM: 768
+ DECODER: ''
+ DECODER_DIM: 768
+
+ PREDICTOR: 'EmbedClsAsRetrievalPredictor'
+ FEATURE_GATHER: True
+ LEARN_TEMP: True
+ PRED_USE_NORM: True
+ PRED_TEMPERATURE: 0.07
+
+ BertParamsInit: True
+
+ CLS_TOKEN: False
+
+ QUEUE_LEN: 1024
+ MAX_LABEL_LEN: 12
+
+ OUTPUT_PROJ: True # output projection
+
+
+# #################################### Token embedding ####################################
+ TOKEN_EMBED:
+ NAME: 'TokenBaseEmbedding'
+ DIM: 768
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ POSITION: 'NNEmbeddingEncoding'
+ POSITION_MAX_LEN: 512
+ TYPE_VOCAB_SIZE: 2
+
+# #################################### Visual embedding ####################################
+ VISUAL_EMBED:
+ NAME: 'none'
+
+# #################################### video embedding ####################################
+ VIDEO_EMBED:
+ NAME: 'VideoBaseEmbedding'
+ IN_DIM: 768
+ OUT_DIM: 768
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ TYPE_SIZE: 1 # video to encoder
+ POSITION: 'NNEmbeddingEncoding'
+ MAX_LENGTH: 1600
+ PATCH_SIZE_S: 16
+ PATCH_SIZE_T: 1
+ DIVIDE_ST_POS: True
+ USE_VISUAL_TOKENIZER: True
+ USE_VISUAL_POS: True
+ MAX_FRAMES: 8
+
+####################################### BERT ############################################
+ BERT:
+ DROP_PATH_PROB: 0.1
+ HIDDEN_SIZE: 768
+ HIDDEN_DROPOUT_PROB: 0.
+ HIDDEN_ACT: "gelu"
+ NUM_ATTENTION_HEADS: 12
+ INTERMEDIATE_SIZE: 3072
+ INTERMEDIATE_DROP: 0.
+ FFN_DROPOUT_PROB: 0.
+ ATTENTION_PROBS_DROPOUT_PROB: 0.
+ NUM_HIDDEN_LAYERS: 12
+ NUM_GENERATION_LAYERS: 0
+
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/bw_mlm_training.yaml b/configs/BERT_L12_H768_experiments/bw_mlm_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c3ebb7acc0c6185ed36178491263bbd54e84f3ee
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/bw_mlm_training.yaml
@@ -0,0 +1,309 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ # -
+ # NAME: 'ImageNet1k'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: False
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+ # -
+ # NAME: imagenet
+ # DATASETS:
+ # TRAIN: 'ImageNetDataset'
+ # # VAL: 'ImageNetDataset'
+ # TASK_TYPE: 'image_classification'
+ # DATASET_NAME: 'ImageNet1k'
+ # TARGET_SET: ['ImageNet1k']
+
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # # TEST_BATCH_SIZE: 2
+ # NUM_WORKERS: 4
+ # FEATS_FOLDER: 'cluster2:s3://imagenet'
+ # ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ # SAMPLING_WEIGHT: 1.0
+ # CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ # MIXUP: 0.8
+ # CUTMIX: 1.0
+ # MIXUP_PROB: 1.0
+ # MIXUP_SWITCH_PROB: 0.5
+ # MIXUP_MODE: 'batch'
+ # MIXUP_LABEL_SMOOTHING: 0.1
+ # MODEL:
+ # MAX_SEQ_LEN: -1
+ # LABELS_NUM: 1000
+ # TEMP_NAME: logit_scale_img_cls
+ # LOSSES:
+ # NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # # LOSS_FP32: True
+ # INFERENCE:
+ # NAME: 'ImageNetEvaler'
+ # ID_KEY: 'image_id'
+ # VALUE: 'cls_logits'
+ # VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ # TEST_ANNFILE: ''
+ # GENERATION_MODE: False
+
+ -
+ NAME: bookswiki_pretrain
+ DATASETS:
+ TRAIN: 'GeneralCorpusDataset'
+ TASK_TYPE: 'text_mlm'
+ DATASET_NAME: 'BooksWiki'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 2
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ SEQ_PER_SAMPLE: 1
+ SAMPLER: NodeDistributed
+ CACHE_MODE: True
+ SEQ_PER_SAMPLE: 128
+ MIN_SEQ_PER_SAMPLE: 128
+ APPEND_EOS: True
+ ONE_STREAM: False
+ SAMPLING_WEIGHT: 1.0
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 128
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ # -
+ # NAME: mscoco_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # # VAL: 'ImageTextPairDataset'
+ # # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'MSCOCO'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 200
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 4
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # EVAL_MAX_SEQ_LEN: 21
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # DECODE_STRATEGY:
+ # NAME: 'CaptionBeamSearcherV3'
+ # BEAM_SIZE: 2
+ # # LEN_PENALTY: 1.0
+ # INFERENCE:
+ # NAME: 'COCOEvaler'
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ # TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ # GENERATION_MODE: True
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5e75f702357c9beea48999f6f60ec6ef83897227
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'CoLA'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/CoLA_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: CoLA
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'CoLA'
+ TARGET_SET: ['CoLA']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ # EPOCH: 1
+ MAX_ITER: 5600
+ CHECKPOINT_PERIOD: 1000000
+ EVAL_PERIOD: 200
+ CHECKPOINT_MAX_SAVE: 1
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 400
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fd278551b04f7b19cbc94811bbf3068fa665c380
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'MNLI'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/MNLI_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+ -
+ NAME: MNLI
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'MNLI_Match'
+ TARGET_SET: ['MNLI']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 125000
+ CHECKPOINT_PERIOD: 125000
+ EVAL_PERIOD: 5000
+ CHECKPOINT_MAX_SAVE: 1
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 7500
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e02e1767e38d80fb7995069c20f235511c078ed9
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml
@@ -0,0 +1,88 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'MRPC'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/MRPC_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: MRPC
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'MRPC'
+ TARGET_SET: ['MRPC']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 100
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 150
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..190e8c9d497eafc3c5d91c9b1e37143307976cbb
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml
@@ -0,0 +1,85 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'QNLI'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/QNLI_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: QNLI
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'QNLI'
+ TARGET_SET: ['QNLI']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 34000
+ CHECKPOINT_PERIOD: 200000
+ EVAL_PERIOD: 2000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c9de949820210150000e3e6bb4c5649289827d37
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml
@@ -0,0 +1,84 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'QQP'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/QQP_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: QQP
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'QQP'
+ TARGET_SET: ['QQP']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 115000
+ CHECKPOINT_PERIOD: 200000
+ EVAL_PERIOD: 5000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 28000
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7ff8e503ebc32c873cd764d67db3b4092bf71a5b
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml
@@ -0,0 +1,92 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'RTE'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/RTE_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+TASKS:
+ -
+ NAME: RTE
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'RTE'
+ TARGET_SET: ['RTE']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 100
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 150
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4f8e494746458d1b955f1ae6f7626926a8e8b9c9
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'SST-2'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/SST-2_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+ -
+ NAME: SST-2
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'SST-2'
+ TARGET_SET: ['SST-2']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 22000
+ CHECKPOINT_PERIOD: 100000
+ EVAL_PERIOD: 1000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 1500
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/base.yaml b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f57ff0a46f78b94c7b7af2f6add49ca6c1f03e89
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/GLUE_finetuning_experiments/base.yaml
@@ -0,0 +1,22 @@
+_BASE_: "../../base_model_bert_l12_h768.yaml"
+
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/finetuning/flickr30k_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/flickr30k_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34ae29b9c0aa624b3ae05f1ab10ab4378672ed7d
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/flickr30k_caption_finetuning.yaml
@@ -0,0 +1,151 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 4000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/finetuning/flickr30k_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/flickr30k_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30279d488deb73452aaff812ef7ee79109e6f5a8
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/flickr30k_retrieval_finetuning.yaml
@@ -0,0 +1,132 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: 's3://open_dataset/flickr30k/flickr30k_images'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 5000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
diff --git a/configs/BERT_L12_H768_experiments/finetuning/in1k_training.yaml b/configs/BERT_L12_H768_experiments/finetuning/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..13aa76317fb0c76195f4c19fa411c7120c96f902
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/in1k_training.yaml
@@ -0,0 +1,135 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 20000
+ CHECKPOINT_PERIOD: 20000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.00000001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.00000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/in1k_training_384inputsize.yaml b/configs/BERT_L12_H768_experiments/finetuning/in1k_training_384inputsize.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f1c4a53a8cc72eed0f268d3d59dc0b5d8e440060
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/in1k_training_384inputsize.yaml
@@ -0,0 +1,134 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 384
+ PATCH_SIZE: 16
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/384"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 40000
+ CHECKPOINT_PERIOD: 40000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.000001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-6
+ GRAD_CLIP: 0.0
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 4000
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/k400_training.yaml b/configs/BERT_L12_H768_experiments/finetuning/k400_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6b87c25f4c3d1eb24016df097c182264cbc773ca
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/k400_training.yaml
@@ -0,0 +1,133 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Kinetics400'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: K400_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ VAL: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K400'
+ TARGET_SET: ['Kinetics400']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 8 # 256
+ TEST_BATCH_SIZE: 4 # debug
+ NUM_WORKERS: 4 # debug 4
+ FEATS_FOLDER: 'open_source_dataset/K400_official'
+ ANNO_FOLDER: 'open_source_dataset/K400_official'
+ S3_PATH: 's3://K400/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MULTI_VEIW_NUM: 4
+ MULTI_VEIW: 'v2'
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 40000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/finetuning/mscoco_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/mscoco_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c344904a7d294ff673cb6893d14feaacaf24b4a7
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/mscoco_caption_finetuning.yaml
@@ -0,0 +1,150 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 10000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/finetuning/mscoco_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/mscoco_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fde185032d19c35d0a6a0bf4981c772337f53319
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/mscoco_retrieval_finetuning.yaml
@@ -0,0 +1,132 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 10000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
diff --git a/configs/BERT_L12_H768_experiments/finetuning/msvd_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/msvd_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..86ee3a0cb591381a495b3547df69ea35e2e3bc3d
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/msvd_caption_finetuning.yaml
@@ -0,0 +1,144 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2 #6
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_val_cocostyle.json'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 500
+ EVAL_PERIOD: 200
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..528fc8062a7630772902b4fdf403e9bf1bfa9d67
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning.yaml
@@ -0,0 +1,129 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 8
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ VIDEO_EMBED:
+ MAX_FRAMES: 8
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 8
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 200
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
diff --git a/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning_frames8.yaml b/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning_frames8.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..70ba1d14d34f99e3b9bd73ab2786daa1aafb0ebb
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/msvd_retrieval_finetuning_frames8.yaml
@@ -0,0 +1,125 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 8
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 8
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 5000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
diff --git a/configs/BERT_L12_H768_experiments/finetuning/vqa_finetuning_debug.yaml b/configs/BERT_L12_H768_experiments/finetuning/vqa_finetuning_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..308e72b4153a9cd4d1942aac5e2d09ef088aeb9f
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/finetuning/vqa_finetuning_debug.yaml
@@ -0,0 +1,127 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+ -
+ NAME: 'VQA_Answer'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/VQA_Answers_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: vqa
+ DATASETS:
+ TRAIN: 'VQADataset'
+ VAL: 'VQADataset'
+ # TEST: 'VQADataset'
+ DATASET_NAME: 'VQA'
+ TASK_TYPE: 'vqa'
+ TARGET_SET: ['VQA_Answer']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/VQA'
+ SEQ_PER_SAMPLE: 1
+ MAX_FEAT_NUM: 51
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ DO_AS_GEN: True
+ SINGLE_CLASS: True
+ MODEL:
+ MAX_SEQ_LEN: 23
+ TEMP_NAME: logit_scale_downstream
+ LOSSES:
+ # not single class
+ # NAMES: ['BCEWithLogits']
+ # LOSS_WEIGHT: 0.05
+ # for single class
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.1
+ INFERENCE:
+ VOCAB: 'CLIP'
+ NAME: 'VQAEvaler'
+ ID_KEY: 'question_id'
+ VALUE: 'answer'
+ VAL_ANNFILE: 'open_source_dataset/VQA/val_target.pkl'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+######################################### Engine #########################################
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ BERT:
+ DROP_PATH_PROB: 0.1
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+ TEMP_NAME: logit_scale_downstream
+ PRED_TEMPERATURE: 0.03
+ LEARN_TEMP: False
+ CLS_TOKEN: True
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ # EPOCH: 1
+ MAX_ITER: 20000
+ CHECKPOINT_PERIOD: 1000
+ EVAL_PERIOD: 1000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00004
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-8
+ GRAD_CLIP: 0.0
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 500
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 1000
+ MIN_LR: 0.00000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/in1k_training.yaml b/configs/BERT_L12_H768_experiments/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4e599b04f524674d9728383dcb2e66ad0cc4dd11
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/in1k_training.yaml
@@ -0,0 +1,310 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ # -
+ # NAME: bookswiki_pretrain
+ # DATASETS:
+ # TRAIN: 'GeneralCorpusDataset'
+ # TASK_TYPE: 'text_mlm'
+ # DATASET_NAME: 'BooksWiki'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 128
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 2
+ # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # # ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/bookswiki'
+ # SEQ_PER_SAMPLE: 1
+ # SAMPLER: NodeDistributed
+ # CACHE_MODE: True
+ # SEQ_PER_SAMPLE: 128
+ # MIN_SEQ_PER_SAMPLE: 128
+ # APPEND_EOS: True
+ # ONE_STREAM: False
+ # SAMPLING_WEIGHT: 1.0
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 128
+ # TEMP_NAME: logit_scale_text_mlm
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # GENERATION_MODE: False
+
+ # -
+ # NAME: mscoco_caption
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # # VAL: 'ImageTextPairDataset'
+ # # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_caption'
+ # DATASET_NAME: 'MSCOCO'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 64
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 4
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.0
+ # TRANSFORM: 'clip_transforms'
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # EVAL_MAX_SEQ_LEN: 21
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.33333
+ # REDUCTION: 'mean'
+ # DECODE_STRATEGY:
+ # NAME: 'CaptionBeamSearcherV3'
+ # BEAM_SIZE: 2
+ # # LEN_PENALTY: 1.0
+ # INFERENCE:
+ # NAME: 'COCOEvaler'
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ # TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ # GENERATION_MODE: True
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 100
+ # TEST_BATCH_SIZE: 32
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 1.0
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 50
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1.0
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..42a7335cff4f1992a7014672bd0ebe5459763215
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_CoLA_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'CoLA-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/CoLA_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: CoLA
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'CoLA'
+ TARGET_SET: ['CoLA-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ # EPOCH: 1
+ MAX_ITER: 5600
+ CHECKPOINT_PERIOD: 1000000
+ EVAL_PERIOD: 200
+ CHECKPOINT_MAX_SAVE: 1
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 400
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..612dc5ccf3c7e78b4395a956df90c341dfebc69a
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MNLI_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'MNLI-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/MNLI_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+ -
+ NAME: MNLI
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'MNLI_Match'
+ TARGET_SET: ['MNLI-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 125000
+ CHECKPOINT_PERIOD: 125000
+ EVAL_PERIOD: 5000
+ CHECKPOINT_MAX_SAVE: 1
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 7500
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..38154820a319d9064515482901857d75e3eba9ab
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_MRPC_mlm_finetune.yaml
@@ -0,0 +1,88 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'MRPC-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/MRPC_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: MRPC
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'MRPC'
+ TARGET_SET: ['MRPC-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 100
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 150
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..858e1487f05f88883b237265a0f43c38739236d7
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QNLI_mlm_finetune.yaml
@@ -0,0 +1,85 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'QNLI-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/QNLI_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: QNLI
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'QNLI'
+ TARGET_SET: ['QNLI-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 34000
+ CHECKPOINT_PERIOD: 200000
+ EVAL_PERIOD: 2000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..abf98c29ca135ff0690d0d741951f95e5035eedf
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_QQP_mlm_finetune.yaml
@@ -0,0 +1,84 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'QQP-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/QQP_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+TASKS:
+ -
+ NAME: QQP
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'QQP'
+ TARGET_SET: ['QQP-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 115000
+ CHECKPOINT_PERIOD: 200000
+ EVAL_PERIOD: 5000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 28000
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c43f368a7649701151087d99ecc39a0fbe6f3669
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_RTE_mlm_finetune.yaml
@@ -0,0 +1,92 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'RTE-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/RTE_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+TASKS:
+ -
+ NAME: RTE
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'RTE'
+ TARGET_SET: ['RTE-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 10000
+ EVAL_PERIOD: 100
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 150
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..17aa99c9f60ed0add0e4b00e79a0cf824b1f1d42
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/GLUE_SST2_mlm_finetune.yaml
@@ -0,0 +1,89 @@
+_BASE_: "base.yaml"
+
+SHARED_TARGETS:
+ -
+ NAME: 'SST-2-target'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/GLUE_classnames/SST-2_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+TASKS:
+ -
+ NAME: SST-2
+ DATASETS:
+ TRAIN: 'GLUEDataset'
+ # TEST: 'GLUEDataset'
+ VAL: 'GLUEDataset'
+ TASK_TYPE: 'text_classification'
+ DATASET_NAME: 'SST-2'
+ TARGET_SET: ['SST-2-target']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 4
+ ANNO_FOLDER: 'open_source_dataset/bert_pretrain_data/glue_data/'
+
+ MODEL:
+ MAX_SEQ_LEN: 256
+ TEMP_NAME: logit_scale_text_mlm
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 1
+ REDUCTION: 'mean'
+ LOSS_FP32: False
+ INFERENCE:
+ NAME: 'GLUEEvaler'
+ VOCAB: 'CLIP'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+######################################### MODEL #########################################
+MODEL:
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ MAX_ITER: 22000
+ CHECKPOINT_PERIOD: 100000
+ EVAL_PERIOD: 1000
+ CHECKPOINT_MAX_SAVE: 2
+ BASE_LR: 0.00001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.1
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.98]
+ EPS: 1e-8
+ GRAD_CLIP: 0.5
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 20
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 1500
+ MIN_LR: 0.00000001
+
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/base.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/base.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..685ef589089b33c71df29c6df2d739d295432bcd
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/GLUE_finetuning_experiments/base.yaml
@@ -0,0 +1,46 @@
+_BASE_: "../../base_model_bert_l12_h768.yaml"
+
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..93800b02a69bb171407791123a180d8546bfbc49
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_caption_finetuning.yaml
@@ -0,0 +1,175 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 4000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6da58a03bf14b780f06f397a576755e0c8272c52
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/flickr30k_retrieval_finetuning.yaml
@@ -0,0 +1,155 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: 's3://open_dataset/flickr30k/flickr30k_images'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 5000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4998f338eb4964d4673d66d0b9599950626ea2e3
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training.yaml
@@ -0,0 +1,159 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 20000
+ CHECKPOINT_PERIOD: 20000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.00000001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.00000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training_384inputsize.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training_384inputsize.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e9af6b007b1addce62aecd160a54fa44e5060589
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training_384inputsize.yaml
@@ -0,0 +1,158 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 256
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.0
+ CUTMIX: 0.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ LABELSMOOTHING: 0.1
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 384
+ PATCH_SIZE: 16
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/384"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 40000
+ CHECKPOINT_PERIOD: 40000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.00000001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.999]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 4000
+ MIN_LR: 0.00000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/k400_training.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/k400_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e99127de00c23ad0f4691b9308fda9facb12e4c5
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/k400_training.yaml
@@ -0,0 +1,158 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Kinetics400'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: K400_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ VAL: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K400'
+ TARGET_SET: ['Kinetics400']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 8 # 256
+ TEST_BATCH_SIZE: 4 # debug
+ NUM_WORKERS: 4 # debug 4
+ FEATS_FOLDER: 'open_source_dataset/K400_official'
+ ANNO_FOLDER: 'open_source_dataset/K400_official'
+ S3_PATH: 's3://K400/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MULTI_VEIW_NUM: 4
+ MULTI_VEIW: 'v2'
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 40000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..47d1dbacf57286437d88c949e080181439c6a6cf
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_caption_finetuning.yaml
@@ -0,0 +1,174 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 10000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b37c7bec96804f43de9c34dfc6906651c9de26c5
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/mscoco_retrieval_finetuning.yaml
@@ -0,0 +1,155 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 10000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_caption_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_caption_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..11e1f60eaef1ddf1113e166556dd48d9b8987dee
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_caption_finetuning.yaml
@@ -0,0 +1,168 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2 #6
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_val_cocostyle.json'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 500
+ EVAL_PERIOD: 200
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_retrieval_finetuning.yaml b/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_retrieval_finetuning.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77388c89644c714f047d11c415bcb9578476b6d4
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_finetuning/msvd_retrieval_finetuning.yaml
@@ -0,0 +1,152 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 8
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ VIDEO_EMBED:
+ MAX_FRAMES: 8
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 8
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 200
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d4705f9de45ea0ff13df5749533e86a1c1acbd2
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml
@@ -0,0 +1,201 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 500
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5912cc00eca1afa6662b45c6a786b0a26d125d07
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml
@@ -0,0 +1,181 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: 's3://open_dataset/flickr30k/flickr30k_images'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ed10c580ea5dcc068c63174be9af307016f1ff6
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,182 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 512
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: True
+ FC_PROMPT_OUT: 1000
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: True
+ LABEL_SIZE: 1000
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.0001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/k400_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/k400_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..593227b225112482b400d781efdc5f8562f8fe65
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/k400_prompt_tuning_0.01data.yaml
@@ -0,0 +1,184 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Kinetics400'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: K400_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ VAL: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K400'
+ TARGET_SET: ['Kinetics400']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4 # 256
+ TEST_BATCH_SIZE: 4 # debug
+ NUM_WORKERS: 4 # debug 4
+ FEATS_FOLDER: 'open_source_dataset/K400_official'
+ ANNO_FOLDER: 'open_source_dataset/K400_official'
+ S3_PATH: 's3://K400/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MULTI_VEIW_NUM: 4
+ MULTI_VEIW: 'v2'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: True
+ FC_PROMPT_OUT: 400
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 400
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.0005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3 copy.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3 copy.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a02867835526947022992dac488d6013c3cb5d57
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3 copy.yaml
@@ -0,0 +1,177 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 200
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e30c79a4c208070558823a0eec6b7cdc35f65dfe
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml
@@ -0,0 +1,201 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 200
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f9967390b06953245e98d99a071ee8f20ce95a59
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,181 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..147037531d580475676bb2aa191f86478881b301
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml
@@ -0,0 +1,194 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2 #6
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_val_cocostyle.json'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 25
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c5ddc6627e373b30ec2c7aa43cff8e5417eac519
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,201 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 8
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 100
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 20
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 5
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 10
+ MIN_LR: 0.000001
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/moe_zeroshot/mscoco_caption.yaml b/configs/BERT_L12_H768_experiments/moe_zeroshot/mscoco_caption.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3704a93b44a251c77092753faa8cfc107d620a18
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/moe_zeroshot/mscoco_caption.yaml
@@ -0,0 +1,168 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2 #6
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_val_cocostyle.json'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 500
+ EVAL_PERIOD: 200
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/mscoco_debug.yaml b/configs/BERT_L12_H768_experiments/mscoco_debug.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e5153edf633763210f1d1339cf7b647508ac2081
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/mscoco_debug.yaml
@@ -0,0 +1,235 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # TEST_BATCH_SIZE: 64
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 2
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 1.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/mscoco_debug_moe.yaml b/configs/BERT_L12_H768_experiments/mscoco_debug_moe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..21d2375a84366dcd4fc8768f66c41a78eb771baf
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/mscoco_debug_moe.yaml
@@ -0,0 +1,262 @@
+_BASE_: "base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # TEST_BATCH_SIZE: 64
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3e575ade4a86caf6decfcad1be8b6ab0b21b1d78
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_caption_prompt_tuning_0.01data.yaml
@@ -0,0 +1,177 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 500
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..69af009307975a05b8658fbff3443333dd1ed551
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/flickr30k_retrieval_prompt_tuning_0.01data.yaml
@@ -0,0 +1,158 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: 's3://open_dataset/flickr30k/flickr30k_images'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_labelprompt_lr1e-3.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_labelprompt_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..232280d27d0565981c0449d87c02cd34e58d1523
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_labelprompt_lr1e-3.yaml
@@ -0,0 +1,162 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 512
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+ FC_PROMPT_OUT: 1000
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: True
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: True
+ LABEL_SIZE: 1000
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dbe3273b26429abc20b0e5c76feeab3d02fbaa15
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,162 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 512
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: True
+ FC_PROMPT_OUT: 1000
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: True
+ LABEL_SIZE: 1000
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.0001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4_notshare.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4_notshare.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5e4f53c26babb237ded81f1581b940ab51482a34
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4_notshare.yaml
@@ -0,0 +1,162 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 64
+ TEST_BATCH_SIZE: 512
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: True
+ FC_PROMPT_OUT: 1000
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: True
+ LABEL_SIZE: 1000
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.0001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/k400_prompt_tuning_0.01data.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/k400_prompt_tuning_0.01data.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..33e8578f0e1ebf4fd28ca53abc3e7d6b2091919f
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/k400_prompt_tuning_0.01data.yaml
@@ -0,0 +1,160 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Kinetics400'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: K400_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ VAL: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K400'
+ TARGET_SET: ['Kinetics400']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4 # 256
+ TEST_BATCH_SIZE: 4 # debug
+ NUM_WORKERS: 4 # debug 4
+ FEATS_FOLDER: 'open_source_dataset/K400_official'
+ ANNO_FOLDER: 'open_source_dataset/K400_official'
+ S3_PATH: 's3://K400/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MULTI_VEIW_NUM: 4
+ MULTI_VEIW: 'v2'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: True
+ FC_PROMPT_OUT: 400
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 400
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 2000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.0005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a02867835526947022992dac488d6013c3cb5d57
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_caption_prompt_tuning_0.01data_lr1e-3.yaml
@@ -0,0 +1,177 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 1000
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 200
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 100
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..891a7827b3616ea6383727de62ecac83bb51315c
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/mscoco_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,158 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 500
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 50
+ MIN_LR: 0.000001
+
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..164153a341d2303be1cdcd1ccd47a64f11d49b03
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_caption_prompt_tuning_0.01data_lr1e-3.yaml
@@ -0,0 +1,170 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 2 #6
+ TEST_BATCH_SIZE: 4
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ # LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_val_cocostyle.json'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: True
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200
+ CHECKPOINT_PERIOD: 5000
+ EVAL_PERIOD: 50
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 25
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 25
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml b/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6fa03a2f668279732050b474d3b3bf0b048ad65c
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/prompt_tuning/msvd_retrieval_prompt_tuning_0.01data_lr1e-4.yaml
@@ -0,0 +1,154 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 16
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 8
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ DATA_PERCENTAGE: 0.01
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+ # prompt
+ PROMPT: True
+ PROMPT_PARAM: ["s_token_bias", "norm", "prompt_embed", "deep_prompt_embedding", "fc_prompt", "similarity_weight", "ln_post"] # ["s_token_bias", "LayerNorm", "prompt_embed", "prompt_fc", "similarity_weight"]
+ FC_PROMPT: False
+
+
+ # #################################### prompt embedding ####################################
+ PROMPT_EMBED: #### activated only when the
+ NAME: 'PrefixPromptEmbedding'
+ ACTIVATION: 'none'
+ ELU_ALPHA: 0.5
+ USE_NORM: False
+ DROPOUT: 0.0
+ WITH_POS: False
+ INPUT_PROMPT: False
+ TARGET_PROMPT: False
+ DEEP_PROMPT: True
+ TARGET_DEEP_PROMPT: True
+ SHARE_DEEP_PROMPT: False
+ PROMPT_LENGTH: 10
+ TARGET_PROMPT_LENGTH: 1
+ INPUT_DEEP_PROMPT_LENGTH: 10
+ TARGET_DEEP_PROMPT_LENGTH: 10
+ LABLE_PROMPT: False
+ LABEL_SIZE: 1000
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 100
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 20
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 5
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 10
+ MIN_LR: 0.000001
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_caption.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_caption.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1b95d6a9d6d11154480ac0032f724ec1cbda14c3
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_caption.yaml
@@ -0,0 +1,150 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: flickr30k_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'FLICKR'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: "s3://open_dataset/flickr30k/flickr30k_images"
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/captions_val.json'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/captions_test.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 4000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_retrieval.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_retrieval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..30279d488deb73452aaff812ef7ee79109e6f5a8
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/flickr30k_retrieval.yaml
@@ -0,0 +1,132 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+
+
+ -
+ NAME: flickr30k_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'FLICKR'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 2
+ FEATS_FOLDER: 'open_source_dataset/flickr30k_images/flickr30k_images/flickr30k_images'
+ ANNO_FOLDER: 'open_source_dataset/flickr30k'
+ S3_PATH: 's3://open_dataset/flickr30k/flickr30k_images'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 5000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 200
+ MIN_LR: 0.000001
+
+find_unused_parameters: true
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/in1k_training.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..71fc9b06b8b612eed9392794dc4e3b93d80515af
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/in1k_training.yaml
@@ -0,0 +1,195 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.3
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/k400_eval.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/k400_eval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bbd5a64660a25a443ca8203f128466b6ba18251c
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/k400_eval.yaml
@@ -0,0 +1,134 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Kinetics400'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: K400_retrieve
+ DATASETS:
+ TRAIN: 'VideoDataSet'
+ VAL: 'VideoDataSet'
+ TASK_TYPE: 'video_classification'
+ DATASET_NAME: 'K400'
+ TARGET_SET: ['Kinetics400']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 4 # 256
+ TEST_BATCH_SIZE: 4 # debug
+ NUM_WORKERS: 4 # debug 4
+ FEATS_FOLDER: 'open_source_dataset/K400_official'
+ ANNO_FOLDER: 'open_source_dataset/K400_official'
+ S3_PATH: 's3://K400/'
+ FRAMES_PER_CLIP: 8
+ STRIDE: 32
+ FILE_EXTENSION: ''
+ ANNO_FILE: 'annotation.json'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+ MULTI_VEIW_NUM: 4
+ MULTI_VEIW: 'v2'
+ MODEL:
+ MAX_SEQ_LEN: -1
+ TEMP_NAME: logit_scale_video_cls
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ INFERENCE:
+ NAME: 'MiTEvaler'
+ ID_KEY: 'video_name'
+ VALUE: 'label'
+ VAL_ANNFILE: 'open_source_dataset/K400_official/annotation.json'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+ NUM_VIEWS: 1
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 20000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 2000
+ BASE_LR: 0.000005
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ LOSS_SCALE_WINDOW: 200
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 2000
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5201c11696f6ceaf70d6807f14753345e63e8cf
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption.yaml
@@ -0,0 +1,239 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # TEST_BATCH_SIZE: 64
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption_finetuning_customattnmodule_womoe.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption_finetuning_customattnmodule_womoe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7759fdf3ea6c7e0379462376199fa298dfa7560
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_caption_finetuning_customattnmodule_womoe.yaml
@@ -0,0 +1,174 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 32
+ TEST_BATCH_SIZE: 8
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 1.0
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 2.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.0
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 10000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500
+ BASE_LR: 0.00002
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.0001
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 500
+ MIN_LR: 0.000001
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'none' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
\ No newline at end of file
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_retrieval.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_retrieval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..afdeb55e080356b24302d81d23fede1fcbbfb91c
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/mscoco_retrieval.yaml
@@ -0,0 +1,191 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+# SHARED_TARGETS:
+
+
+
+# -
+# NAME: 'Vocab_Word'
+# SHARED_TARGETS_CFG:
+# FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+# DISTRIBUTED: True
+
+TASKS:
+
+
+
+ -
+ NAME: mscoco_retrieve
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_retrieval'
+ DATASET_NAME: 'MSCOCO'
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 256
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 1
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ MODEL:
+ MAX_SEQ_LEN: 30
+ TEMP_NAME: logit_scale_retrieve
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ NAME: 'RetrievalEvaler'
+ VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/msrvtt.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/msrvtt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6c4e49301d15aa3fb787ff83cb05d337a333a7e6
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/msrvtt.yaml
@@ -0,0 +1,240 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+
+ # -
+ # NAME: 'Vocab_Word'
+ # SHARED_TARGETS_CFG:
+ # FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ # DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: msrvtt_retrieval
+ DATASETS:
+ TRAIN: 'MSRVTTDataset'
+ # VAL: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSRVTTDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 6
+ TEST_BATCH_SIZE: 6
+ NUM_WORKERS: 0
+ FEATS_FOLDER: 'open_source_dataset/msrvtt_dataset/videos'
+ ANNO_FOLDER: 'open_source_dataset/msrvtt_dataset/annotations_new'
+ STRIDE: 32
+ S3_PATH: 's3://coco/'
+ TIMESFORMER_AUG: True
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 77
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+ # -
+ # NAME: msrvtt_caption
+ # DATASETS:
+ # TRAIN: 'MSRVTTDataset'
+ # # VAL: 'ImageTextPairDataset'
+ # # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'video_caption'
+ # DATASET_NAME: 'MSRVTTDataset'
+ # TARGET_SET: ['Vocab_Word']
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 6
+ # TEST_BATCH_SIZE: 6
+ # NUM_WORKERS: 0
+ # FEATS_FOLDER: 'open_source_dataset/msrvtt_dataset/videos'
+ # ANNO_FOLDER: 'open_source_dataset/msrvtt_dataset/annotations_new'
+ # STRIDE: 32
+ # S3_PATH: 's3://coco/'
+ # TIMESFORMER_AUG: True
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # RANDOM_MASK: True
+ # MODEL:
+ # MAX_SEQ_LEN: 77
+ # EVAL_MAX_SEQ_LEN: 21
+ # TEMP_NAME: logit_scale_caption
+ # LOSSES:
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # DECODE_STRATEGY:
+ # NAME: 'CaptionBeamSearcherV3'
+ # BEAM_SIZE: 2
+ # # LEN_PENALTY: 1.0
+ # INFERENCE:
+ # NAME: 'COCOEvaler'
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ # TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ # GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ IN_TUNING: True # use IN1k instead of 22k
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 0
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_caption.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_caption.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9dbf810c788265691a11f4e3fbe0d9c2f7fb76dd
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_caption.yaml
@@ -0,0 +1,190 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+SHARED_TARGETS:
+
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+ -
+ NAME: msvd_caption
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_caption'
+ DATASET_NAME: 'MSVDDataset'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 6
+ TEST_BATCH_SIZE: 6
+ NUM_WORKERS: 6
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 1.0
+
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ TEST_ANNFILE: 'open_source_dataset/msvd_dataset/new_annotations/caption_msvd_test_cocostyle.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_retrieval.yaml b/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_retrieval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..53a1888c39f42f49a3803d44917a7c0eb5a2c5d3
--- /dev/null
+++ b/configs/BERT_L12_H768_experiments/zeroshot_config/msvd_retrieval.yaml
@@ -0,0 +1,172 @@
+_BASE_: "../base_model_bert_l12_h768.yaml"
+
+
+TASKS:
+
+ -
+ NAME: msvd_retrieval
+ DATASETS:
+ TRAIN: 'MSVDDataset'
+ TEST: 'MSVDDataset'
+ TASK_TYPE: 'video_retrieval'
+ DATASET_NAME: 'MSVDDataset'
+ # TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 6
+ TEST_BATCH_SIZE: 64
+ NUM_WORKERS: 6
+ FEATS_FOLDER: 'open_source_dataset/msvd_dataset/YouTubeClips'
+ ANNO_FOLDER: 'open_source_dataset/msvd_dataset/new_annotations'
+ STRIDE: 32
+ FRAMES_PER_CLIP: 4
+ S3_PATH: 's3://msvd/YouTubeClips/'
+ TIMESFORMER_AUG: True
+ SAMPLING_WEIGHT: 0.5
+ MODEL:
+ MAX_SEQ_LEN: 77
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ LABELSMOOTHING: 0.1
+ # NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ INFERENCE:
+ NAME: 'RetrievalEvaler'
+ GENERATION_MODE: False
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.1
+ DROP_PATH_PROB_FIXED: True
+
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ # POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 6
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L24_H1024_experiments/base_model_bert_l24_h1024.yaml b/configs/BERT_L24_H1024_experiments/base_model_bert_l24_h1024.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..77f8f138db4bf010686f7fd204b24475b3b39f93
--- /dev/null
+++ b/configs/BERT_L24_H1024_experiments/base_model_bert_l24_h1024.yaml
@@ -0,0 +1,73 @@
+
+######################################### MODEL #########################################
+MODEL:
+ VOCAB_SIZE: 49411 # include /
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder_v3'
+ ENCODER_DIM: 1024
+ DECODER: ''
+ DECODER_DIM: 1024
+
+ PREDICTOR: 'EmbedClsAsRetrievalPredictor'
+ FEATURE_GATHER: True
+ LEARN_TEMP: True
+ PRED_USE_NORM: True
+ PRED_TEMPERATURE: 0.07
+
+ BertParamsInit: True
+
+ CLS_TOKEN: False
+
+ QUEUE_LEN: 1024
+ MAX_LABEL_LEN: 12
+
+ OUTPUT_PROJ: True # output projection
+
+
+# #################################### Token embedding ####################################
+ TOKEN_EMBED:
+ NAME: 'TokenBaseEmbedding'
+ DIM: 1024
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ POSITION: 'NNEmbeddingEncoding'
+ POSITION_MAX_LEN: 512
+ TYPE_VOCAB_SIZE: 2
+
+# #################################### Visual embedding ####################################
+ VISUAL_EMBED:
+ NAME: 'none'
+
+# #################################### video embedding ####################################
+ VIDEO_EMBED:
+ NAME: 'VideoBaseEmbedding'
+ IN_DIM: 768
+ OUT_DIM: 1024
+ ACTIVATION: 'none'
+ USE_NORM: True
+ DROPOUT: 0.0
+ TYPE_SIZE: 1 # video to encoder
+ POSITION: 'NNEmbeddingEncoding'
+ MAX_LENGTH: 1600
+ PATCH_SIZE_S: 16
+ PATCH_SIZE_T: 1
+ DIVIDE_ST_POS: True
+ USE_VISUAL_TOKENIZER: True
+ USE_VISUAL_POS: True
+ MAX_FRAMES: 8
+
+####################################### BERT ############################################
+ BERT:
+ DROP_PATH_PROB: 0.1
+ HIDDEN_SIZE: 1024
+ HIDDEN_DROPOUT_PROB: 0.
+ HIDDEN_ACT: "gelu"
+ NUM_ATTENTION_HEADS: 16
+ INTERMEDIATE_SIZE: 4096
+ INTERMEDIATE_DROP: 0.
+ FFN_DROPOUT_PROB: 0.
+ ATTENTION_PROBS_DROPOUT_PROB: 0.
+ NUM_HIDDEN_LAYERS: 24
+ NUM_GENERATION_LAYERS: 0
+
\ No newline at end of file
diff --git a/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training.yaml b/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..573296a24967249cab1b3f7fcc4abada19db93ff
--- /dev/null
+++ b/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training.yaml
@@ -0,0 +1,190 @@
+_BASE_: "../base_model_bert_l24_h1024.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.3
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training_moe.yaml b/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training_moe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b35b1c38a4913f77f20b380ca51f6dd7aa8a8dd1
--- /dev/null
+++ b/configs/BERT_L24_H1024_experiments/zeroshot_config/in1k_training_moe.yaml
@@ -0,0 +1,180 @@
+_BASE_: "../base_model_bert_l24_h1024.yaml"
+
+SHARED_TARGETS:
+
+ -
+ NAME: 'ImageNet1k'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: False
+
+
+
+TASKS:
+
+ -
+ NAME: imagenet
+ DATASETS:
+ TRAIN: 'ImageNetDataset'
+ VAL: 'ImageNetDataset'
+ TASK_TYPE: 'image_classification'
+ DATASET_NAME: 'ImageNet1k'
+ TARGET_SET: ['ImageNet1k']
+
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 128
+ TEST_BATCH_SIZE: 128
+ NUM_WORKERS: 4 # will be used as numworker for testing loader
+ FEATS_FOLDER: 'open_source_dataset/imagenet'
+ S3_PATH: 'cluster2:s3://imagenet'
+ ANNO_FOLDER: 'open_source_dataset/imagenet/meta'
+ SAMPLING_WEIGHT: 1.0
+ CLASS_NAME_FILE: 'open_source_dataset/imagenet_class_name.pkl'
+ MIXUP: 0.8
+ CUTMIX: 1.0
+ MIXUP_PROB: 1.0
+ MIXUP_SWITCH_PROB: 0.5
+ MIXUP_MODE: 'batch'
+ MIXUP_LABEL_SMOOTHING: 0.1
+ MODEL:
+ MAX_SEQ_LEN: -1
+ LABELS_NUM: 1000
+ TEMP_NAME: logit_scale_img_cls
+ LOSSES:
+ NAMES: ['SoftTargetCrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 1.0
+ REDUCTION: 'mean'
+ # LOSS_FP32: True
+ INFERENCE:
+ NAME: 'ImageNetEvaler'
+ ID_KEY: 'image_id'
+ VALUE: 'cls_logits'
+ VAL_ANNFILE: 'open_source_dataset/imagenet/meta/val.txt'
+ TEST_ANNFILE: ''
+ GENERATION_MODE: False
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 200000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.3
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+MOE:
+ MOE: True
+ MOE_TYPE: 'attribute'
+ TAG_Transform: True
+ ATTRIBUTE_LENGTH: 8
+ EP_WORLD_SIZE: 1 # tag moe only
+ NUM_EXPERTS: 8
+ TOP_K: 2
+ CAPACITY_FACTOR: 3.0
+ EVAL_MIN_CAPACITY: 4.0
+ MIN_CAPACITY: 4
+ NOISY_GATE_POLICY: 'vmoe'
+ MOE_PARAM_GROUP: True
+ MOE_EXPERT_TYPE: 'FFN,SA'
+ SA_LINEAR_OUT_MOE: True
+ MOE_EXPERT_LOCATION: 'odd' # 'odd'
+ # MOE_LAYER_START_IDX: 3
+ # MOE_LAYER_END_IDX: 21
+ # MOE_LAYER_START_IDX: 18
+ # MOE_LAYER_END_IDX: 12
+ BATCH_PRIO: True
+ USE_TUTEL: True
+ FFN_SHARE_GATE_DECISION: True
diff --git a/configs/BERT_L24_H1024_experiments/zeroshot_config/mscoco_retrieval.yaml b/configs/BERT_L24_H1024_experiments/zeroshot_config/mscoco_retrieval.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1c02c349cf7e22b1ed90d8b892df6730064fa80
--- /dev/null
+++ b/configs/BERT_L24_H1024_experiments/zeroshot_config/mscoco_retrieval.yaml
@@ -0,0 +1,239 @@
+_BASE_: "../base_model_bert_l24_h1024.yaml"
+
+SHARED_TARGETS:
+
+
+
+ -
+ NAME: 'Vocab_Word'
+ SHARED_TARGETS_CFG:
+ FILE_PATH: 'open_source_dataset/vocabulary_CLIP_with_endoftext.pkl'
+ DISTRIBUTED: True
+
+TASKS:
+
+
+
+ # -
+ # NAME: mscoco_retrieve
+ # DATASETS:
+ # TRAIN: 'ImageTextPairDataset'
+ # TEST: 'ImageTextPairDataset'
+ # TASK_TYPE: 'image_retrieval'
+ # DATASET_NAME: 'MSCOCO'
+ # DATALOADER:
+ # TRAIN_BATCH_SIZE: 256
+ # TEST_BATCH_SIZE: 64
+ # NUM_WORKERS: 1
+ # FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ # ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ # S3_PATH: 's3://coco/'
+ # SEQ_PER_SAMPLE: 1
+ # CACHE_MODE: True
+ # CIRCULAR_CACHE_MODE: False
+ # ZIP_MODE: False
+ # CACHE_ORIGIN_IMAGE: False
+ # RANDOM_CAPTION: False
+ # AS_NUMPY_AS_POSSIBLE: False
+ # SAMPLING_WEIGHT: 0.5
+ # TRANSFORM: 'clip_transforms'
+ # MODEL:
+ # MAX_SEQ_LEN: 30
+ # TEMP_NAME: logit_scale_retrieve
+ # LOSSES:
+ # NAMES: ['LabelSmoothingCrossEntropy', 'Accuracy']
+ # LABELSMOOTHING: 0.1
+ # LOSS_WEIGHT: 0.5
+ # REDUCTION: 'mean'
+ # INFERENCE:
+ # VOCAB: 'CLIP'
+ # ID_KEY: 'image_id'
+ # VALUE: 'caption'
+ # NAME: 'RetrievalEvaler'
+ # VAL_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_val_set0_2014.jsonline'
+ # TEST_ANNFILE: 'open_source_dataset/flickr30k/all_data_final_test_set0_2014.jsonline'
+ # GENERATION_MODE: False
+
+ -
+ NAME: mscoco_caption
+ DATASETS:
+ TRAIN: 'ImageTextPairDataset'
+ # VAL: 'ImageTextPairDataset'
+ TEST: 'ImageTextPairDataset'
+ TASK_TYPE: 'image_caption'
+ DATASET_NAME: 'MSCOCO'
+ TARGET_SET: ['Vocab_Word']
+ DATALOADER:
+ TRAIN_BATCH_SIZE: 200
+ TEST_BATCH_SIZE: 32
+ NUM_WORKERS: 4
+ FEATS_FOLDER: 'open_source_dataset/mscoco_dataset/coco_origin'
+ ANNO_FOLDER: 'open_source_dataset/mscoco_dataset/new_annotations'
+ S3_PATH: 's3://coco/'
+ SEQ_PER_SAMPLE: 1
+ CACHE_MODE: True
+ CIRCULAR_CACHE_MODE: False
+ ZIP_MODE: False
+ CACHE_ORIGIN_IMAGE: False
+ RANDOM_CAPTION: False
+ AS_NUMPY_AS_POSSIBLE: False
+ SAMPLING_WEIGHT: 0.5
+ TRANSFORM: 'clip_transforms'
+ RANDOM_MASK: True
+ MODEL:
+ MAX_SEQ_LEN: 30
+ EVAL_MAX_SEQ_LEN: 21
+ TEMP_NAME: logit_scale_caption
+ LOSSES:
+ NAMES: ['CrossEntropy', 'Accuracy']
+ LOSS_WEIGHT: 0.5
+ REDUCTION: 'mean'
+ DECODE_STRATEGY:
+ NAME: 'CaptionBeamSearcherV3'
+ BEAM_SIZE: 2
+ LEN_PENALTY: 2.0
+ INFERENCE:
+ NAME: 'COCOEvaler'
+ VOCAB: 'CLIP'
+ ID_KEY: 'image_id'
+ VALUE: 'caption'
+ VAL_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_val5k.json'
+ TEST_ANNFILE: 'open_source_dataset/mscoco_dataset/new_annotations/captions_test5k.json'
+ GENERATION_MODE: True
+
+
+
+
+ENGINE:
+ NAME: 'UnifiedTrainer'
+
+MODEL:
+ META_ARCHITECTURE: 'MultiTaskTransformerEncoder'
+ ENCODER: 'UnifiedBertEncoder'
+
+
+ SHARE_LAYERNORM: True
+ BERT:
+ NORMALIZE_DECISION: "BERTPre"
+ DROP_PATH_PROB: 0.2
+ DROP_PATH_PROB_FIXED: True
+
+ UNIFY_QKV: True
+
+ MODEL_EMA: False
+ MODEL_EMA_DECAY: 0.9999
+
+ MAEParamsInit: True
+ POSEMBEDFIX: True
+
+
+ IMG_INPUT_SIZE: 224
+ PATCH_SIZE: 16
+
+ POSEMBED_SCALE: !!python/object/apply:eval ["160/224"]
+ CHECKPOINT_FILETER: False
+ OLD_CHECKPONT: True
+
+ LAYER_SCALE: True
+ LAYER_SCALE_INIT: 1e-3
+
+
+DATALOADER:
+ USE_WEIGHTED_SAMPLER: True
+ UNIFIED_DATASET: True
+ NUM_WORKERS: 16
+
+ PADDING_TO_MAX: False # True for debugging or token moe with distributed moe
+
+
+
+####################################### Optimizer #######################################
+SOLVER:
+ NAME: 'Adam'
+ TORCH_OPTIMIZER: True
+ PARAMS_SEPERATE: True
+ # PARAMS_GROUP: True
+ # EPOCH: 1
+ MAX_ITER: 450000
+ CHECKPOINT_PERIOD: 50000
+ EVAL_PERIOD: 500000
+ BASE_LR: 0.001
+ BIAS_LR_FACTOR: 1.0
+ WEIGHT_DECAY: 0.05
+ WEIGHT_DECAY_NORM: 0.0
+ WEIGHT_DECAY_BIAS: 0.0
+ WEIGHT_DECAY_EMBEDDING: 0.0
+ MOMENTUM: 0.9
+ DAMPENING: 0.0
+ NESTEROV: 0.0
+ BETAS: [0.9, 0.95]
+ EPS: 1e-6
+ GRAD_CLIP: 0.1
+ GRAD_CLIP_TYPE: 'norm'
+ ACCUM_ITER: 0
+ AMP_FP16: True
+ APEX_FP16: False # dangerous
+
+ WRITE_PERIOD: 50
+ MIN_LOSS_SCLE: 2048.0
+ # BF16: False # True
+ # ZEROSTAGE: 2
+
+ LOSS_SCALE_WINDOW: 200
+
+
+
+
+
+
+####################################### lr scheduler #######################################
+LR_SCHEDULER:
+ NAME: 'WarmupCosine'
+ WARMUP: 20000
+ MIN_LR: 0.000001
+
+
+
+
+####################################### evaluation #######################################
+INFERENCE:
+
+ VOCAB: 'CLIP'
+ ITER_BASED: True
+
+
+find_unused_parameters: true
+
+# ENCODERS:
+# -
+# NAME: VisualEncoder
+# TYPE: VisualEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
+# -
+# NAME: TextEncoder
+# TYPE: TextEncoder
+# DROP_PATH_PROB: 0.0
+# HIDDEN_SIZE: 192
+# HIDDEN_DROPOUT_PROB: 0.
+# HIDDEN_ACT: "gelu"
+# NUM_ATTENTION_HEADS: 3
+# INTERMEDIATE_SIZE: 768
+# INTERMEDIATE_DROP: 0.
+# FFN_DROPOUT_PROB: 0.
+# ATTENTION_PROBS_DROPOUT_PROB: 0.
+# NUM_HIDDEN_LAYERS: 6
+# NUM_GENERATION_LAYERS: 0
+# DROP_PATH_PROB_FIXED: True
+
diff --git a/data/checkpoints.md b/data/checkpoints.md
new file mode 100644
index 0000000000000000000000000000000000000000..66ac7f4c01c49418c680214e3ce44f4065c787be
--- /dev/null
+++ b/data/checkpoints.md
@@ -0,0 +1,30 @@
+# Checkpoints
+We provide links for you to download the pre-trained weights of some Uni-Perceiver models.
+
+
+
+
+ Model | Link | Params | Hidden size | Intermediate size | Num. of heads | Enc layers |
+
+
+ Uni Perceiverbase | Downloadtorch
+ Downloadds | 124M | 768 | 3072 | 12 | 12 |
+
+
+ Uni Perceiver MoEbase | Downloadds | 167M | 768 | 3072 | 12 | 12 |
+
+
+ Uni Perceiverlarge | Downloadds | 354M | 1024 | 4096 | 16 | 24 |
+
+
+ Uni Perceiver MoElarge | Downloadds | 505M | 1024 | 4096 | 16 | 24 |
+
+
+
+
+* Models torch are pretrained in our current codebase, while models ds are pretrained in our previous codebase with DeepSpeed engine.
+
+* `Params` in the table is the parameters required during model deployment for image-text tasks. Note that it may be different on other tasks.
+
+
+* More pre-trained weights will be released.
\ No newline at end of file
diff --git a/data/data_structure.md b/data/data_structure.md
new file mode 100644
index 0000000000000000000000000000000000000000..01596e4cbe0894d93a3ff67442297e174249eef6
--- /dev/null
+++ b/data/data_structure.md
@@ -0,0 +1,199 @@
+## data structure
+
+* imagenet 1k
+
+```
+data = {
+ 'input_sample_list': [
+ {
+ 'data':
+ torch.rand(bs, 3, 224, 224, dtype=torch.float32),
+ 'invalid_mask':
+ None,
+ 'modality':
+ 'image',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'id': list(range(bs)),
+ 'path': ['hah' for _ in range(bs)]
+ }
+ },
+ ],
+ 'target_sample_list': [],
+ 'target_idx_list': [torch.randint(0, 1000, (bs, ))],
+ 'target_set_list': ['ImageNet22k'],
+ 'shared_target_sets': {
+ 'ImageNet22k': [{
+ 'data':
+ torch.randint(0, 49411, (1000, 11)),
+ 'invalid_mask':
+ torch.zeros(1000, 11, dtype=torch.bool),
+ 'modality':
+ 'text',
+ 'data_type': 'target',
+ 'sample_info': {
+ 'distributed': True,
+ 'total_num': 1000,
+ }
+ }]
+ },
+ 'task_info': {
+ 'task_name': 'imagenet',
+ 'task_type': 'image_classification',
+ 'dataset_name': 'ImageNet22k',
+ 'batchsize': None,
+ 'sampling_ratio': None
+ }
+}
+```
+* mscoco caption
+``` data = {
+ 'input_sample_list': [
+ {
+ 'data':
+ torch.rand(bs, 3, 224, 224, dtype=torch.float32),
+ 'invalid_mask':
+ None,
+ 'modality':
+ 'image',
+ 'data_type': 'input',
+ 'sample_info': [{
+ 'id': id,
+ 'path': 'hahah',
+ 'bs': bs
+ } for _ in range(bs)]
+ },
+ {
+ 'data':
+ torch.randint(0, 49411, (bs, 31 * 2)),
+ 'invalid_mask':
+ torch.zeros(bs, 31 * 2, dtype=torch.bool),
+ 'modality':
+ 'text',
+ 'data_type': 'input',
+ 'sample_info': [{
+ 'pe_index':
+ torch.cat([torch.arange(31),
+ torch.arange(31)],
+ dim=0)
+ } for _ in range(bs)]
+ },
+ ],
+ 'target_sample_list': [],
+ 'target_idx_list': [torch.randint(0, 49411, (bs, 31))],
+ 'target_set_list': ['Vocab_Word'],
+ 'shared_target_sets': {
+ 'Vocab_Word': [{
+ 'data': torch.randint(0, 49411, (49411, 2)),
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'target',
+ 'sample_info': {
+ 'distributed': True,
+ 'total_num': 49411,
+ }
+ }]
+ },
+ 'task_info': {
+ 'task_name': 'mscoco_caption',
+ 'task_type': 'image_caption',
+ 'dataset_name': 'MSCOCO',
+ 'batchsize': None,
+ 'sampling_ratio': None
+ }
+}
+```
+
+
+* text_mlm
+```
+data = {
+ 'input_sample_list': [
+ {
+ 'data': torch.randint(0, 49411, (bs, 128)),
+ 'invalid_mask': torch.zeros(bs, 128, dtype=torch.bool),
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'seq_length': 128
+ }
+ },
+ ],
+ 'target_sample_list': [],
+ 'target_idx_list': [torch.randint(0, 49411,
+ (bs, 128))], # most are -1,
+ 'target_set_list': ['Vocab_Word'],
+ 'shared_target_sets': {
+ 'Vocab_Word': [{
+ 'data': torch.randint(0, 49411, (49411, 2)),
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'target',
+ 'sample_info': {
+ 'distributed': True,
+ 'total_num': 49411,
+ }
+ }]
+ },
+ 'task_info': {
+ 'task_name': 'bookswiki_pretrain',
+ 'task_type': 'text_mlm',
+ 'dataset_name': 'BooksWiki',
+ 'batchsize': None,
+ 'sampling_ratio': None
+ }
+}
+```
+
+
+ * mscoco retrieval
+ ```
+data = {
+ 'input_sample_list': [
+ {
+ 'data':
+ torch.rand(bs, 3, 224, 224, dtype=torch.float32),
+ 'invalid_mask':
+ None,
+ 'modality':
+ 'image',
+ 'sample_info': {
+ 'id': list(range(bs)),
+ 'path': ['hah' for _ in range(bs)]
+ }
+ },
+ ],
+ 'target_sample_list': [
+ {
+ 'data': torch.randint(0, 49411, (bs, 30)),
+ 'invalid_mask': torch.zeros(bs, 30,
+ dtype=torch.bool),
+ 'modality': 'text',
+ 'sample_info': {}
+ },
+ ],
+ 'target_idx_list': [],
+ 'target_set_list': [],
+ 'shared_target_sets': {
+ 'ImageNet22k': [{
+ 'data':
+ torch.randint(0, 49411, (1000, 11)),
+ 'invalid_mask':
+ torch.zeros(1000, 11, dtype=torch.bool),
+ 'modality':
+ 'text',
+ 'sample_info': {
+ 'distributed': True,
+ 'total_num': 1000,
+ }
+ }]
+ },
+ 'task_info': {
+ 'task_name': 'mscoco_retrieve',
+ 'task_type': 'image_retrieval',
+ 'dataset_name': 'MSCOCO',
+ 'batchsize': None,
+ 'sampling_ratio': None
+ }
+}
+```
\ No newline at end of file
diff --git a/data/finetuning.md b/data/finetuning.md
new file mode 100644
index 0000000000000000000000000000000000000000..a3d0afda4a8e4a764790c80a626a01f5687243ff
--- /dev/null
+++ b/data/finetuning.md
@@ -0,0 +1,27 @@
+# Fine-tuning
+
+For reproducing the fine-tuning results in our paper, we provide the corresponding fine-tuning configs in `configs/BERT_L12_H768_experiments/finetuning` and `configs/BERT_L12_H768_experiments/moe_finetuning` for Uni-Perceiver-Base and Uni-Perceiver-MoE-Base, respectively.
+
+
+Specifically, we fine-tuned the ImageNet-1K dataset with image classification task. For video classification, we fine-tuned Kinetics-400. We also employed image caption and image-text retrieval tasks on MSCOCO caption and FLicker-30K datasets.
+In addition, language understand tasks are fine-tuned on GLUE benchmarks, and video caption and video-text retrieval tasks are conducted on MSVD dataset.
+Please perpare the dataset following [PREPARE_DATA.md](prepare_data.md)
+
+---
+
+In our experiments, fine-tuning on all datasets exception GLUE benchmarks is performed on 16 NVIDIA-V100 GPUs with 80GB memory.
+GLUE tasks are all performed on 1 GPU.
+Taking Imagenet-1K as an example, the __Uni-Perceiver-Base__ can be fine-tuned as
+```
+
+sh run.sh configs/BERT_L12_H768_experiments/finetuning/in1k_training.yaml in1k-ft 16 partitionname MODEL.WEIGHTS work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth
+
+```
+The __Uni-Perceiver-MoE-Base__ can also be fine-tuned in a similar way:
+```
+sh run.sh configs/BERT_L12_H768_experiments/moe_finetuning/in1k_training.yaml in1k-moe-ft 16 partitionname MODEL.WEIGHTS work_dirs/pretrained_models/uni-perceiver-moe-base-L12-H768-224size-pretrained.pth
+```
+
+
+Note that we used only a few sets of hyperparameters in those task and did not adjust them carefully. Maybe hyper-parameter search can lead to further performance improvement.
+
diff --git a/data/inference.md b/data/inference.md
new file mode 100644
index 0000000000000000000000000000000000000000..a5d46bf853b55806d5b88e55b7004d957890d467
--- /dev/null
+++ b/data/inference.md
@@ -0,0 +1,12 @@
+# Inference
+
+Uni-Perceiver models have excellent generalization ability. Thus you can evalute the pre-trained model for any task and dataset.
+For any experiment, inference mode can be activated when argument `--eval-only` is passed.
+
+For example, you can conduct zero-shot inference on MSVD caption with a pre-trained checkpoint:
+```
+sh run.sh configs/BERT_L12_H768_experiments/zeroshot_config/msvd_caption.yaml msvd_cap_infer 8 partitionname --eval-only MODEL.WEIGHTS work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth
+```
+
+
+More inference configs are provided in floder `configs/BERT_L12_H768_experiments/zeroshot_config`.
\ No newline at end of file
diff --git a/data/other_results.md b/data/other_results.md
new file mode 100644
index 0000000000000000000000000000000000000000..96fd7570c8ae08c9f17c21c165fad2a7ee8010cd
--- /dev/null
+++ b/data/other_results.md
@@ -0,0 +1,48 @@
+## GLUE results
+We also evalute the language understanding performance of Uni-Perceiver on GLUE benchmarks.
+The results are listed as below.
+
+
+
+ Dataset |
+ MNLI |
+ QNLI | QQP | RTE | SST-2 | MRPC | CoLA |
+
+
+
+ Metric | Acc | Acc | F1 | Acc | Acc | F1 | Acc |
+
+
+
+
+ Uni-PerceiverBASE | 79.7 | 87.3 | 86.7 | 71.1 | 89.3 | 86.0 | 43.1 |
+
+
+ Uni-Perceiver-MoEBASE | 81.5 | 88.2 | 87.8 | 75.8 | 90.9 | 87.1 | 52.2 |
+
+
+ Uni-PerceiverLARGE | 82.5 | 89.2 | 87.7 | 73.7 | 91.2 | 90.2 | 52.0 |
+
+
+ Uni-Perceiver-MoELARGE | 85.7 | 91.9 | 89.5 | 78.4 | 93.4 | 91.2 | 57.4 |
+
+
+
+ ---
+
+* All fine-tuning experiments are performed on 1 GPU.
+
+* We use the hyper-parameters for GLUE tasks from [fair-seq](https://github.com/facebookresearch/fairseq/blob/main/examples/bart/README.glue.md)
+
+Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
+---|---|---|---|---|---|---|---|---
+`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
+`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
+`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
+`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
+`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107 | 1334 | 1799
+`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
+
+* Following RoBerta, we finetune RTE, STS and MRPC starting from the
+MNLI single-task model, rather than the baseline
+pretrained model.
diff --git a/data/prepare_data.md b/data/prepare_data.md
new file mode 100644
index 0000000000000000000000000000000000000000..f106ce7d10581311baba612aa8c731f080805ff7
--- /dev/null
+++ b/data/prepare_data.md
@@ -0,0 +1,510 @@
+# Prepare Data
+
+* By default, all training data used for this repository will be searched in the directory `DATA_PATH`. Please specify or change your data location before code running as following:
+
+ ```
+ export DATA_PATH='/mnt/lustre/share_data/zhujinguo/open_source_dataset'
+ ```
+
+
+* To make it very easy for you to run the training code, we provide [a toy dataset](https://drive.google.com/file/d/14GZPYqVLiXVYjxRGvC9WO4WK3WiwBPRY/view?usp=sharing), which is a small subset of our pretraing data.
+You should download this and unzip this file to `DATA_PATH`.
+
+ With the data of this subset, you can train Uni-Perceiver with config file [configs/BERT_L12_H192_experiments/4tasks_training_small_datasets.yaml](../configs/BERT_L12_H192_experiments/4tasks_training_small_datasets.yaml).
+ Please refer to [pretraining.md](./pretraining.md) for training usage.
+
+* For tasks with a fixed candidate target sets, such as image / video classification (where the target sets are the category labels) and masked language modeling (where the target set is the vocabulary), you also need to perpare the target set file. Please refer to the jupyter notebook [tools/generate_target_sets.ipynb](../tools/generate_target_sets.ipynb) for details.
+
+* For the complete datasets for training our models, please download datasets according to the instructions below:
+
+## Different datasets
+
+### Todo List:
+- [x] Imagenet-21k and Imagenet-1k
+- [x] books&wiki
+- [x] MSCOCO Caption
+- [x] YFCC
+- [x] CC12M
+- [x] CC3M
+- [x] Visual Genome
+- [x] SBU
+- [x] Kinetics-400 & Kinetics-700
+- [x] Moments in Time
+- [x] Flickr30k
+- [x] MSVD
+- [x] MSR-VTT
+- [x] GLUE
+- [x] VQA
+
+### Imagenet-1k
+
+1. Please download the images of imagenet dataset from the official website [Imagenet](https://image-net.org/).
+
+2. We provide the annotation files (including train.txt, val.txt and test.txt) on [meta](https://drive.google.com/file/d/1piqII0qGHmK1pop0RjdoFx927hcm1Mny/view).
+
+3. a) Tokenizing imagenet class names to generate "imagenet_class_name_CLIP_with_endoftext.pkl" using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb)
+
+ b) Or using generated file we provide from [here](https://drive.google.com/file/d/1bgFohNsppe7kksxTbWgSFMoEZuX0szc_/view?usp=sharing)
+4. Organize them as follows:
+ ```
+ DATA_PATH/
+ └── imagenet/
+ ├── imagenet_class_name_CLIP_with_endoftext.pkl
+ ├── meta
+ │ ├── test.txt
+ │ ├── train.txt
+ │ └── val.txt
+ ├── test
+ │ ├── ILSVRC2012_test_00000001.JPEG
+ │ ├── ILSVRC2012_test_00000002.JPEG
+ │ ├── ILSVRC2012_test_00000003.JPEG
+ │ ├── ILSVRC2012_test_00000004.JPEG
+ │ └── ...
+ ├── train
+ │ ├── n01440764
+ │ │ ├── n01440764_10026.JPEG
+ | │ ├── n01440764_10027.JPEG
+ | │ ├── n01440764_10029.JPEG
+ | │ └── ...
+ │ ├── n01443537
+ | │ └── ...
+ │ ├── n01484850
+ | │ └── ...
+ | └── ...
+ └─── val
+ ├── ILSVRC2012_val_00000001.JPEG
+ ├── ILSVRC2012_val_00000002.JPEG
+ ├── ILSVRC2012_val_00000003.JPEG
+ └── ...
+
+ ```
+
+
+
+
+
+
+### Imagenet-22k
+1. Please refer to Imagenet-1K dataset.
+
+2. Meta file is provided from [here](https://drive.google.com/file/d/1TDF0i8tXTB-K-zYOVhsmmocAtgrKpOG8/view?usp=sharing)
+
+3. Imagenet class name file in [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb) for tokenizing is provided from [here](https://drive.google.com/file/d/1cJHD5Ysxfr4tRMqAAwjOah2glfFktiT1/view?usp=sharing). Or you can directly use the CLIP-tokenized imagenet-22K class name files is provided from [here](https://drive.google.com/file/d/1juSGVP8IjERXoM-AwxKRDtLk65p9FTds/view?usp=sharing)
+
+### Books&wiki
+1. please download files [wiki.doc](https://drive.google.com/file/d/1rZJ-Nj_SSqwu85tME3wbN8tfGhljfAsf/view) abd [bc1g.doc](https://drive.google.com/file/d/16T5EYqIjO-tAj1OFxz6bnnzEABCusCcv/view).
+And put them together into a file:
+ ```
+ cat wiki.doc bc1g.doc > bookswiki.doc
+ ```
+2. a) Tokenizing vocabularies to generate "vocabulary_CLIP_with_endoftext.pkl" using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb)
+
+ b) Or using generated file we provide from [here](https://drive.google.com/file/d/1omEahjKjeWe0a4PSXEHaGE_WVdiZLf4W/view?usp=sharing)
+
+3. Then put this files in `DATA_PATH`
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── bert_pretrain_data/
+ └─ bookswiki/
+ └── bookswiki.doc
+
+ ```
+4. you can also download the plain text dataset from [huggingface.co/datasets/wikipedia](https://huggingface.co/datasets/wikipedia) and [huggingface.co/datasets/bookcorpus](https://huggingface.co/datasets/bookcorpus).
+
+### MSCOCO
+
+1. Please download the images of COCO 2014 from [MSCOCO](https://cocodataset.org/#download).
+2. Download preprocessed coco captions from Karpathy's homepage: [link](http://cs.stanford.edu/people/karpathy/deepimagesent/caption_datasets.zip) and extract "dataset_coco.json" from zip file.
+
+3. a) You can run the [coco_preprocess.py](./preprocess/coco_preprocess.py) file to split the dataset_coco.json file into train, val and test part:
+
+ 1. walk into the /data/preprocess folder and open the [coco_preprocess.py](./preprocess/coco_preprocess.py) file;
+ 2. fill the 'original_json' variable with the path you download the dataset_coco.json file.
+ 3. fill the 'savepath' with the path you want to save the splited json file.
+ 4. run the [coco_preprocess.py](./preprocess/coco_preprocess.py) file.
+
+ b) Or you can directly use the generated json file we provide from [here](https://drive.google.com/file/d/12XUh4-Lb82RXg7Sa-Vgtut2dhIqrp7Sy/view)
+
+4. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+5. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── mscoco_dataset/
+ ├── new_annotations
+ │ ├── captions_test5k.json
+ │ ├── captions_train113k.json
+ │ ├── captions_val5k.json
+ | └── dataset_coco.json
+ └── coco_origin
+ ├── train2014
+ │ ├── COCO_train2014_000000000009.jpg
+ | ├── COCO_train2014_000000000025.jpg
+ | ├── COCO_train2014_000000000030.jpg
+ │ └── ...
+ └── val2014
+ ├── COCO_val2014_000000000042.jpg
+ ├── COCO_val2014_000000000073.jpg
+ ├── COCO_val2014_000000000074.jpg
+ └── ...
+
+
+ ```
+
+### Visual Genome
+1. Please download the images and region decriptions of visual genome from [VG](https://visualgenome.org/api/v0/api_home.html).
+
+2. a) You can run the [region_descriptions.ipynb](./preprocess/region_descriptions.ipynb) to preprocess the downloaded "region_descriptions.json" file:
+
+ 1. walk into the /data/preprocess folder and open the [region_descriptions.ipynb](./preprocess/region_descriptions.ipynb);
+ 2. fill the path of downloaded 'region_descriptions.json' and the path you want to save the processed file.
+ 3. run the [region_descriptions.ipynb](./preprocess/region_descriptions.ipynb).
+
+ b) Or you can directly use the generated json file we provide from [here](https://drive.google.com/file/d/1pnl30qAPr03RpKbdbH13YZI9GtWseEHf/view)
+3. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+4. Organize the files into following structure:
+
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── visual_genome/
+ ├── annotations
+ │ ├── region_descriptions.json
+ │ ├── vg_captions_128filter.json
+ └── images
+ ├── VG_100K
+ │ ├── 2.jpg
+ | ├── 3.jpg
+ | ├── 4.jpg
+ │ └── ...
+ └── VG_100K_2
+ ├── 1.jpg
+ ├── 51.jpg
+ ├── 52.jpg
+ └── ...
+
+
+ ```
+
+### Flickr30k
+1. Please download the images of filckr30k according to the instruction of [Flickr30k](http://shannon.cs.illinois.edu/DenotationGraph/).
+2. Download [flickr_jsons](https://drive.google.com/file/d/1_dJsD8_YXWtR0124X_RiEgcx1c6B_BUM/view) which provides the annotations of flickr30k images.
+3. a) You can run the [process_flickr_caption_json.py](./preprocess/process_flickr_caption_json.py) to preprocess the json file:
+
+ 1. walk into the /data/preprocess folder and open the [process_flickr_caption_json.py](./preprocess/process_flickr_caption_json.py);
+ 2. fill the path of downloaded json files and fill the path you want to save the processed json files.
+ 3. run the [process_flickr_caption_json.py](./preprocess/process_flickr_caption_json.py).
+
+ b) Or you can directly use the generated json files (including captions_test.json, captions_train.json and captions_val.json) we provide from [here](https://drive.google.com/file/d/1WIWUKbXfBJd1S0izTe_OuP7bjCFDk2wk/view)
+
+4. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+5. Organize the files into following structure:
+
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ ├── flickr30k/
+ │ ├── captions_test.json
+ │ ├── captions_train.json
+ │ └── captions_val.json
+ └── flickr30k_images
+ └── flickr30k_images
+ └── flickr30k_images
+ ├── 36979.jpg
+ ├── 65567.jpg
+ └── ...
+
+ ```
+
+### SBU
+1. Please download the SBU [url](https://drive.google.com/file/d/1Hfbw8DVSnE3ZAaWZ7C6d6hlUde7Pr_YN/view) and [caption](https://drive.google.com/file/d/1GY_kFyiFqOHAYvjfRdM98LMlAsnFmxic/view?usp=sharing) files.
+2. Filling the path of above files in [sbu_download_list.py](./preprocess/sbu/sbu_download_list.py) and run it for generating the download_list.
+3. Running the script [sbu_download.sh](./preprocess/sbu/sbu_download.sh) to download the sbu images.
+4. a) You can run the [make_sbu_json.py](./preprocess/sbu/make_sbu_json.py) to get the annotation file:
+
+ b) Or you can directly download the generated json file [sbucaption.json](https://drive.google.com/file/d/1xFJPvyJNlH0jzqzHRN16Hk5DmKGiGEJE/view) we provide.
+5. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+6. Organize the files into following structure:
+
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── sbucaption/
+ ├── annotations
+ │ └── sbucaptions.json
+ └── images
+ ├── 4385058960_b0f291553e.jpg
+ ├── 5148648301_1174ef59bc.jpg
+ └── ...
+
+ ```
+### CC3M
+1. Please download "Train_GCC-training.tsv" and "Validation_GCC-1.1.0-Validation.tsv" from [here](https://ai.google.com/research/ConceptualCaptions/download)
+2. Filling the path of "Train_GCC-training.tsv" in [cc3m_train_download_list.py](./preprocess/cc3m/cc3m_train_download_list.py) and run it for generating the training download list.
+3. Filling the path of "Validation_GCC-1.1.0-Validation.tsv" in [cc3m_val_download_list.py](./preprocess/cc3m/cc3m_val_download_list.py) and run it for generating the validation download list.
+4. Running the script [cc3m_train_download.sh](./preprocess/cc3m/cc3m_train_download.sh) and [cc3m_val_download.sh](./preprocess/cc3m/cc3m_val_download.sh) to download the cc3m images.
+5. Zip (without compression) "train_image", "val_image" by:
+ ```
+ zip -0 ../train_image.zip ./*
+ zip -0 ../val_image.zip ./*
+
+ ```
+6. a) You can run the [make_cc3m_train_json.py](./preprocess/cc3m/make_cc3m_train_json.py) and [make_cc3m_val_json.py](./preprocess/cc3m/make_cc3m_val_json.py) to get the annotation file:
+
+ b) Or you can directly download the generated json files [train_spacy.json](https://drive.google.com/file/d/1_bqx0xQOQC3bd40GLMC27TyLRi1tHRlC/view) and [val_spacy.json](https://drive.google.com/file/d/11ibsX_K-hgdHiomk9c6JvuAl2kYW8tjt/view) we provide.
+7. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+8. Organize the files into following structure:
+
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── cc3m/
+ ├── train_spacy.json
+ ├── val_spacy.json
+ ├──train_image
+ │ ├── 00000000.jpg
+ │ └── ...
+ └── val_image
+ ├── 00000000.jpg
+ └── ...
+
+ ```
+
+### CC12M
+1. Please download "cc12m.tsv" from [here](https://github.com/google-research-datasets/conceptual-12m)
+2. Filling the path of "cc12m.tsv" in [cc12m_train_download_list.py](./preprocess/cc12m/cc12m_train_download_list.py) and run it for generating the training download list.
+3. Running the script [cc12m_train_download.sh](./preprocess/cc12m/cc12m_train_download.sh) to download the cc12m images.
+5. Zip (without compression) "train_image" by:
+ ```
+ zip -0 ../train_image.zip ./*
+ ```
+5. a) You can run the [make_cc12m_train_json.py](./preprocess/cc12m/make_cc12m_train_json.py) to get the annotation file:
+
+ b) Or you can directly download the generated json file [train_available.json](https://drive.google.com/file/d/1SVHmHpewvmpCbWDCsLSbwQ8lhusQXEIt/view) we provide.
+6. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+7. Organize the files into following structure:
+
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── c12m/
+ ├── train_available.json
+ └── train_image
+ ├── 00000000.jpg
+ └── ...
+
+ ```
+
+### Kinetics-400 & Kinetics-700
+1. Please download the Kinectics-400 & Kinetics-700 videos according to the instructions of [this](https://github.com/cvdfoundation/kinetics-dataset)
+
+2. a)
+
+ i. Filling the path of K400's "training" and "validation" folder you download in [k400_construct_csv.py](./preprocess/k400_construct_csv.py) and run it for generating the K400 related files (K400_val.csv, K400_train.csv, categories.txt, annotation.json).
+
+ ii. Filling the path of K700's "training" and "validation" folder you download in [k700_construct_csv.py](./preprocess/k700_construct_csv.py) and run it for generating the K700 related files (K700_val.csv, K700_train.csv, categories.txt, annotation.json).
+
+ iii. Running script [video_categories.ipynb](../tools/video_categories.ipynb) to generate "category_mapping.txt".
+
+ b) Or you can directly download the processed files we provide: [K400](https://drive.google.com/file/d/1YqchifEjoovZYJ77Egn5pHv3E1olRIpq/view?usp=sharing), [K700](https://drive.google.com/file/d/1fHdcBRdU27w7OfNijP0ZBsNbxLQSLfRa/view?usp=sharing)
+
+3. a) Tokenizing K400, K700 class names to generate "k400_class_name_CLIP_with_endoftext.pkl" and "k700_class_name_CLIP_with_endoftext.pkl" using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb)
+
+ b) Or using generated file we provide from [K400-CLIP](https://drive.google.com/file/d/1V-SpRzugmFgHR6j7ifFLqh7Ao5VW-gM8/view?usp=sharing) and [K700-CLIP](https://drive.google.com/file/d/1lq9WaEWh1lmfBv4pOs8Aj9yTrTxkQc3Z/view?usp=sharing)
+
+
+4. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── k400_class_name_CLIP_with_endoftext.pkl
+ └── K400/
+ ├── training
+ │ ├── abseiling
+ │ │ ├── _4YTwq0-73Y_000044_000054.mp4
+ │ │ └── ...
+ │ ├── air_drumming
+ │ └── ...
+ ├── validation/
+ │ ├── abseiling
+ │ │ ├── __NrybzYzUg.mkv
+ │ │ └── ...
+ │ ├── air_drumming
+ │ └── ...
+ ├── annotation.json
+ ├── category_mapping.txt
+ ├── categories.txt
+ ├── K400_train.csv
+ └── K400_val.csv
+ ```
+ K700 is similar.
+
+### MomentsInTime
+
+1. Please download the MomentsInTime videos according to the instructions of [Official Website](http://moments.csail.mit.edu/)
+
+2. a)
+
+ i. Filling the path of "training" folder you download in [moments_construct_csv.py](./preprocess/moments_construct_csv.py) and run it for generating the training files (moments_train.csv, categories.txt, annotation.json).
+
+ ii. Running script [video_categories.ipynb](../tools/video_categories.ipynb) to generate "category_mapping.txt".
+
+ b) Or you can directly download the processed files we provide: [moments](https://drive.google.com/file/d/1aXVCBKrocatZfT8TRKv4TuxHkTa7SxMz/view?usp=sharing).
+3. a) Tokenizing momentsInTime class names to generate "MiT_class_name_CLIP_with_endoftext.pkl" using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb)
+
+ b) Or using generated file we provide from [MiT-CLIP](https://drive.google.com/file/d/1xNC8Dld-0x735nO60cwUBYG3gTVvoPC8/view?usp=sharing)
+
+4. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── MiT_class_name_CLIP_with_endoftext.pkl
+ └── MomentsInTime/
+ ├── training
+ │ ├── adult+female+singing
+ │ │ ├── 0a2b81cb0ec5fde79b8c.mp4
+ │ │ └── ...
+ │ ├── adult+female+speaking
+ │ └── ...
+ ├── annotation.json
+ ├── categories.txt
+ ├── category_mapping.txt
+ └── moments_train.csv
+ ```
+
+
+### MSVD
+1. Download MSVD videos "YoutTubeClips.tar" from [here](https://www.cs.utexas.edu/users/ml/clamp/videoDescription/) and preprocessed "txt_labels" from [here](https://github.com/nasib-ullah/video-captioning-models-in-Pytorch/tree/main/MSVD/captions).
+2. a) Fill the path of downloaded files in [msvd_preprocess.py](./preprocess/msvd_preprocess.py) to generate the annotation files (caption_msvd_train_cocostyle.json, caption_msvd_val_cocostyle.json, caption_msvd_test_cocostyle.json)
+
+ b) Or directly download the annotation files we provide [new_annotations](https://drive.google.com/file/d/1VHT8waNVp8LUFlfY_YACCbVrzBQW3sWy/view)
+
+3. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+4. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── msvd_dataset/
+ ├── new_annotations
+ │ ├── caption_msvd_test_cocostyle.json
+ │ ├── caption_msvd_train_cocostyle
+ │ └── caption_msvd_val_cocostyle
+ ├── txt_labels
+ │ ├── sents_test_lc_nopunc.txt
+ │ ├── sents_train_lc_nopunc.txt
+ │ ├── sents_train_lc_nopunc.txt
+ │ └── youtube_mapping.txt
+ └── YouTubeClips
+ ├── _0nX-El-ySo_83_93.avi
+ └── ...
+ ```
+### MSR-VTT
+1. Download MSRVTT videos ("train_val_videos.zip", "test_videos.zip") and annotation files ("train_val_annotation.zip", "test_videodatainfo.zip") from [here](https://www.mediafire.com/folder/h14iarbs62e7p/shared) and download dataset split info from [here](https://github.com/ArrowLuo/CLIP4Clip/releases/download/v0.0/msrvtt_data.zip).
+2. Unzip downloaded files above, fill the paths of "test_videodatainfo.json", "train_val_videodatainfo.json", "MSRVTT_train.9k.csv", "MSRVTT_JSFUSION_test.csv" in the [msrvtt_dataprocess_1k.ipynb](./preprocess/msrvtt_dataprocess_1k.ipynb)
+
+ b) Or directly download the annotation files ("caption_msrvtt_1k_trainval_cocostyle.json","caption_msrvtt_1k_test_cocostyle.json") we provide [annotations_new](https://drive.google.com/file/d/1ZnA4hEic6x9D7dfaEUPoa6MlQ30rITom/view)
+3. Generating tokenized vocabularies as mentioned in [bookswiki part](#vocab)
+4. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ └── msrvtt_dataset/
+ ├── annotations_new
+ │ ├── caption_msrvtt_1k_trainval_cocostyle.json
+ │ └── caption_msrvtt_1k_test_cocostyle.json
+ └── videos
+ ├── video0.mp4
+ └── ...
+ ```
+
+### VQA
+
+1. Download VQA meta data from the datalink [vilbert](https://github.com/jiasenlu/vilbert_beta/tree/master/data) provided, files including:
+ - dictionary.pkl
+ - train_ids.pkl
+ - val_ids.pkl
+ - train_target.pkl
+ - trainval_ans2label.pkl
+ - val_target.pkl
+ - trainval_label2ans.pkl
+
+2. Download VG questions and answers from [here](https://drive.google.com/drive/folders/10XHRXg07lNbdZQrREhOLYVM3N0LrkTxB)
+
+
+
+
+
+
+3. Download VQA annotations from the [link](https://visualqa.org/download.html) xmodaler provided, files including:
+ - vg_target.pkl
+ - VG_questions2.json
+ - download
+ - VG_annotations.json
+4. Download VQA annotations from [VQA](https://visualqa.org/download.html) website, files including:
+ - v2_OpenEnded_mscoco_test2015_questions.json
+ - v2_OpenEnded_mscoco_train2014_questions.json
+ - v2_OpenEnded_mscoco_val2014_questions.json
+
+5. a) Tokenizing all the possible answers using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb).
+
+ b) Or you can use the tokenized answers we provide [VQA_Answers](https://drive.google.com/file/d/1X-1blHh2MrYhDq9bkdndNVRZ-49VCsuz/view?usp=sharing).
+6. Organize the files into following structure:
+ ```
+ DATA_PATH/
+ ├── vocabulary_CLIP_with_endoftext.pkl
+ ├── mscoco_dataset/
+ | └── coco_origin
+ | ├── train2014
+ | │ ├── COCO_train2014_000000000009.jpg
+ | | ├── COCO_train2014_000000000025.jpg
+ | | ├── COCO_train2014_000000000030.jpg
+ | │ └── ...
+ | └── val2014
+ | ├── COCO_val2014_000000000042.jpg
+ | ├── COCO_val2014_000000000073.jpg
+ | ├── COCO_val2014_000000000074.jpg
+ | └── ...
+ └── VQA
+ ├── trainval_ans2label.pkl
+ ├── trainval_label2ans.pkl
+ ├── v2_OpenEnded_mscoco_train2014_questions.json
+ ├── v2_OpenEnded_mscoco_val2014_questions.json
+ ├── v2_OpenEnded_mscoco_test-dev2015_questions.json
+ ├── val_target.pkl
+ ├── VG_questions2.json
+ ├── vg_target.pkl
+ └── coco_map.json
+
+
+ ```
+### GLUE
+1. Follow the instructions of [this](https://github.com/nyu-mll/GLUE-baselines) to download GLUE benchmark data and refer to [fairseq](https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.glue.md) to preprocess datasets.
+2. a) Tokenizing GLUE datasets using [generate_target_sets.ipynb](../tools/generate_target_sets.ipynb).
+
+ b) Or you can use the tokenized answers we provide [GLUE_classnames](https://drive.google.com/file/d/1HR7xHsIRsS4iUwGr3CX6h5z_dVn-EJt-/view?usp=sharing).
+
+3. Organize the files into following structure:
+```
+ DATA_PATH/
+ ├── GLUE_classnames
+ └── bert_pretrain_data/
+ └── glue_data
+ ├── CoLA
+ ├── CoLA-bin
+ ├── diagnostic
+ ├── MNLI
+ ├── MNLI-bin
+ ├── MRPC
+ ├── MRPC-bin
+ ├── QNLI
+ ├── QNLI-bin
+ ├── QQP
+ ├── QQP-bin
+ ├── RTE
+ ├── RTE-bin
+ ├── SST-2
+ ├── SST-2-bin
+ ├── STS-B
+ ├── STS-B-bin
+ └── WNLI
+```
+
diff --git a/data/preprocess/cc12m/cc12m_train_download.sh b/data/preprocess/cc12m/cc12m_train_download.sh
new file mode 100644
index 0000000000000000000000000000000000000000..218a56f9652f4d794bc90d142330b3589e065fc7
--- /dev/null
+++ b/data/preprocess/cc12m/cc12m_train_download.sh
@@ -0,0 +1,8 @@
+# use 20 threads
+
+cat train4download.txt | xargs -n 2 -P 20 wget -nc -U 'Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17' --timeout=1 --waitretry=0 --tries=5 --retry-connrefused -nv -O
+find ../train_image -type f -size -1c -exec rm {} \;
+ls -d ../train_image/* | xargs -n 1 -P 20 python check_valid.py | tee train_size_invalid.txt
+xargs rm < train_size_invalid.txt
+rm train_size_invalid.txt
+ls ../train_image > train_valid.txt
\ No newline at end of file
diff --git a/data/preprocess/cc12m/cc12m_train_download_list.py b/data/preprocess/cc12m/cc12m_train_download_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..984b60e1c4b3d44a7a75ca74fc363d9826dbb6a8
--- /dev/null
+++ b/data/preprocess/cc12m/cc12m_train_download_list.py
@@ -0,0 +1,17 @@
+import os
+
+captions = []
+urls = []
+
+with open('cc12m.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+with open('train4download.txt', 'w') as fp:
+ for cnt, url in enumerate(urls):
+ fp.write("../train_image/{:08d}.jpg\t\"{}\"\n".format(cnt, url))
+
+if not os.path.exists('../train_image'):
+ os.makedirs('../train_image')
\ No newline at end of file
diff --git a/data/preprocess/cc12m/check_valid.py b/data/preprocess/cc12m/check_valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0fd4ce4dd811996ecc164af2fc2ff3924bf170
--- /dev/null
+++ b/data/preprocess/cc12m/check_valid.py
@@ -0,0 +1,14 @@
+import sys
+from PIL import Image
+
+import warnings
+
+warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
+
+try:
+ im = Image.open(sys.argv[1]).convert('RGB')
+ # remove images with too small or too large size
+ if (im.size[0] < 10 or im.size[1] < 10 or im.size[0] > 10000 or im.size[1] > 10000):
+ raise Exception('')
+except:
+ print(sys.argv[1])
\ No newline at end of file
diff --git a/data/preprocess/cc12m/make_cc12m_train_json.py b/data/preprocess/cc12m/make_cc12m_train_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d285729076a6d5dc3194399cc3f87d73e56a478
--- /dev/null
+++ b/data/preprocess/cc12m/make_cc12m_train_json.py
@@ -0,0 +1,32 @@
+captions = []
+urls = []
+
+with open('cc12m.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+valids = set([])
+with open('train_valid.txt') as fp:
+ for cnt, line in enumerate(fp):
+ valids.add(line[:-1])
+
+import json
+with open('train.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"train_image.zip@/{}".format(im), 'caption':cap}
+ json.dump(d, outfile)
+ outfile.write('\n')
+
+
+import json
+with open('train_frcnn.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"train_image.zip@/{}".format(im), 'caption':cap, 'frcnn':"train_frcnn.zip@/{:08d}.json".format(cnt)}
+ json.dump(d, outfile)
+ outfile.write('\n')
\ No newline at end of file
diff --git a/data/preprocess/cc3m/cc3m_train_download.sh b/data/preprocess/cc3m/cc3m_train_download.sh
new file mode 100644
index 0000000000000000000000000000000000000000..218a56f9652f4d794bc90d142330b3589e065fc7
--- /dev/null
+++ b/data/preprocess/cc3m/cc3m_train_download.sh
@@ -0,0 +1,8 @@
+# use 20 threads
+
+cat train4download.txt | xargs -n 2 -P 20 wget -nc -U 'Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17' --timeout=1 --waitretry=0 --tries=5 --retry-connrefused -nv -O
+find ../train_image -type f -size -1c -exec rm {} \;
+ls -d ../train_image/* | xargs -n 1 -P 20 python check_valid.py | tee train_size_invalid.txt
+xargs rm < train_size_invalid.txt
+rm train_size_invalid.txt
+ls ../train_image > train_valid.txt
\ No newline at end of file
diff --git a/data/preprocess/cc3m/cc3m_train_download_list.py b/data/preprocess/cc3m/cc3m_train_download_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdd3f879aee1948d582a59403c77938a4e87ef03
--- /dev/null
+++ b/data/preprocess/cc3m/cc3m_train_download_list.py
@@ -0,0 +1,17 @@
+import os
+
+captions = []
+urls = []
+
+with open('Train_GCC-training.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+with open('train4download.txt', 'w') as fp:
+ for cnt, url in enumerate(urls):
+ fp.write("../train_image/{:08d}.jpg\t\"{}\"\n".format(cnt, url))
+
+if not os.path.exists('../train_image'):
+ os.makedirs('../train_image')
\ No newline at end of file
diff --git a/data/preprocess/cc3m/cc3m_val_download.sh b/data/preprocess/cc3m/cc3m_val_download.sh
new file mode 100644
index 0000000000000000000000000000000000000000..7d15ac0d549bf87b2fe8196b853a9ec8999803c7
--- /dev/null
+++ b/data/preprocess/cc3m/cc3m_val_download.sh
@@ -0,0 +1,8 @@
+# use 20 threads
+
+cat val4download.txt | xargs -n 2 -P 20 wget -nc -U 'Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17' --timeout=1 --waitretry=0 --tries=5 --retry-connrefused -nv -O
+find ../val_image -type f -size -1c -exec rm {} \;
+ls -d ../val_image/* | xargs -n 1 -P 20 python check_valid.py | tee val_size_invalid.txt
+xargs rm < val_size_invalid.txt
+rm val_size_invalid.txt
+ls ../val_image > val_valid.txt
\ No newline at end of file
diff --git a/data/preprocess/cc3m/cc3m_val_download_list.py b/data/preprocess/cc3m/cc3m_val_download_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c78354f09f25c40599ae8502c519359ec731cf
--- /dev/null
+++ b/data/preprocess/cc3m/cc3m_val_download_list.py
@@ -0,0 +1,17 @@
+import os
+
+captions = []
+urls = []
+
+with open('Validation_GCC-1.1.0-Validation.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+with open('val4download.txt', 'w') as fp:
+ for cnt, url in enumerate(urls):
+ fp.write("../val_image/{:08d}.jpg\t\"{}\"\n".format(cnt, url))
+
+if not os.path.exists('../val_image'):
+ os.makedirs('../val_image')
\ No newline at end of file
diff --git a/data/preprocess/cc3m/check_valid.py b/data/preprocess/cc3m/check_valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0fd4ce4dd811996ecc164af2fc2ff3924bf170
--- /dev/null
+++ b/data/preprocess/cc3m/check_valid.py
@@ -0,0 +1,14 @@
+import sys
+from PIL import Image
+
+import warnings
+
+warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
+
+try:
+ im = Image.open(sys.argv[1]).convert('RGB')
+ # remove images with too small or too large size
+ if (im.size[0] < 10 or im.size[1] < 10 or im.size[0] > 10000 or im.size[1] > 10000):
+ raise Exception('')
+except:
+ print(sys.argv[1])
\ No newline at end of file
diff --git a/data/preprocess/cc3m/make_cc3m_train_json.py b/data/preprocess/cc3m/make_cc3m_train_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..77fff5baf5bd726cee598fc70f4dc579184be543
--- /dev/null
+++ b/data/preprocess/cc3m/make_cc3m_train_json.py
@@ -0,0 +1,32 @@
+captions = []
+urls = []
+
+with open('Train_GCC-training.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+valids = set([])
+with open('train_valid.txt') as fp:
+ for cnt, line in enumerate(fp):
+ valids.add(line[:-1])
+
+import json
+with open('train.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"train_image.zip@/{}".format(im), 'caption':cap}
+ json.dump(d, outfile)
+ outfile.write('\n')
+
+
+import json
+with open('train_frcnn.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"train_image.zip@/{}".format(im), 'caption':cap, 'frcnn':"train_frcnn.zip@/{:08d}.json".format(cnt)}
+ json.dump(d, outfile)
+ outfile.write('\n')
\ No newline at end of file
diff --git a/data/preprocess/cc3m/make_cc3m_val_json.py b/data/preprocess/cc3m/make_cc3m_val_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ab8f727c2af62c15758873a91c988a7f64efd15
--- /dev/null
+++ b/data/preprocess/cc3m/make_cc3m_val_json.py
@@ -0,0 +1,31 @@
+captions = []
+urls = []
+
+with open('Validation_GCC-1.1.0-Validation.tsv') as fp:
+ for cnt, line in enumerate(fp):
+ s = line.split('\t')
+ captions.append(s[0].split(' '))
+ urls.append(s[1][:-1])
+
+valids = set([])
+with open('val_valid.txt') as fp:
+ for cnt, line in enumerate(fp):
+ valids.add(line[:-1])
+
+import json
+with open('val.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"val_image.zip@/{}".format(im), 'caption':cap}
+ json.dump(d, outfile)
+ outfile.write('\n')
+
+import json
+with open('val_frcnn.json', 'w') as outfile:
+ for cnt, (cap, url) in enumerate(zip(captions, urls)):
+ im = "{:08d}.jpg".format(cnt)
+ if (im in valids):
+ d = {'image':"val_image.zip@/{}".format(im), 'caption':cap, 'frcnn':"val_frcnn.zip@/{:08d}.json".format(cnt)}
+ json.dump(d, outfile)
+ outfile.write('\n')
\ No newline at end of file
diff --git a/data/preprocess/coco_preprocess.py b/data/preprocess/coco_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..b15507f6c3fa44d1b6944d19da928d93ae78b792
--- /dev/null
+++ b/data/preprocess/coco_preprocess.py
@@ -0,0 +1,50 @@
+import json
+from collections import defaultdict
+
+original_json = json.load(open("mscoco_dataset/new_annotations/dataset_coco.json"))
+
+subsets = ['train', 'val', 'test']
+savepath = "mscoco_dataset/new_annotations"
+
+import os
+if not os.path.exists(savepath):
+ os.makedirs(savepath)
+
+savename = {
+ 'train': "captions_train113k.json",
+ 'val': "captions_val5k.json",
+ 'test': "captions_test5k.json",
+}
+
+imagefields = defaultdict(list)
+annotationsfields = defaultdict(list)
+
+for imagecaps in original_json['images']:
+ filepath = imagecaps['filepath']
+ filename = imagecaps['filename']
+ image_id = int(filename.split(".")[0].split('_')[-1])
+ split = imagecaps['split']
+ if split == 'restval':
+ split = 'train'
+ imagefields[split].append({
+ "file_name": filename,
+ "file_path": filepath,
+ "id": image_id
+ })
+ for sen in imagecaps['sentences']:
+ annotationsfields[split].append({
+ "image_id": image_id,
+ "id": sen["sentid"],
+ "caption": sen["raw"],
+ })
+
+for subset in subsets:
+ data = {
+ "images": imagefields[subset],
+ "annotations": annotationsfields[subset]
+ }
+ json.dump(data, open(os.path.join(savepath, savename[subset]), "w"))
+pass
+
+
+
diff --git a/data/preprocess/k400_construct_csv.py b/data/preprocess/k400_construct_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..dab509668ebb59267ba906b37df9f90f39b8e089
--- /dev/null
+++ b/data/preprocess/k400_construct_csv.py
@@ -0,0 +1,35 @@
+import os
+import json
+
+
+data = dict()
+data['database'] = dict()
+
+categories = list()
+
+with open("K400_val.csv", 'w') as f:
+ for root, dirs, files in os.walk("./validation"):
+ label = root.strip().split('/')[-1]
+ if files and label not in categories:
+ categories.append(label)
+ for fi in files:
+ f.write("{},{}\n".format(os.path.join(root, fi), label))
+ data['database'][fi] = {'subset': 'validation', 'annotations': {'label': label}}
+
+with open("K400_train.csv", 'w') as f:
+ for root, dirs, files in os.walk("./training"):
+ label = root.strip().split('/')[-1]
+ if files and label not in categories:
+ categories.append(label)
+ for fi in files:
+ f.write("{},{}\n".format(os.path.join(root, fi), label))
+ data['database'][fi] = {'subset': 'training', 'annotations': {'label': label}}
+
+with open("categories.txt", 'w') as f:
+ for i, label in enumerate(categories):
+ f.write("{},{}\n".format(label, i))
+
+with open("annotation.json", 'w') as f:
+ json.dump(data, f)
+
+
diff --git a/data/preprocess/k700_construct_csv.py b/data/preprocess/k700_construct_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..437b886829e8008c45000dde7c73e5a44232014a
--- /dev/null
+++ b/data/preprocess/k700_construct_csv.py
@@ -0,0 +1,40 @@
+import os
+import json
+import random
+
+
+data = dict()
+data['database'] = dict()
+
+categories = list()
+
+
+with open("K700_val.csv", 'w') as f:
+ for root, dirs, files in os.walk("./validation"):
+ label = root.strip().split('/')[-1]
+ if files:
+ categories.append(label)
+ for fi in files:
+# f.write("{},{}\n".format(os.path.join(root, fi), label))
+ data['database'][fi] = {'subset': 'validation', 'annotations': {'label': label}}
+
+print("{} validation instances".format(len(data['database'])))
+
+with open("K700_train.csv", 'w') as f:
+ for root, dirs, files in os.walk("./training"):
+ label = root.strip().split('/')[-1]
+ if files:
+ categories.append(label)
+ for fi in files:
+ f.write("{},{}\n".format(os.path.join(root, fi), label))
+ data['database'][fi] = {'subset': 'training', 'annotations': {'label': label}}
+
+with open("categories.txt", 'w') as f:
+ categories = sorted(categories)
+ for i, label in enumerate(categories):
+ f.write("{},{}\n".format(label, i))
+
+with open("annotation.json", 'w') as f:
+ json.dump(data, f)
+
+
diff --git a/data/preprocess/moments_construct_csv.py b/data/preprocess/moments_construct_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..96039b0f4680f636129bf73eab06840e64bc5522
--- /dev/null
+++ b/data/preprocess/moments_construct_csv.py
@@ -0,0 +1,33 @@
+import os
+import scandir
+import json
+import random
+
+
+data = dict()
+data['database'] = dict()
+
+categories = list()
+
+print("{} validation instances".format(len(data['database'])))
+c = 0
+with open("moments_train.csv", 'w') as f:
+ for root, dirs, files in scandir.walk("./training"):
+ print(c)
+ c += 1
+ label = root.strip().split('/')[-1]
+ if files:
+ categories.append(label)
+ for fi in files:
+ f.write("{},{}\n".format(os.path.join(root, fi), label))
+ data['database'][fi] = {'subset': 'training', 'annotations': {'label': label}}
+
+with open("categories.txt", 'w') as f:
+ categories = sorted(categories)
+ for i, label in enumerate(categories):
+ f.write("{},{}\n".format(label, i))
+
+with open("annotation.json", 'w') as f:
+ json.dump(data, f)
+
+
diff --git a/data/preprocess/msrvtt_dataprocess_1k.ipynb b/data/preprocess/msrvtt_dataprocess_1k.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e3a2feaa2a968c24276dbb4d159b588c7c290331
--- /dev/null
+++ b/data/preprocess/msrvtt_dataprocess_1k.ipynb
@@ -0,0 +1,302 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import json\n",
+ "from collections import defaultdict\n",
+ "import csv"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['info', 'videos', 'sentences'])"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset_info = json.load(open('train_val_videodatainfo.json' ))\n",
+ "dataset_info.keys()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2990"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "test_dataset_info = json.load(open('test_videodatainfo.json' ))\n",
+ "test_dataset_info.keys()\n",
+ "len(test_dataset_info['videos'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd \n",
+ "train_9k_video_id = pd.read_csv(open('msrvtt_data/MSRVTT_train.9k.csv'))\n",
+ "train_9k_video_id = train_9k_video_id['video_id'].tolist()\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "9000"
+ ]
+ },
+ "execution_count": 18,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(train_9k_video_id)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "images_field = []\n",
+ "images_val_field = []\n",
+ "id2split = {}\n",
+ "\n",
+ "for video_info in dataset_info['videos']:\n",
+ " if video_info['video_id'] in train_9k_video_id:\n",
+ " images_field.append(\n",
+ " {\n",
+ " \"id\": int(video_info['video_id'].split('video')[-1]),\n",
+ " 'file_name': video_info['video_id']\n",
+ " }\n",
+ " )\n",
+ " \n",
+ " id2split[video_info['video_id']] = 'train'\n",
+ " else:\n",
+ " images_val_field.append(\n",
+ " {\n",
+ " \"id\": int(video_info['video_id'].split('video')[-1]),\n",
+ " 'file_name': video_info['video_id']\n",
+ " }\n",
+ " )\n",
+ " \n",
+ " id2split[video_info['video_id']] = 'test'\n",
+ "\n",
+ "\n",
+ "for video_info in test_dataset_info['videos']:\n",
+ "\n",
+ " if video_info['video_id'] in train_9k_video_id:\n",
+ " images_field.append(\n",
+ " {\n",
+ " \"id\": int(video_info['video_id'].split('video')[-1]),\n",
+ " 'file_name': video_info['video_id']\n",
+ " }\n",
+ " )\n",
+ " \n",
+ " id2split[video_info['video_id']] = 'train'\n",
+ " else:\n",
+ " images_val_field.append(\n",
+ " {\n",
+ " \"id\": int(video_info['video_id'].split('video')[-1]),\n",
+ " 'file_name': video_info['video_id']\n",
+ " }\n",
+ " )\n",
+ " \n",
+ " id2split[video_info['video_id']] = 'test'\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "length train: 9000 test: 1000\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "print(\"length train: {} test: {}\".format(len(images_field), len(images_val_field)))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "annotations_field = []\n",
+ "\n",
+ "for sentence_info in dataset_info['sentences']:\n",
+ " if id2split[sentence_info['video_id']] == 'train':\n",
+ " video_id = int(sentence_info['video_id'].split('video')[-1])\n",
+ "\n",
+ " annotations_field.append(\n",
+ " {\n",
+ " \"image_id\": video_id,\n",
+ " 'id': sentence_info['sen_id'],\n",
+ " \"caption\": sentence_info['caption']\n",
+ " }\n",
+ " )\n",
+ "\n",
+ "for sentence_info in test_dataset_info['sentences']:\n",
+ " if id2split[sentence_info['video_id']] == 'train':\n",
+ " \n",
+ " video_id = int(sentence_info['video_id'].split('video')[-1])\n",
+ "\n",
+ " annotations_field.append(\n",
+ " {\n",
+ " \"image_id\": video_id,\n",
+ " 'id': sentence_info['sen_id'],\n",
+ " \"caption\": sentence_info['caption']\n",
+ " }\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "180000"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(annotations_field)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "data = {\n",
+ " 'images': images_field,\n",
+ " \"annotations\": annotations_field\n",
+ " }\n",
+ "json.dump(data, open('annotations_new/caption_msrvtt_1k_trainval_cocostyle.json', 'w'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "test_info = pd.read_csv('msrvtt_data/MSRVTT_JSFUSION_test.csv')\n",
+ "videoids = test_info['video_id'].tolist()\n",
+ "sentences = test_info['sentence'].tolist()\n",
+ "images_field = []\n",
+ "\n",
+ "annotations_field = []\n",
+ "\n",
+ "for video_id, sentence in zip(videoids, sentences):\n",
+ " images_field.append(\n",
+ " {\n",
+ " \"id\": int(video_id.split('video')[-1]),\n",
+ " 'file_name': video_id\n",
+ " }\n",
+ " )\n",
+ " annotations_field.append(\n",
+ " {\n",
+ " \"image_id\": int(video_id.split('video')[-1]),\n",
+ " 'id': int(video_id.split('video')[-1]),\n",
+ " \"caption\": sentence\n",
+ " }\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 39,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data = {\n",
+ " 'images': images_field,\n",
+ " \"annotations\": annotations_field\n",
+ " }\n",
+ "json.dump(data, open('annotations_new/caption_msrvtt_1k_test_cocostyle.json', 'w'))"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "a745cf6333d4d8275ecd56c526d26202f2d2beb96e1206fac92576cf98b427be"
+ },
+ "kernelspec": {
+ "display_name": "Python 3.7.11 64-bit ('xmodaler': conda)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/data/preprocess/msvd_preprocess.py b/data/preprocess/msvd_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..e039a02ce8c33a909eef1bc00c0a69f4029f10e0
--- /dev/null
+++ b/data/preprocess/msvd_preprocess.py
@@ -0,0 +1,48 @@
+import json
+import os
+
+subsets = ["train", "val", "test"]
+save_path = 'msvd_dataset/new_annotations'
+
+subset = subsets[1]
+
+videoindex = open("msvd_dataset/txt_labels/youtube_mapping.txt", 'r').readlines()
+sentence_count = 1
+for subset in subsets:
+ name2idx = dict()
+ idx2name = dict()
+
+ for v in videoindex:
+ name2idx[v.split()[0]] = v.split()[1]
+ idx2name[v.split()[1]] = v.split()[0]
+
+ images_field = []
+ annotations_field = []
+ visited_imames = set()
+ txtfile = "msvd_dataset/txt_labels/sents_{}_lc_nopunc.txt".format(subset)
+ capinfos = open(txtfile, 'r').readlines()
+ for caption in capinfos:
+ vidindex = caption.split('\t')[0]
+ if vidindex not in visited_imames:
+ visited_imames.add(vidindex)
+ images_field.append(
+ {
+ "id": int(vidindex.replace('vid', '')),
+ "file_name": idx2name[vidindex]
+ }
+ )
+ annotations_field.append(
+ {
+ "image_id":int(caption.split()[0].replace('vid', '')),
+ "id": sentence_count,
+ "caption": caption.split('\t')[1].strip()
+
+ }
+ )
+ sentence_count += 1
+
+ data = {
+ "images": images_field,
+ "annotations": annotations_field
+ }
+ json.dump(data, open(os.path.join(save_path, "caption_msvd_{}_cocostyle.json".format(subset)), "w"))
\ No newline at end of file
diff --git a/data/preprocess/process_flickr_caption_json.py b/data/preprocess/process_flickr_caption_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7b7289f65cc219e09b36c5e16d409f664be8adb
--- /dev/null
+++ b/data/preprocess/process_flickr_caption_json.py
@@ -0,0 +1,53 @@
+import json
+from collections import defaultdict
+import jsonlines
+
+subsets = ['train', 'val', 'test']
+savepath = "flickr30k/annotations"
+
+set2jsonline = {
+ 'train': 'flickr30k/all_data_final_train_2014.jsonline',
+ 'val': 'flickr30k/all_data_final_val_set0_2014.jsonline',
+ 'test': 'flickr30k/all_data_final_test_set0_2014.jsonline',
+}
+
+import os
+if not os.path.exists(savepath):
+ os.makedirs(savepath)
+
+
+savename = {
+ 'train': "flickr30k/captions_train.json",
+ 'val': "flickr30k/captions_val.json",
+ 'test': "flickr30k/captions_test.json",
+}
+
+# imagefields = defaultdict(list)
+# annotationsfields = defaultdict(list)
+
+for subset in subsets:
+ imagefield = []
+ annotaionfiled = []
+ sen_id = 0
+ with jsonlines.open(set2jsonline[subset]) as reader:
+ for annotation in reader:
+ sentences = annotation["sentences"]
+ image_id = annotation["img_path"]
+ imagefield.append({
+ "filename": annotation["img_path"],
+ "id": annotation['id'],
+ })
+ for sentence in sentences:
+ annotaionfiled.append({
+ "image_id": annotation['id'],
+ "id": sen_id,
+ "caption": sentence,
+ })
+ sen_id += 1
+
+ data = {
+ "images": imagefield,
+ "annotations": annotaionfiled,
+ }
+ json.dump( data, open(savename[subset], "w"))
+
\ No newline at end of file
diff --git a/data/preprocess/region_descriptions.ipynb b/data/preprocess/region_descriptions.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..2d7cef8682d29113617a0c4de5afd47db465b34f
--- /dev/null
+++ b/data/preprocess/region_descriptions.ipynb
@@ -0,0 +1,384 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/nfs/zhujinguo/datasets/visual_genome/annotations\n"
+ ]
+ }
+ ],
+ "source": [
+ "import glob \n",
+ "import json\n",
+ "import os\n",
+ "print(os.getcwd())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# 5.41M caption for \"region_descriptions.json\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(f\"region_descriptions.json\", \"r\") as fp: #5.41\n",
+ " captions = json.load(fp)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "dict_keys(['regions', 'id'])\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(captions[1].keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "2"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "captions[1]['id']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[{'region_id': 1387,\n",
+ " 'width': 43,\n",
+ " 'height': 17,\n",
+ " 'image_id': 2,\n",
+ " 'phrase': 'walk sign is lit up',\n",
+ " 'y': 193,\n",
+ " 'x': 465},\n",
+ " {'region_id': 1388,\n",
+ " 'width': 133,\n",
+ " 'height': 253,\n",
+ " 'image_id': 2,\n",
+ " 'phrase': 'man wearing silver backpack',\n",
+ " 'y': 322,\n",
+ " 'x': 331}]"
+ ]
+ },
+ "execution_count": 90,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "captions[1]['regions'][:2]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "256\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(len(captions[1]['regions']))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'region_id': 1387,\n",
+ " 'width': 43,\n",
+ " 'height': 17,\n",
+ " 'image_id': 2,\n",
+ " 'phrase': 'walk sign is lit up',\n",
+ " 'y': 193,\n",
+ " 'x': 465}"
+ ]
+ },
+ "execution_count": 92,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "captions[1]['regions'][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108077/108077 [00:01<00:00, 58748.17it/s]\n"
+ ]
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 93,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 108077/108077 [00:03<00:00, 34841.72it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from collections import defaultdict\n",
+ "from tqdm import tqdm\n",
+ "iid2captions = defaultdict(set)\n",
+ "for cap in tqdm(captions):\n",
+ " cap = cap[\"regions\"]\n",
+ " \n",
+ " for c in cap:\n",
+ " # v0\n",
+ " # iid2captions[c[\"image_id\"]].append(c['phrase'])\n",
+ " region_area = int(c['height'])*int(c['width'])\n",
+ " if region_area >= 128*128:\n",
+ " iid2captions[c[\"image_id\"]].add(c['phrase'])\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 109,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for iid in iid2captions.keys():\n",
+ " iid2captions[iid] = list(iid2captions[iid])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "107823"
+ ]
+ },
+ "execution_count": 110,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(iid2captions)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "iid2captions[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 111,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "not all images have caption annotations\n",
+ "108249 107823 107823 107823\n"
+ ]
+ }
+ ],
+ "source": [
+ "import random\n",
+ "from glob import glob\n",
+ "paths = list(glob(f\"../images/VG_100K/*.jpg\")) + list(\n",
+ " glob(f\"../images/VG_100K_2/*.jpg\")\n",
+ ")\n",
+ "random.shuffle(paths)\n",
+ "caption_paths = [\n",
+ " path for path in paths if int(path.split(\"/\")[-1][:-4]) in iid2captions\n",
+ "]\n",
+ "iid2subset = {}\n",
+ "for path in paths:\n",
+ " if int(path.split(\"/\")[-1][:-4]) in iid2captions:\n",
+ " iid2subset[int(path.split(\"/\")[-1][:-4])] = os.path.join(path.split(\"/\")[-2],path.split(\"/\")[-1])\n",
+ " \n",
+ "\n",
+ "if len(paths) == len(caption_paths):\n",
+ " print(\"all images have caption annotations\")\n",
+ "else:\n",
+ " print(\"not all images have caption annotations\")\n",
+ "print(\n",
+ " len(paths), len(caption_paths), len(iid2captions), len(iid2subset)\n",
+ ")\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "paths\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 112,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1645544\n"
+ ]
+ }
+ ],
+ "source": [
+ "num=0\n",
+ "for iid in iid2captions.keys():\n",
+ " num += len(iid2captions[iid])\n",
+ "print(num)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "25614848"
+ ]
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 107,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'VG_100K_2/1.jpg'"
+ ]
+ },
+ "execution_count": 107,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "iid2subset[1]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 113,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "data = {\n",
+ " \"phrase\": iid2captions,\n",
+ " \"subset\": iid2subset,\n",
+ "}\n",
+ "json.dump(data, open(\"vg_captions_128filter.json\", \"w\"))\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "interpreter": {
+ "hash": "a745cf6333d4d8275ecd56c526d26202f2d2beb96e1206fac92576cf98b427be"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/data/preprocess/sbu/check_valid.py b/data/preprocess/sbu/check_valid.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d8009e5c299ca425e54c8ddb68342e684d2fe98
--- /dev/null
+++ b/data/preprocess/sbu/check_valid.py
@@ -0,0 +1,36 @@
+import sys
+from PIL import Image
+
+import warnings
+from glob import glob
+import os
+
+warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
+
+def check_image_size(img_path):
+ try:
+ im = Image.open(img_path).convert('RGB')
+ # remove images with too small or too large size
+ if (im.size[0] < 10 or im.size[1] < 10 or im.size[0] > 10000 or im.size[1] > 10000):
+ raise Exception('')
+
+ except:
+ # print(sys.argv[1])
+ return img_path
+ else:
+ return None
+
+def main():
+ image_already_dl = list(glob("images/*"))
+ print('already download {} images.'.format(len(image_already_dl)))
+ for image_path in image_already_dl:
+ ret = check_image_size(image_path)
+ if ret is not None:
+ os.remove(ret)
+
+ image_already_dl = list(glob("images/*"))
+ print('after check size, {} images left.'.format(len(image_already_dl)))
+
+if __name__ == "__main__":
+ print('remove images with too small or too large size')
+ main()
\ No newline at end of file
diff --git a/data/preprocess/sbu/make_sbu_json.py b/data/preprocess/sbu/make_sbu_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b85b174d766119425a971209932c4291518592f
--- /dev/null
+++ b/data/preprocess/sbu/make_sbu_json.py
@@ -0,0 +1,23 @@
+import json
+import os
+from glob import glob
+imagefile = open('dataset/SBU_captioned_photo_dataset_urls.txt', 'r').readlines()
+captionfile = open('dataset/SBU_captioned_photo_dataset_captions.txt', 'r').readlines()
+
+valid_list = list(glob("images/*"))
+valid_list = [ i.split('/')[-1] for i in valid_list]
+
+
+name2cap = {}
+for imageurl, caption in zip(imagefile, captionfile):
+ filename = imageurl.strip().split('/')[-1]
+ name2cap[filename] = caption.strip()
+
+data_list = {}
+for valid_img in valid_list:
+ data_list[valid_img]=name2cap[valid_img]
+
+fp = open('annotations/subcaption.json', 'w')
+json.dump(data_list, fp)
+
+print(len(data_list))
diff --git a/data/preprocess/sbu/sbu_download.sh b/data/preprocess/sbu/sbu_download.sh
new file mode 100644
index 0000000000000000000000000000000000000000..169ed96f71341f2bc33746e1683fb6272463069b
--- /dev/null
+++ b/data/preprocess/sbu/sbu_download.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+export http_proxy=http://172.16.1.135:3128/ ; export https_proxy=http://172.16.1.135:3128/ ; export HTTP_PROXY=http://172.16.1.135:3128/ ; export HTTPS_PROXY=http://172.16.1.135:3128/
+
+
+srun -p cpu --cpus-per-task 20 \
+cat dataset/download_list.txt | xargs -n 2 -P 20 wget -nc -U 'Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17' --timeout=1 --waitretry=0 --tries=5 --retry-connrefused -nv -O
+
+
+find ./images -type f -size -1c -exec rm {} \;
+ls -d ./images/* | xargs -n 1 -P 20 python check_valid.py | tee image_size_invalid.txt
+xargs rm < image_size_invalid.txt
+rm image_size_invalid.txt
+ls ../image > image_valid.txt
\ No newline at end of file
diff --git a/data/preprocess/sbu/sbu_download_list.py b/data/preprocess/sbu/sbu_download_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..d36b16d756278fafe79477935b5d799d405fded2
--- /dev/null
+++ b/data/preprocess/sbu/sbu_download_list.py
@@ -0,0 +1,10 @@
+f = open('dataset/SBU_captioned_photo_dataset_urls.txt', 'r')
+
+download_list = 'dataset/download_list.txt'
+out = open(download_list, 'w')
+
+image_path = '/mnt/lustrenew/share_data/zhujinguo/SBU/images'
+for url in f.readlines():
+ url = url.strip()
+ filename = url.split('/')[-1]
+ out.write("{path}/{filename}\t{url}\n".format(path=image_path, filename=filename, url=url))
\ No newline at end of file
diff --git a/data/pretraining.md b/data/pretraining.md
new file mode 100644
index 0000000000000000000000000000000000000000..7911c2a6410d5e8f078d28a6d303b4a25301ad85
--- /dev/null
+++ b/data/pretraining.md
@@ -0,0 +1,91 @@
+# Pre-Training
+
+## Preparation
+You should prepare training data following [PREPARE_DATA.md](prepare_data.md), and make sure that the environment variable `DATA_PATH` is indeed the location path that stores pre-training data.
+
+```
+echo $DATA_PATH
+```
+
+## Training on singe node
+```
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=$PORT \
+main.py --num-gpus $GPUS --config-file ${CONFIG} OUTPUT_DIR $WORK_DIR
+```
+
+where `$GPUS` is GPU number, `${CONFIG}` is the configuration file, `$PORT` is the specified available port used for distributed training, and `$WORK_DIR` is the directory used to store checkpoint and training log.
+
+For exmaple, the command for pre-training a Uni-Perceiver-Tiny model with `configs/BERT_L12_H192_experiments/7tasks_berttiny_training.yaml` is as folllowing:
+```
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=26511 \
+main.py --num-gpus 8 --config-file configs/BERT_L12_H192_experiments/7tasks_berttiny_training.yaml OUTPUT_DIR work_dirs/exp_demo_log
+```
+Another training example with gradient accumulation :
+```
+ python -m torch.distributed.launch --nproc_per_node=4 --master_port=26511 \
+ main.py --num-gpus 4 --config-file configs/BERT_L12_H384_experiments/in1k_training.yaml SOLVER.ACCUM_ITER 2 OUTPUT_DIR work_dirs/deepspeed_moe/BERT_L12_H384_experiments/debug
+ ```
+
+
+
+## Evaluation without any tuning
+
+You can evaluate the pre-training tasks by adding the `--eval-only` argument.
+```
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=$PORT \
+main.py --num-gpus $GPUS --config-file ${CONFIG} --eval-only OUTPUT_DIR $WORK_DIR
+```
+
+## Training on multiple nodes
+For example, the command for training Uni-Perceiver on 2 nodes of each with 8 GPUs is as following:
+
+On node 1:
+```
+MASTER_ADDR= NODE_RANK=0 GPUS_PER_NODE=8
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=$PORT \
+main.py --num-gpus $GPUS --config-file ${CONFIG} OUTPUT_DIR $WORK_DIR
+```
+
+On node 2:
+```
+MASTER_ADDR= NODE_RANK=1 GPUS_PER_NODE=8
+python -m torch.distributed.launch --nproc_per_node=8 --master_port=$PORT \
+main.py --num-gpus $GPUS --config-file ${CONFIG} OUTPUT_DIR $WORK_DIR
+```
+
+## Training on slurm cluster
+
+If you are using slurm cluster, you can simply run the following command to train Uni-Perceiver on `GPUS/8` nodes with `GPUS` GPUs:
+
+```
+sh run.sh ${CONFIG} ${JOBNAME} ${GPUS} ${PARTITION}
+```
+* Note: you should change the `DATA_PATH` in the script `./run.sh` before your training.
+
+
+## Pre-Training of Uni-Perceiver models
+To save the computation cost, Uni-Perceiver and Uni-Perceiver-MoE are both pre-trained in a two-stage way:
+they are pre-trained with the image resolution of 160x160 firstly, and then are pre-trained for another 10% of total iterations on a higher resolution of 224x224.
+The two-stage training strategy makes our training more effective.
+
+### Uni-Perceiver
+ Take __Uni-Perceiver-Base__ as an example, the 1-st pre-training stage can be conducted as
+```
+sh run.sh configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage1_64gpu.yaml base_pretrain_stage1 64 partitionname
+```
+After the 1-stage, you can run the 2-nd stage pre-training as
+```
+sh run.sh configs/BERT_L12_H768_experiments/16tasks_training_basedense_stage2_64gpu.yaml base_pretrain_stage2 64 partitionname MODEL.WEIGHTS work_dirs/BERT_L12_H768_experiments/16tasks_training_basedense_stage1_64gpu/base_pretrain_stage1/model_Epoch_200000_Iter_0199999.pth
+```
+
+### Uni-Perceiver-MoE
+The __Uni-Perceiver-MoE__ model can also be pre-trained in a similar way, which also follows two-stage pre-training.
+```
+sh run.sh configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage1_56gpu.yaml base_moe_pretrain_stage1 56 partitionname
+```
+
+```
+sh run.sh configs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage2_56gpu.yaml base_moe_pretrain_stage2 56 partitionname MODEL.WEIGHTS work_dirs/BERT_L12_H768_experiments/16tasks_training_basemoe_stage1_56gpu/base_moe_pretrain_stage1/model_Epoch_200000_Iter_0199999.pth
+
+```
+By the way, you should adjust the training iteration and learning scheduler accordingly as you use a different number of GPUs.
\ No newline at end of file
diff --git a/data/prompt_tuning.md b/data/prompt_tuning.md
new file mode 100644
index 0000000000000000000000000000000000000000..883a6eb2e155175399c0411cbaa659618be723b0
--- /dev/null
+++ b/data/prompt_tuning.md
@@ -0,0 +1,22 @@
+# Prompt Tuning
+
+For reproducing the fine-tuning results in our paper, we provide the corresponding prompt-tuning configs in `configs/BERT_L12_H768_experiments/prompt_tuning` and `configs/BERT_L12_H768_experiments/moe_prompt_tuning` for Uni-Perceiver-Base and Uni-Perceiver-MoE-Base, respectively.
+
+Specifically, we prompt-tuned the ImageNet-1K dataset with image classification task. For video classification, we fine-tuned Kinetics-400. We also employed image caption and image-text retrieval tasks on MSCOCO caption and FLicker-30K datasets.
+In addition, video caption and video-text retrieval tasks are conducted on MSVD dataset.
+Please perpare the dataset following [PREPARE_DATA.md](prepare_data.md)
+
+---
+
+In our experiments, prompt-tuning on all datasets benchmarks is performed on 16 NVIDIA-V100 GPUs with 80GB memory.
+Taking Imagenet-1K as an example, the __Uni-Perceiver-Base__ can be prompt-tuned as
+```
+
+sh run.sh configs/BERT_L12_H768_experiments/prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml in1k-pt 16 partitionname MODEL.WEIGHTS work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth
+
+```
+The __Uni-Perceiver-MoE-Base__ can also be fine-tuned in a similar way:
+```
+sh run.sh configs/BERT_L12_H768_experiments/moe_prompt_tuning/in1k_prompt_tuning_0.01data_lr1e-4.yaml in1k-moe-pt 16 partitionname MODEL.WEIGHTS work_dirs/pretrained_models/uni-perceiver-moe-base-L12-H768-224size-pretrained.pth
+```
+
diff --git a/figs/overview.png b/figs/overview.png
new file mode 100644
index 0000000000000000000000000000000000000000..d5f04aef4917ebf2fce1dbb532814fb9b2d48882
Binary files /dev/null and b/figs/overview.png differ
diff --git a/figs/overview_moe.png b/figs/overview_moe.png
new file mode 100644
index 0000000000000000000000000000000000000000..0925546a7769cd6a3f920a982f4795d23500490b
Binary files /dev/null and b/figs/overview_moe.png differ
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..a91512e16c1ea2479a4f94e9931b05f164d37182
--- /dev/null
+++ b/main.py
@@ -0,0 +1,184 @@
+"""
+A main training script.
+"""
+
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+import warnings
+warnings.filterwarnings('ignore') # never print matching warnings
+import logging
+import os
+from collections import OrderedDict
+import torch
+import uniperceiver.utils.comm as comm
+from uniperceiver.config import get_cfg, CfgNode
+from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
+
+#!TODO re-implement hooks
+from uniperceiver.engine import hooks
+from uniperceiver.modeling import add_config
+from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
+try:
+ import deepspeed
+ DEEPSPEED_INSTALLED = True
+except:
+ DEEPSPEED_INSTALLED = False
+
+import copy
+
+def add_data_prefix(cfg):
+ # TODO: more flexible method
+ data_dir = os.getenv("DATA_PATH", None)
+ mapping_list = [
+ [cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
+ [cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
+ [cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
+ [cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
+ [cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
+ [cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
+ [cfg.MODEL, 'WEIGHTS', ['MODEL',]],
+ ]
+ whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
+ if data_dir:
+ for node, attr ,_ in mapping_list:
+ if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
+ setattr(node, attr, os.path.join(data_dir, node[attr]))
+ for task in cfg.TASKS:
+ for _, item, key_list in mapping_list:
+ config_tmp = task
+ for key in key_list:
+ if key in config_tmp:
+ config_tmp = config_tmp[key]
+ if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
+ config_tmp[item] = os.path.join(data_dir, config_tmp[item])
+
+ mapping_list = [
+ ['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
+ ]
+ if cfg.SHARED_TARGETS is None:
+ cfg.SHARED_TARGETS = []
+ for share_targets in cfg.SHARED_TARGETS:
+ for _, item, key_list in mapping_list:
+ config_tmp = share_targets
+ for key in key_list:
+ config_tmp = config_tmp[key]
+ if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
+ '/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
+ 'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
+ config_tmp[item] = os.path.join(data_dir, config_tmp[item])
+
+
+
+def add_default_setting_for_multitask_config(cfg):
+ # merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
+
+ tasks_config_temp = cfg.TASKS
+ num_tasks = len(tasks_config_temp)
+ cfg.pop('TASKS', None)
+
+ cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
+
+ for i, task_config in enumerate(tasks_config_temp):
+ cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
+ cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
+ pass
+
+
+def setup(args):
+ """
+ Create configs and perform basic setups.
+ """
+ cfg = get_cfg()
+ tmp_cfg = cfg.load_from_file_tmp(args.config_file)
+ add_config(cfg, tmp_cfg)
+
+ cfg.merge_from_file(args.config_file)
+ add_data_prefix(cfg)
+
+ cfg.merge_from_list(args.opts)
+ #
+ add_default_setting_for_multitask_config(cfg)
+ cfg.freeze()
+ default_setup(cfg, args)
+ return cfg
+
+def main(args):
+ cfg = setup(args)
+
+ """
+ If you'd like to do anything fancier than the standard training logic,
+ consider writing your own training loop (see plain_train_net.py) or
+ subclassing the trainer.
+ """
+ trainer = build_engine(cfg)
+ trainer.resume_or_load(resume=args.resume)
+ trainer.cast_layers()
+
+ if args.eval_only:
+ print('---------------------------')
+ print('eval model only')
+ print('---------------------------\n')
+ res = None
+ if trainer.val_data_loader is not None:
+
+ if trainer.model_ema is not None and args.eval_ema:
+ if comm.is_main_process():
+ print('using ema model for evaluation')
+ res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
+ else:
+ if args.eval_ema and comm.is_main_process():
+ print('no ema model exists! using master model for evaluation')
+ res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
+
+ if comm.is_main_process():
+ print(res)
+
+ if trainer.test_data_loader is not None:
+ if trainer.model_ema is not None and args.eval_ema:
+ if comm.is_main_process():
+ print('using ema model for evaluation')
+ res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
+ else:
+ if args.eval_ema and comm.is_main_process():
+ print('no ema model exists! using master model for evaluation')
+ res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
+ if comm.is_main_process():
+ print(res)
+ return res
+
+ return trainer.train()
+
+def get_args_parser():
+ parser = default_argument_parser()
+ if DEEPSPEED_INSTALLED:
+ parser = deepspeed.add_config_arguments(parser)
+ parser = add_moe_arguments(parser)
+
+ parser.add_argument('--init_method', default='slurm', type=str)
+ parser.add_argument('--local_rank', default=0, type=int)
+ parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
+ args = parser.parse_args()
+
+ return args
+
+if __name__ == "__main__":
+ args = get_args_parser()
+ print("Command Line Args:", args)
+ if args.init_method == 'slurm':
+ # slurm init
+ check_dist_portfile()
+ init_distributed_mode(args)
+ main(args)
+ elif args.init_method == 'pytorch':
+ main(args)
+ else:
+ # follow 'd2' use default `mp.spawn` to init dist training
+ print('using \'mp.spawn\' for dist init! ')
+ launch(
+ main,
+ args.num_gpus,
+ num_machines=args.num_machines,
+ machine_rank=args.machine_rank,
+ dist_url=args.dist_url,
+ args=(args,),
+ )
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..353f9d53d45f83af07801804a827466f538007ae
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,32 @@
+av==9.0.0
+einops==0.4.1
+ftfy==6.1.1
+fvcore==0.1.5.post20220305
+h5py==3.6.0
+hjson==3.0.2
+huggingface-hub==0.6.0
+hydra-core==1.1.2
+iopath==0.1.9
+jsonlines==3.0.0
+matplotlib==3.5.1
+mkl-service==2.4.0
+ninja==1.10.2.3
+omegaconf==2.1.2
+opencv-python==4.5.5.64
+pandas==1.3.5
+panopticapi @ git+https://github.com/cocodataset/panopticapi.git@7bb4655548f98f3fedc07bf37e9040a992b054b0
+Pillow==9.0.1
+pyarrow==7.0.0
+pycocoevalcap==1.2
+pycocotools==2.0.4
+pytorch-transformers==1.2.0
+PyYAML==6.0
+scikit-learn==1.0.2
+scipy==1.7.3
+sklearn==0.0
+tensorboard==2.8.0
+timm==0.4.5
+tqdm==4.63.0
+transformers==4.19.1
+cloudpickle
+protobuf==3.19
\ No newline at end of file
diff --git a/run.sh b/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..f5042dfa341b70232499de34b260bcc103267b32
--- /dev/null
+++ b/run.sh
@@ -0,0 +1,36 @@
+#!/bin/bash
+
+a=$(echo $HOSTNAME | cut -c12-16)
+
+CONFIG=$1
+JOB_NAME=${2:-"experiments"}
+GPUS=${3:-8}
+
+partition=${4:-'local'} #
+
+GPUS_PER_NODE=${GPUS:-8}
+if [ $GPUS_PER_NODE -ge 8 ]; then
+ GPUS_PER_NODE=8
+fi
+CPUS_PER_TASK=${CPUS_PER_TASK:-4}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PY_ARGS=${@:5}
+
+WORK_DIR=${CONFIG//configs/work_dirs}
+WORK_DIR=${WORK_DIR//.yaml//$JOB_NAME}
+echo $WORK_DIR
+mkdir -p $WORK_DIR
+mkdir -p data/temp
+
+# please change DATA_PATH where you put the training data
+export DATA_PATH='/mnt/lustre/share_data/zhujinguo'
+
+srun --partition=${partition} $SRUN_ARGS \
+--job-name=${JOB_NAME} -n$GPUS --gres=gpu:${GPUS_PER_NODE} \
+--ntasks-per-node=${GPUS_PER_NODE} \
+--kill-on-bad-exit=1 --cpus-per-task 12 \
+python -u main.py --num-gpus $GPUS \
+--config-file ${CONFIG} --init_method slurm --resume \
+${PY_ARGS} OUTPUT_DIR $WORK_DIR
+
diff --git a/slurm_run.sh b/slurm_run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..9bc758e369680779c05ae684b29e8b0e77c99cb5
--- /dev/null
+++ b/slurm_run.sh
@@ -0,0 +1,98 @@
+#!/bin/bash
+
+a=$(echo $HOSTNAME | cut -c12-16)
+
+CONFIG=$1
+JOB_NAME=${2:-"experiments"}
+GPUS=${3:-8}
+
+SRUN=${4:-'reserved'}
+
+GPUS_PER_NODE=${GPUS:-8}
+if [ $GPUS_PER_NODE -ge 8 ]; then
+ GPUS_PER_NODE=8
+fi
+CPUS_PER_TASK=${CPUS_PER_TASK:-4}
+SRUN_ARGS=${SRUN_ARGS:-""}
+
+PY_ARGS=${@:5}
+
+WORK_DIR=${CONFIG//configs/work_dirs}
+WORK_DIR=${WORK_DIR//.yaml//$JOB_NAME}
+echo $WORK_DIR
+mkdir -p $WORK_DIR
+mkdir -p data/temp
+
+now=$(date +"%Y%m%d_%H%M%S")
+
+a=$(echo $HOSTNAME | cut -c12-16)
+
+
+if [ $a == '140-0' ]; then
+ export DATA_PATH='/mnt/lustre/share_data/zhujinguo'
+ export LD_LIBRARY_PATH=/mnt/cache/zhujinguo/anaconda3/envs/py36/lib:$LD_LIBRARY_PATH
+ export TORCH_EXTENSIONS_DIR='/mnt/lustre/zhujinguo/.cache/torch_extensions'
+ export NO_NVRTC=0
+ partition='INTERN'
+ CEPH_CONFIG='slurm_tools/petreloss_1400.config'
+ SRUNreal=${SRUN}
+
+ if [ ${SRUN} == 'vcspot' ]; then
+ SRUNreal='spot --async'
+ partition=VC
+ elif [ ${SRUN} == 'vcauto' ]; then
+ SRUNreal='auto --async'
+ partition=VC
+ elif [ ${SRUN} == 'vcreserved' ]; then
+ SRUNreal='reserved'
+ partition=VC
+ elif [ ${SRUN} == 'spot' ]; then
+ SRUNreal='spot --async'
+ elif [ ${SRUN} == 'auto' ]; then
+ SRUNreal='auto --async'
+
+ fi
+
+elif [ $a == '142-4' ]; then
+ # 1424
+ export DATA_PATH='/mnt/lustre/share_data/zhujinguo'
+ export LD_LIBRARY_PATH=/mnt/cache/zhujinguo/anaconda3/envs/py36/lib:$LD_LIBRARY_PATH
+ export TORCH_EXTENSIONS_DIR='/mnt/lustre/zhujinguo/.cache/torch_extensions'
+ export NO_NVRTC=0
+ partition='vc_research_5'
+ CEPH_CONFIG='slurm_tools/petreloss_1424.config'
+
+ SRUNreal=${SRUN}
+
+ if [ ${SRUN} == 'vc4spot' ]; then
+ SRUNreal='spot --async'
+ partition=vc_research_4
+ elif [ ${SRUN} == 'vc4auto' ]; then
+ SRUNreal='auto --async -x SH-IDC1-10-142-4-76'
+ partition=vc_research_4
+ elif [ ${SRUN} == 'vc4reserved' ]; then
+ SRUNreal='reserved'
+ partition=vc_research_4
+ elif [ ${SRUN} == 'spot' ]; then
+ SRUNreal='spot --async'
+ elif [ ${SRUN} == 'auto' ]; then
+ SRUNreal='auto --async'
+ fi
+
+else
+ echo only SH1424 and SH1400 supported now
+
+fi
+
+srun --partition=${partition} $SRUN_ARGS --quotatype=${SRUNreal} -o $WORK_DIR/phoenix-slurm-%j-$now.out \
+--job-name=${JOB_NAME} -n$GPUS --gres=gpu:${GPUS_PER_NODE} \
+--ntasks-per-node=${GPUS_PER_NODE} \
+--kill-on-bad-exit=1 --cpus-per-task 12 \
+python -u main.py --num-gpus $GPUS \
+--config-file ${CONFIG} --init_method slurm --resume \
+${PY_ARGS} OUTPUT_DIR $WORK_DIR DATALOADER.USE_CEPH True \
+DATALOADER.TCS_CONF_PATH $CEPH_CONFIG SOLVER.CHECKPOINT_PERIOD 10000 SOLVER.CHECKPOINT_MAX_SAVE 1 \
+${OTHERARGS} 2>&1
+
+# SOLVER.ACCUM_ITER 2 SOLVER.CHECKPOINT_PERIOD 1000 SOLVER.CHECKPOINT_MAX_SAVE 1 MODEL.BERT.DROP_PATH_PROB 0.1
+
diff --git a/tools/ceph_test.ipynb b/tools/ceph_test.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..dd6708588da6cfa3f640a7581e23c22af7f78ffc
--- /dev/null
+++ b/tools/ceph_test.ipynb
@@ -0,0 +1,80 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "df7559ae-7130-46d2-9ea1-6a8276893274",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../')\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "7562f117-2723-4cf8-80d3-9f0363f82c8e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from uniperceiver.datasets import TCSLoader\n",
+ "tcs_loader = TCSLoader('../slurm_tools/petreloss_1400.config')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "46d941b6-47bc-42b4-992b-cdf403675e47",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "xdg-open: no method available for opening '/tmp/tmpj8ge0htm.PNG'\n",
+ "xdg-open: no method available for opening '/tmp/tmpuw7r_irw.PNG'\n"
+ ]
+ }
+ ],
+ "source": [
+ "images = [ 'cluster2:s3://imagenet/train/n02127052/n02127052_19569.JPEG', 's3://visual_genome/images/VG_100K/2368406.jpg']\n",
+ "for image_path in images:\n",
+ " img = tcs_loader(image_path).convert('RGB')\n",
+ " img.show()\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2b5375e0-ea19-4b08-9cde-d50fe77dd5eb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "img\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.13"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tools/convert_checkpoint.ipynb b/tools/convert_checkpoint.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..c6ef3635b1da9839eed2516a7fc555281a6f9a09
--- /dev/null
+++ b/tools/convert_checkpoint.ipynb
@@ -0,0 +1,195 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "f2254819-deaf-48ba-848c-471f51ee1221",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "4fe03089-ec4d-4bda-9b02-46cb320e516a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "origin_checkpoint_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "edc282cd-8345-4321-b0a0-3e21d64bfa35",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['module', 'buffer_names', 'optimizer', 'lr_scheduler', 'sparse_tensor_module_names', 'skipped_steps', 'global_steps', 'global_samples', 'dp_world_size', 'mp_world_size', 'ds_config', 'ds_version', 'iteration'])"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')\n",
+ "origin_checkpoint.keys()\n",
+ "# list(origin_checkpoint['module'].keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "79d9f479-3144-4791-82ba-71fec264aa29",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "201"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(list(origin_checkpoint['module'].keys()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "3452947d-4593-4431-a772-3a8ad4882c03",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "dict_keys(['model', 'trainer', 'amp_scaler', 'scheduler', 'iteration'])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'\n",
+ "# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')\n",
+ "# new_checkpoint.keys()\n",
+ "# list(new_checkpoint['model'].keys())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "ffdcf5c5-ffd4-4379-89d7-37ce05c4c0f2",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "41"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# len(list(new_checkpoint['model'].keys()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "fec7a303-a30c-4e92-9452-b534a52d67e9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mapping_dict = {\n",
+ "\n",
+ " 'encoder.': 'fused_encoder.',\n",
+ " 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
+ " 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
+ " 'attention.output.dense': 'self_attn.out_proj',\n",
+ " 'attention_output.residual_scale': 'gamma_1',\n",
+ " 'ffn.dense.': 'linear1.',\n",
+ " 'ffn.dense2.': 'linear2.',\n",
+ " 'ffn_output.residual_scale': 'gamma_2',\n",
+ " 'LayerNormModules.0.': 'norm1.',\n",
+ " 'LayerNormModules.1.': 'norm2.',\n",
+ " 'predictor.': 'loss_prepare.',\n",
+ " \n",
+ "}\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "897ff2f0-1232-4d25-9c13-7ea9568da362",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_checkpoint = { } \n",
+ "\n",
+ "module_checkpoint = origin_checkpoint['module']\n",
+ "\n",
+ "for k, v in module_checkpoint.items():\n",
+ " if k.endswith('residual_scale'):\n",
+ " v.squeeze_(1).squeeze_(0)\n",
+ " if k.startswith('visual_embed'):\n",
+ " continue\n",
+ " for origin_str, target_str in mapping_dict.items():\n",
+ " if origin_str in k:\n",
+ " k = k.replace(origin_str, target_str)\n",
+ " \n",
+ " new_checkpoint[k] = v.float()\n",
+ "\n",
+ "# merge type embedding in video_embed \n",
+ "new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "3c26719f-7451-4c0a-85c3-640c820dfe98",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "torch.save({ 'model': new_checkpoint}, '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tools/convert_checkpoint.py b/tools/convert_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..5be6c69d0cde7abb9233cae4378ca01f50dd7486
--- /dev/null
+++ b/tools/convert_checkpoint.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python
+# coding: utf-8
+
+# In[1]:
+
+
+import torch
+
+
+# In[2]:
+# ckpt_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/16task_90k_bertbase_lr1e-3_wd0.2_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_changeweight_stage2_224size/bertbase_womoe_pretrain2/89999/mp_rank_00_model_states.pt'
+# save_path = '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained.pth'
+
+ckpt_path = '/mnt/cache/zhujinguo/codes/UniPerceiver/work_dirs/deepspeed_moe/BERT_L24_H1024_experiments/16task_90k_bertlarge_lr2e-5_wd0.05_gc0.1_prenorm_warm5k_layerscale1e-3_uniformdp0.2_maeinit_fixedpos_torchfp16_unifieddataset_pretrain_stage2_224size_bw128_all0.5_accum2_bwv2_k700_8frames_yfccfixcap_womixup/all0.5_rmmixup_from430/89999/mp_rank_00_model_states.pt'
+save_path = '/mnt/lustre/zhujinguo/codes/Uni-Perceiver/work_dirs/pretrained_models/uni-perceiver-large-L24-H1024-224size-pretrained.pth'
+origin_checkpoint_path = ckpt_path
+
+
+# In[3]:
+
+
+origin_checkpoint = torch.load(origin_checkpoint_path, 'cpu')
+origin_checkpoint.keys()
+# list(origin_checkpoint['module'].keys())
+
+
+# In[4]:
+
+
+len(list(origin_checkpoint['module'].keys()))
+
+
+# In[8]:
+
+
+# new_checkpoint_path = 'new_exp/model_Epoch_00160_Iter_0000159.pth'
+# new_checkpoint = torch.load(new_checkpoint_path, 'cpu')
+# new_checkpoint.keys()
+# list(new_checkpoint['model'].keys())
+
+
+# In[10]:
+
+
+# len(list(new_checkpoint['model'].keys()))
+
+
+# In[5]:
+
+
+mapping_dict = {
+
+ 'encoder.': 'fused_encoder.',
+ 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',
+ 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',
+ 'attention.output.dense': 'self_attn.out_proj',
+ 'attention_output.residual_scale': 'gamma_1',
+ 'ffn.dense.': 'linear1.',
+ 'ffn.dense2.': 'linear2.',
+ 'ffn_output.residual_scale': 'gamma_2',
+ 'LayerNormModules.0.': 'norm1.',
+ 'LayerNormModules.1.': 'norm2.',
+ 'predictor.': 'loss_prepare.',
+
+}
+
+
+# In[6]:
+
+
+new_checkpoint = { }
+
+module_checkpoint = origin_checkpoint['module']
+
+for k, v in module_checkpoint.items():
+ if k.endswith('residual_scale'):
+ v.squeeze_(1).squeeze_(0)
+ if k.startswith('visual_embed'):
+ continue
+ for origin_str, target_str in mapping_dict.items():
+ if origin_str in k:
+ k = k.replace(origin_str, target_str)
+
+ new_checkpoint[k] = v.float()
+
+# merge type embedding in video_embed
+new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]
+
+# In[7]:
+
+
+
+torch.save({ 'model': new_checkpoint}, save_path)
+
diff --git a/tools/convertmoe.ipynb b/tools/convertmoe.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..0588c0ab060d7031d3f13513f252d0e08afa2889
--- /dev/null
+++ b/tools/convertmoe.ipynb
@@ -0,0 +1,169 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "5fa051c8-8f5e-4809-b90e-bf129a701352",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "d549cf85-c638-4dec-a436-254da7060ee3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "checkpoints = torch.load('../output/model_Epoch_00030_Iter_0000029.pth', 'cpu')['model']\n",
+ "new_keys = list(list(checkpoints.keys()))\n",
+ "new_keys[:10]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "81fd1bd8-d8ed-457f-a9ec-d2b0bf91c8c7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'/mnt/lustre/zhujinguo/jinguo_data/codes/Uni-Perceiver/work_dirs'"
+ ]
+ },
+ "execution_count": 1,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import os\n",
+ "os.getcwd()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "d9a698e1-de54-4257-a45f-65cd2d7cf095",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ds_checkpoints = torch.load('deepspeed_moe/BERT_L12_H192_experiments/7task_150k_berttiny_lr1e-3_wd0.05_gc0.1_prenorm_warm10k_layerscale1e-3_uniformdp0.1_maeinit_fixedpos_torchfp16_unifieddataset_224inputsize_tagmoe_alllayer/tagmoe_alllayer_exp4/149999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
+ "# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/basetagmoe_pretrainstage2/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
+ "# ds_checkpoints = torch.load('/nfs/zhujinguo/codes/xmodaler/work_dirs/deepspeed_moe/BERT_L12_H768_experiments/bertbase_womoe/89999/mp_rank_00_model_states.pt', 'cpu')['module']\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "d5f02e7e-95eb-4386-ac68-f8be0f7ac3c1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "oldkeys = list(ds_checkpoints.keys())\n",
+ "# oldkeys[:20]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "6a9c4928-e808-4e5d-a8cb-c4d133cc9c6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mapping_dict = {\n",
+ "\n",
+ " 'encoder.': 'fused_encoder.',\n",
+ " # 'attention.self.qkv_proj.weight': 'self_attn.in_proj_weight',\n",
+ " # 'attention.self.qkv_proj.bias': 'self_attn.in_proj_bias',\n",
+ " 'attention.self.qkv_proj': 'self_attn.qkv_proj',\n",
+ " 'deepspeed_moe.gate': 'gate',\n",
+ " 'deepspeed_moe.experts': 'experts',\n",
+ " 'attention.output.dense': 'self_attn.dense',\n",
+ " 'attention_output.residual_scale': 'gamma_1',\n",
+ " 'ffn.dense.': 'linear1.',\n",
+ " 'ffn.dense2.': 'linear2.',\n",
+ " 'ffn_output.residual_scale': 'gamma_2',\n",
+ " 'LayerNormModules.0.': 'norm1.',\n",
+ " 'LayerNormModules.1.': 'norm2.',\n",
+ " 'predictor.': 'loss_prepare.',\n",
+ " \n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "9a84319a-d13c-411a-bd21-c0a9c1adb872",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "new_checkpoint = {}\n",
+ "for k, v in ds_checkpoints.items():\n",
+ " if k.endswith('residual_scale'):\n",
+ " v = v.squeeze(1).squeeze(0)\n",
+ " # print(v.shape )\n",
+ " \n",
+ " if k.startswith('visual_embed'):\n",
+ " continue\n",
+ " \n",
+ " \n",
+ " for origin_str, target_str in mapping_dict.items():\n",
+ " if origin_str in k:\n",
+ " k = k.replace(origin_str, target_str)\n",
+ " # merge type embedding in video_embed \n",
+ " # if k=='video_embed.embeddings.bias':\n",
+ " # v = v + ds_checkpoints['video_embed.embeddings_type.weight'][0]\n",
+ "\n",
+ " new_checkpoint[k] = v.float()\n",
+ " # if 'wg' in k:\n",
+ " # print(f'{k}, {v}')\n",
+ "# new_checkpoint['video_embed.embeddings.bias'] = new_checkpoint['video_embed.embeddings.bias'] + new_checkpoint['video_embed.embeddings_type.weight'][0]\n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "b5c999e3-e4b1-4949-b89e-4ee2259db8fa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained-withvtype.pth')\n",
+ "\n",
+ "\n",
+ "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-tiny-L12-H192-224size-pretrained.pth')\n",
+ "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-moe-base-L12-H768-224size-pretrained.pth')\n",
+ "# torch.save({ 'model': new_checkpoint}, 'pretrained_models/uni-perceiver-base-L12-H768-224size-pretrained-custom-attn-module.pth')\n",
+ "\n",
+ "\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tools/generate_target_sets.ipynb b/tools/generate_target_sets.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..e7471e42b6ae39d052f5a20909ca38cccf59a44b
--- /dev/null
+++ b/tools/generate_target_sets.ipynb
@@ -0,0 +1,1584 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pickle\n",
+ "import numpy as np\n",
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "sys.path.append('../')\n",
+ "from uniperceiver.tokenization import ClipTokenizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. ImageNet 1k class names"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tokenize the class names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# We follow CLIP to use the ImageNet-1k class names sourced from Anish Athalye's imagenet-simple-labels.\n",
+ "\n",
+ "class_names = [\"tench\", \"goldfish\", \"great white shark\", \"tiger shark\", \"hammerhead shark\", \"electric ray\", \"stingray\", \"rooster\", \"hen\", \"ostrich\", \"brambling\", \"goldfinch\", \"house finch\", \"junco\", \"indigo bunting\", \"American robin\", \"bulbul\", \"jay\", \"magpie\", \"chickadee\", \"American dipper\", \"kite (bird of prey)\", \"bald eagle\", \"vulture\", \"great grey owl\", \"fire salamander\", \"smooth newt\", \"newt\", \"spotted salamander\", \"axolotl\", \"American bullfrog\", \"tree frog\", \"tailed frog\", \"loggerhead sea turtle\", \"leatherback sea turtle\", \"mud turtle\", \"terrapin\", \"box turtle\", \"banded gecko\", \"green iguana\", \"Carolina anole\", \"desert grassland whiptail lizard\", \"agama\", \"frilled-necked lizard\", \"alligator lizard\", \"Gila monster\", \"European green lizard\", \"chameleon\", \"Komodo dragon\", \"Nile crocodile\", \"American alligator\", \"triceratops\", \"worm snake\", \"ring-necked snake\", \"eastern hog-nosed snake\", \"smooth green snake\", \"kingsnake\", \"garter snake\", \"water snake\", \"vine snake\", \"night snake\", \"boa constrictor\", \"African rock python\", \"Indian cobra\", \"green mamba\", \"sea snake\", \"Saharan horned viper\", \"eastern diamondback rattlesnake\", \"sidewinder rattlesnake\", \"trilobite\", \"harvestman\", \"scorpion\", \"yellow garden spider\", \"barn spider\", \"European garden spider\", \"southern black widow\", \"tarantula\", \"wolf spider\", \"tick\", \"centipede\", \"black grouse\", \"ptarmigan\", \"ruffed grouse\", \"prairie grouse\", \"peafowl\", \"quail\", \"partridge\", \"african grey parrot\", \"macaw\", \"sulphur-crested cockatoo\", \"lorikeet\", \"coucal\", \"bee eater\", \"hornbill\", \"hummingbird\", \"jacamar\", \"toucan\", \"duck\", \"red-breasted merganser\", \"goose\", \"black swan\", \"tusker\", \"echidna\", \"platypus\", \"wallaby\", \"koala\", \"wombat\", \"jellyfish\", \"sea anemone\", \"brain coral\", \"flatworm\", \"nematode\", \"conch\", \"snail\", \"slug\", \"sea slug\", \"chiton\", \"chambered nautilus\", \"Dungeness crab\", \"rock crab\", \"fiddler crab\", \"red king crab\", \"American lobster\", \"spiny lobster\", \"crayfish\", \"hermit crab\", \"isopod\", \"white stork\", \"black stork\", \"spoonbill\", \"flamingo\", \"little blue heron\", \"great egret\", \"bittern bird\", \"crane bird\", \"limpkin\", \"common gallinule\", \"American coot\", \"bustard\", \"ruddy turnstone\", \"dunlin\", \"common redshank\", \"dowitcher\", \"oystercatcher\", \"pelican\", \"king penguin\", \"albatross\", \"grey whale\", \"killer whale\", \"dugong\", \"sea lion\", \"Chihuahua\", \"Japanese Chin\", \"Maltese\", \"Pekingese\", \"Shih Tzu\", \"King Charles Spaniel\", \"Papillon\", \"toy terrier\", \"Rhodesian Ridgeback\", \"Afghan Hound\", \"Basset Hound\", \"Beagle\", \"Bloodhound\", \"Bluetick Coonhound\", \"Black and Tan Coonhound\", \"Treeing Walker Coonhound\", \"English foxhound\", \"Redbone Coonhound\", \"borzoi\", \"Irish Wolfhound\", \"Italian Greyhound\", \"Whippet\", \"Ibizan Hound\", \"Norwegian Elkhound\", \"Otterhound\", \"Saluki\", \"Scottish Deerhound\", \"Weimaraner\", \"Staffordshire Bull Terrier\", \"American Staffordshire Terrier\", \"Bedlington Terrier\", \"Border Terrier\", \"Kerry Blue Terrier\", \"Irish Terrier\", \"Norfolk Terrier\", \"Norwich Terrier\", \"Yorkshire Terrier\", \"Wire Fox Terrier\", \"Lakeland Terrier\", \"Sealyham Terrier\", \"Airedale Terrier\", \"Cairn Terrier\", \"Australian Terrier\", \"Dandie Dinmont Terrier\", \"Boston Terrier\", \"Miniature Schnauzer\", \"Giant Schnauzer\", \"Standard Schnauzer\", \"Scottish Terrier\", \"Tibetan Terrier\", \"Australian Silky Terrier\", \"Soft-coated Wheaten Terrier\", \"West Highland White Terrier\", \"Lhasa Apso\", \"Flat-Coated Retriever\", \"Curly-coated Retriever\", \"Golden Retriever\", \"Labrador Retriever\", \"Chesapeake Bay Retriever\", \"German Shorthaired Pointer\", \"Vizsla\", \"English Setter\", \"Irish Setter\", \"Gordon Setter\", \"Brittany dog\", \"Clumber Spaniel\", \"English Springer Spaniel\", \"Welsh Springer Spaniel\", \"Cocker Spaniel\", \"Sussex Spaniel\", \"Irish Water Spaniel\", \"Kuvasz\", \"Schipperke\", \"Groenendael dog\", \"Malinois\", \"Briard\", \"Australian Kelpie\", \"Komondor\", \"Old English Sheepdog\", \"Shetland Sheepdog\", \"collie\", \"Border Collie\", \"Bouvier des Flandres dog\", \"Rottweiler\", \"German Shepherd Dog\", \"Dobermann\", \"Miniature Pinscher\", \"Greater Swiss Mountain Dog\", \"Bernese Mountain Dog\", \"Appenzeller Sennenhund\", \"Entlebucher Sennenhund\", \"Boxer\", \"Bullmastiff\", \"Tibetan Mastiff\", \"French Bulldog\", \"Great Dane\", \"St. Bernard\", \"husky\", \"Alaskan Malamute\", \"Siberian Husky\", \"Dalmatian\", \"Affenpinscher\", \"Basenji\", \"pug\", \"Leonberger\", \"Newfoundland dog\", \"Great Pyrenees dog\", \"Samoyed\", \"Pomeranian\", \"Chow Chow\", \"Keeshond\", \"brussels griffon\", \"Pembroke Welsh Corgi\", \"Cardigan Welsh Corgi\", \"Toy Poodle\", \"Miniature Poodle\", \"Standard Poodle\", \"Mexican hairless dog (xoloitzcuintli)\", \"grey wolf\", \"Alaskan tundra wolf\", \"red wolf or maned wolf\", \"coyote\", \"dingo\", \"dhole\", \"African wild dog\", \"hyena\", \"red fox\", \"kit fox\", \"Arctic fox\", \"grey fox\", \"tabby cat\", \"tiger cat\", \"Persian cat\", \"Siamese cat\", \"Egyptian Mau\", \"cougar\", \"lynx\", \"leopard\", \"snow leopard\", \"jaguar\", \"lion\", \"tiger\", \"cheetah\", \"brown bear\", \"American black bear\", \"polar bear\", \"sloth bear\", \"mongoose\", \"meerkat\", \"tiger beetle\", \"ladybug\", \"ground beetle\", \"longhorn beetle\", \"leaf beetle\", \"dung beetle\", \"rhinoceros beetle\", \"weevil\", \"fly\", \"bee\", \"ant\", \"grasshopper\", \"cricket insect\", \"stick insect\", \"cockroach\", \"praying mantis\", \"cicada\", \"leafhopper\", \"lacewing\", \"dragonfly\", \"damselfly\", \"red admiral butterfly\", \"ringlet butterfly\", \"monarch butterfly\", \"small white butterfly\", \"sulphur butterfly\", \"gossamer-winged butterfly\", \"starfish\", \"sea urchin\", \"sea cucumber\", \"cottontail rabbit\", \"hare\", \"Angora rabbit\", \"hamster\", \"porcupine\", \"fox squirrel\", \"marmot\", \"beaver\", \"guinea pig\", \"common sorrel horse\", \"zebra\", \"pig\", \"wild boar\", \"warthog\", \"hippopotamus\", \"ox\", \"water buffalo\", \"bison\", \"ram (adult male sheep)\", \"bighorn sheep\", \"Alpine ibex\", \"hartebeest\", \"impala (antelope)\", \"gazelle\", \"arabian camel\", \"llama\", \"weasel\", \"mink\", \"European polecat\", \"black-footed ferret\", \"otter\", \"skunk\", \"badger\", \"armadillo\", \"three-toed sloth\", \"orangutan\", \"gorilla\", \"chimpanzee\", \"gibbon\", \"siamang\", \"guenon\", \"patas monkey\", \"baboon\", \"macaque\", \"langur\", \"black-and-white colobus\", \"proboscis monkey\", \"marmoset\", \"white-headed capuchin\", \"howler monkey\", \"titi monkey\", \"Geoffroy's spider monkey\", \"common squirrel monkey\", \"ring-tailed lemur\", \"indri\", \"Asian elephant\", \"African bush elephant\", \"red panda\", \"giant panda\", \"snoek fish\", \"eel\", \"silver salmon\", \"rock beauty fish\", \"clownfish\", \"sturgeon\", \"gar fish\", \"lionfish\", \"pufferfish\", \"abacus\", \"abaya\", \"academic gown\", \"accordion\", \"acoustic guitar\", \"aircraft carrier\", \"airliner\", \"airship\", \"altar\", \"ambulance\", \"amphibious vehicle\", \"analog clock\", \"apiary\", \"apron\", \"trash can\", \"assault rifle\", \"backpack\", \"bakery\", \"balance beam\", \"balloon\", \"ballpoint pen\", \"Band-Aid\", \"banjo\", \"baluster / handrail\", \"barbell\", \"barber chair\", \"barbershop\", \"barn\", \"barometer\", \"barrel\", \"wheelbarrow\", \"baseball\", \"basketball\", \"bassinet\", \"bassoon\", \"swimming cap\", \"bath towel\", \"bathtub\", \"station wagon\", \"lighthouse\", \"beaker\", \"military hat (bearskin or shako)\", \"beer bottle\", \"beer glass\", \"bell tower\", \"baby bib\", \"tandem bicycle\", \"bikini\", \"ring binder\", \"binoculars\", \"birdhouse\", \"boathouse\", \"bobsleigh\", \"bolo tie\", \"poke bonnet\", \"bookcase\", \"bookstore\", \"bottle cap\", \"hunting bow\", \"bow tie\", \"brass memorial plaque\", \"bra\", \"breakwater\", \"breastplate\", \"broom\", \"bucket\", \"buckle\", \"bulletproof vest\", \"high-speed train\", \"butcher shop\", \"taxicab\", \"cauldron\", \"candle\", \"cannon\", \"canoe\", \"can opener\", \"cardigan\", \"car mirror\", \"carousel\", \"tool kit\", \"cardboard box / carton\", \"car wheel\", \"automated teller machine\", \"cassette\", \"cassette player\", \"castle\", \"catamaran\", \"CD player\", \"cello\", \"mobile phone\", \"chain\", \"chain-link fence\", \"chain mail\", \"chainsaw\", \"storage chest\", \"chiffonier\", \"bell or wind chime\", \"china cabinet\", \"Christmas stocking\", \"church\", \"movie theater\", \"cleaver\", \"cliff dwelling\", \"cloak\", \"clogs\", \"cocktail shaker\", \"coffee mug\", \"coffeemaker\", \"spiral or coil\", \"combination lock\", \"computer keyboard\", \"candy store\", \"container ship\", \"convertible\", \"corkscrew\", \"cornet\", \"cowboy boot\", \"cowboy hat\", \"cradle\", \"construction crane\", \"crash helmet\", \"crate\", \"infant bed\", \"Crock Pot\", \"croquet ball\", \"crutch\", \"cuirass\", \"dam\", \"desk\", \"desktop computer\", \"rotary dial telephone\", \"diaper\", \"digital clock\", \"digital watch\", \"dining table\", \"dishcloth\", \"dishwasher\", \"disc brake\", \"dock\", \"dog sled\", \"dome\", \"doormat\", \"drilling rig\", \"drum\", \"drumstick\", \"dumbbell\", \"Dutch oven\", \"electric fan\", \"electric guitar\", \"electric locomotive\", \"entertainment center\", \"envelope\", \"espresso machine\", \"face powder\", \"feather boa\", \"filing cabinet\", \"fireboat\", \"fire truck\", \"fire screen\", \"flagpole\", \"flute\", \"folding chair\", \"football helmet\", \"forklift\", \"fountain\", \"fountain pen\", \"four-poster bed\", \"freight car\", \"French horn\", \"frying pan\", \"fur coat\", \"garbage truck\", \"gas mask or respirator\", \"gas pump\", \"goblet\", \"go-kart\", \"golf ball\", \"golf cart\", \"gondola\", \"gong\", \"gown\", \"grand piano\", \"greenhouse\", \"radiator grille\", \"grocery store\", \"guillotine\", \"hair clip\", \"hair spray\", \"half-track\", \"hammer\", \"hamper\", \"hair dryer\", \"hand-held computer\", \"handkerchief\", \"hard disk drive\", \"harmonica\", \"harp\", \"combine harvester\", \"hatchet\", \"holster\", \"home theater\", \"honeycomb\", \"hook\", \"hoop skirt\", \"gymnastic horizontal bar\", \"horse-drawn vehicle\", \"hourglass\", \"iPod\", \"clothes iron\", \"carved pumpkin\", \"jeans\", \"jeep\", \"T-shirt\", \"jigsaw puzzle\", \"rickshaw\", \"joystick\", \"kimono\", \"knee pad\", \"knot\", \"lab coat\", \"ladle\", \"lampshade\", \"laptop computer\", \"lawn mower\", \"lens cap\", \"letter opener\", \"library\", \"lifeboat\", \"lighter\", \"limousine\", \"ocean liner\", \"lipstick\", \"slip-on shoe\", \"lotion\", \"music speaker\", \"loupe magnifying glass\", \"sawmill\", \"magnetic compass\", \"messenger bag\", \"mailbox\", \"tights\", \"one-piece bathing suit\", \"manhole cover\", \"maraca\", \"marimba\", \"mask\", \"matchstick\", \"maypole\", \"maze\", \"measuring cup\", \"medicine cabinet\", \"megalith\", \"microphone\", \"microwave oven\", \"military uniform\", \"milk can\", \"minibus\", \"miniskirt\", \"minivan\", \"missile\", \"mitten\", \"mixing bowl\", \"mobile home\", \"ford model t\", \"modem\", \"monastery\", \"monitor\", \"moped\", \"mortar and pestle\", \"graduation cap\", \"mosque\", \"mosquito net\", \"vespa\", \"mountain bike\", \"tent\", \"computer mouse\", \"mousetrap\", \"moving van\", \"muzzle\", \"metal nail\", \"neck brace\", \"necklace\", \"baby pacifier\", \"notebook computer\", \"obelisk\", \"oboe\", \"ocarina\", \"odometer\", \"oil filter\", \"pipe organ\", \"oscilloscope\", \"overskirt\", \"bullock cart\", \"oxygen mask\", \"product packet / packaging\", \"paddle\", \"paddle wheel\", \"padlock\", \"paintbrush\", \"pajamas\", \"palace\", \"pan flute\", \"paper towel\", \"parachute\", \"parallel bars\", \"park bench\", \"parking meter\", \"railroad car\", \"patio\", \"payphone\", \"pedestal\", \"pencil case\", \"pencil sharpener\", \"perfume\", \"Petri dish\", \"photocopier\", \"plectrum\", \"Pickelhaube\", \"picket fence\", \"pickup truck\", \"pier\", \"piggy bank\", \"pill bottle\", \"pillow\", \"ping-pong ball\", \"pinwheel\", \"pirate ship\", \"drink pitcher\", \"block plane\", \"planetarium\", \"plastic bag\", \"plate rack\", \"farm plow\", \"plunger\", \"Polaroid camera\", \"pole\", \"police van\", \"poncho\", \"pool table\", \"soda bottle\", \"plant pot\", \"potter's wheel\", \"power drill\", \"prayer rug\", \"printer\", \"prison\", \"missile\", \"projector\", \"hockey puck\", \"punching bag\", \"purse\", \"quill\", \"quilt\", \"race car\", \"racket\", \"radiator\", \"radio\", \"radio telescope\", \"rain barrel\", \"recreational vehicle\", \"fishing casting reel\", \"reflex camera\", \"refrigerator\", \"remote control\", \"restaurant\", \"revolver\", \"rifle\", \"rocking chair\", \"rotisserie\", \"eraser\", \"rugby ball\", \"ruler measuring stick\", \"sneaker\", \"safe\", \"safety pin\", \"salt shaker\", \"sandal\", \"sarong\", \"saxophone\", \"scabbard\", \"weighing scale\", \"school bus\", \"schooner\", \"scoreboard\", \"CRT monitor\", \"screw\", \"screwdriver\", \"seat belt\", \"sewing machine\", \"shield\", \"shoe store\", \"shoji screen / room divider\", \"shopping basket\", \"shopping cart\", \"shovel\", \"shower cap\", \"shower curtain\", \"ski\", \"balaclava ski mask\", \"sleeping bag\", \"slide rule\", \"sliding door\", \"slot machine\", \"snorkel\", \"snowmobile\", \"snowplow\", \"soap dispenser\", \"soccer ball\", \"sock\", \"solar thermal collector\", \"sombrero\", \"soup bowl\", \"keyboard space bar\", \"space heater\", \"space shuttle\", \"spatula\", \"motorboat\", \"spider web\", \"spindle\", \"sports car\", \"spotlight\", \"stage\", \"steam locomotive\", \"through arch bridge\", \"steel drum\", \"stethoscope\", \"scarf\", \"stone wall\", \"stopwatch\", \"stove\", \"strainer\", \"tram\", \"stretcher\", \"couch\", \"stupa\", \"submarine\", \"suit\", \"sundial\", \"sunglasses\", \"sunglasses\", \"sunscreen\", \"suspension bridge\", \"mop\", \"sweatshirt\", \"swim trunks / shorts\", \"swing\", \"electrical switch\", \"syringe\", \"table lamp\", \"tank\", \"tape player\", \"teapot\", \"teddy bear\", \"television\", \"tennis ball\", \"thatched roof\", \"front curtain\", \"thimble\", \"threshing machine\", \"throne\", \"tile roof\", \"toaster\", \"tobacco shop\", \"toilet seat\", \"torch\", \"totem pole\", \"tow truck\", \"toy store\", \"tractor\", \"semi-trailer truck\", \"tray\", \"trench coat\", \"tricycle\", \"trimaran\", \"tripod\", \"triumphal arch\", \"trolleybus\", \"trombone\", \"hot tub\", \"turnstile\", \"typewriter keyboard\", \"umbrella\", \"unicycle\", \"upright piano\", \"vacuum cleaner\", \"vase\", \"vaulted or arched ceiling\", \"velvet fabric\", \"vending machine\", \"vestment\", \"viaduct\", \"violin\", \"volleyball\", \"waffle iron\", \"wall clock\", \"wallet\", \"wardrobe\", \"military aircraft\", \"sink\", \"washing machine\", \"water bottle\", \"water jug\", \"water tower\", \"whiskey jug\", \"whistle\", \"hair wig\", \"window screen\", \"window shade\", \"Windsor tie\", \"wine bottle\", \"airplane wing\", \"wok\", \"wooden spoon\", \"wool\", \"split-rail fence\", \"shipwreck\", \"sailboat\", \"yurt\", \"website\", \"comic book\", \"crossword\", \"traffic or street sign\", \"traffic light\", \"dust jacket\", \"menu\", \"plate\", \"guacamole\", \"consomme\", \"hot pot\", \"trifle\", \"ice cream\", \"popsicle\", \"baguette\", \"bagel\", \"pretzel\", \"cheeseburger\", \"hot dog\", \"mashed potatoes\", \"cabbage\", \"broccoli\", \"cauliflower\", \"zucchini\", \"spaghetti squash\", \"acorn squash\", \"butternut squash\", \"cucumber\", \"artichoke\", \"bell pepper\", \"cardoon\", \"mushroom\", \"Granny Smith apple\", \"strawberry\", \"orange\", \"lemon\", \"fig\", \"pineapple\", \"banana\", \"jackfruit\", \"cherimoya (custard apple)\", \"pomegranate\", \"hay\", \"carbonara\", \"chocolate syrup\", \"dough\", \"meatloaf\", \"pizza\", \"pot pie\", \"burrito\", \"red wine\", \"espresso\", \"tea cup\", \"eggnog\", \"mountain\", \"bubble\", \"cliff\", \"coral reef\", \"geyser\", \"lakeshore\", \"promontory\", \"sandbar\", \"beach\", \"valley\", \"volcano\", \"baseball player\", \"bridegroom\", \"scuba diver\", \"rapeseed\", \"daisy\", \"yellow lady's slipper\", \"corn\", \"acorn\", \"rose hip\", \"horse chestnut seed\", \"coral fungus\", \"agaric\", \"gyromitra\", \"stinkhorn mushroom\", \"earth star fungus\", \"hen of the woods mushroom\", \"bolete\", \"corn cob\", \"toilet paper\"]\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "class_names_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in class_names]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/imagenet_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. ImageNet 22k class names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For ImageNet 22k, we use the ImageNet 1k class names as the first 1000 categories, and use the WordNet synonyms for the remaining categories.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/imagenet_22k_class_names.pkl', 'rb') as f:\n",
+ " class_name_mapping = pickle.load(f)\n",
+ "\n",
+ "class_names = [c[1] for c in class_name_mapping]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "class_names_tokenized = [[torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in c] for c in class_names]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/imagenet_22k_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Vocabulary"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tokenize the vocabulary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "eot = tokenizer.encoder['<|endoftext|>']\n",
+ "vocabulary_tokenized = [torch.tensor([v, eot]) for v in tokenizer.encoder.values]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': vocabulary_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/vocabulary_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 4. Kinetics-400 class names"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tokenize the class names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Class names obtained from https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769\n",
+ "\n",
+ "class_names = ['abseiling', 'air drumming', 'answering questions', 'applauding', 'applying cream', 'archery', 'arm wrestling', 'arranging flowers', 'assembling computer', 'auctioning', 'baby waking up', 'baking cookies', 'balloon blowing', 'bandaging', 'barbequing', 'bartending', 'beatboxing', 'bee keeping', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bookbinding', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'brush painting', 'brushing hair', 'brushing teeth', 'building cabinet', 'building shed', 'bungee jumping', 'busking', 'canoeing or kayaking', 'capoeira', 'carrying baby', 'cartwheeling', 'carving pumpkin', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing oil', 'changing wheel', 'checking tires', 'cheerleading', 'chopping wood', 'clapping', 'clay pottery making', 'clean and jerk', 'cleaning floor', 'cleaning gutters', 'cleaning pool', 'cleaning shoes', 'cleaning toilet', 'cleaning windows', 'climbing a rope', 'climbing ladder', 'climbing tree', 'contact juggling', 'cooking chicken', 'cooking egg', 'cooking on campfire', 'cooking sausages', 'counting money', 'country line dancing', 'cracking neck', 'crawling baby', 'crossing river', 'crying', 'curling hair', 'cutting nails', 'cutting pineapple', 'cutting watermelon', 'dancing ballet', 'dancing charleston', 'dancing gangnam style', 'dancing macarena', 'deadlifting', 'decorating the christmas tree', 'digging', 'dining', 'disc golfing', 'diving cliff', 'dodgeball', 'doing aerobics', 'doing laundry', 'doing nails', 'drawing', 'dribbling basketball', 'drinking', 'drinking beer', 'drinking shots', 'driving car', 'driving tractor', 'drop kicking', 'drumming fingers', 'dunking basketball', 'dying hair', 'eating burger', 'eating cake', 'eating carrots', 'eating chips', 'eating doughnuts', 'eating hotdog', 'eating ice cream', 'eating spaghetti', 'eating watermelon', 'egg hunting', 'exercising arm', 'exercising with an exercise ball', 'extinguishing fire', 'faceplanting', 'feeding birds', 'feeding fish', 'feeding goats', 'filling eyebrows', 'finger snapping', 'fixing hair', 'flipping pancake', 'flying kite', 'folding clothes', 'folding napkins', 'folding paper', 'front raises', 'frying vegetables', 'garbage collecting', 'gargling', 'getting a haircut', 'getting a tattoo', 'giving or receiving award', 'golf chipping', 'golf driving', 'golf putting', 'grinding meat', 'grooming dog', 'grooming horse', 'gymnastics tumbling', 'hammer throw', 'headbanging', 'headbutting', 'high jump', 'high kick', 'hitting baseball', 'hockey stop', 'holding snake', 'hopscotch', 'hoverboarding', 'hugging', 'hula hooping', 'hurdling', 'hurling (sport)', 'ice climbing', 'ice fishing', 'ice skating', 'ironing', 'javelin throw', 'jetskiing', 'jogging', 'juggling balls', 'juggling fire', 'juggling soccer ball', 'jumping into pool', 'jumpstyle dancing', 'kicking field goal', 'kicking soccer ball', 'kissing', 'kitesurfing', 'knitting', 'krumping', 'laughing', 'laying bricks', 'long jump', 'lunge', 'making a cake', 'making a sandwich', 'making bed', 'making jewelry', 'making pizza', 'making snowman', 'making sushi', 'making tea', 'marching', 'massaging back', 'massaging feet', 'massaging legs', \"massaging person's head\", 'milking cow', 'mopping floor', 'motorcycling', 'moving furniture', 'mowing lawn', 'news anchoring', 'opening bottle', 'opening present', 'paragliding', 'parasailing', 'parkour', 'passing American football (in game)', 'passing American football (not in game)', 'peeling apples', 'peeling potatoes', 'petting animal (not cat)', 'petting cat', 'picking fruit', 'planting trees', 'plastering', 'playing accordion', 'playing badminton', 'playing bagpipes', 'playing basketball', 'playing bass guitar', 'playing cards', 'playing cello', 'playing chess', 'playing clarinet', 'playing controller', 'playing cricket', 'playing cymbals', 'playing didgeridoo', 'playing drums', 'playing flute', 'playing guitar', 'playing harmonica', 'playing harp', 'playing ice hockey', 'playing keyboard', 'playing kickball', 'playing monopoly', 'playing organ', 'playing paintball', 'playing piano', 'playing poker', 'playing recorder', 'playing saxophone', 'playing squash or racquetball', 'playing tennis', 'playing trombone', 'playing trumpet', 'playing ukulele', 'playing violin', 'playing volleyball', 'playing xylophone', 'pole vault', 'presenting weather forecast', 'pull ups', 'pumping fist', 'pumping gas', 'punching bag', 'punching person (boxing)', 'push up', 'pushing car', 'pushing cart', 'pushing wheelchair', 'reading book', 'reading newspaper', 'recording music', 'riding a bike', 'riding camel', 'riding elephant', 'riding mechanical bull', 'riding mountain bike', 'riding mule', 'riding or walking with horse', 'riding scooter', 'riding unicycle', 'ripping paper', 'robot dancing', 'rock climbing', 'rock scissors paper', 'roller skating', 'running on treadmill', 'sailing', 'salsa dancing', 'sanding floor', 'scrambling eggs', 'scuba diving', 'setting table', 'shaking hands', 'shaking head', 'sharpening knives', 'sharpening pencil', 'shaving head', 'shaving legs', 'shearing sheep', 'shining shoes', 'shooting basketball', 'shooting goal (soccer)', 'shot put', 'shoveling snow', 'shredding paper', 'shuffling cards', 'side kick', 'sign language interpreting', 'singing', 'situp', 'skateboarding', 'ski jumping', 'skiing (not slalom or crosscountry)', 'skiing crosscountry', 'skiing slalom', 'skipping rope', 'skydiving', 'slacklining', 'slapping', 'sled dog racing', 'smoking', 'smoking hookah', 'snatch weight lifting', 'sneezing', 'sniffing', 'snorkeling', 'snowboarding', 'snowkiting', 'snowmobiling', 'somersaulting', 'spinning poi', 'spray painting', 'spraying', 'springboard diving', 'squat', 'sticking tongue out', 'stomping grapes', 'stretching arm', 'stretching leg', 'strumming guitar', 'surfing crowd', 'surfing water', 'sweeping floor', 'swimming backstroke', 'swimming breast stroke', 'swimming butterfly stroke', 'swing dancing', 'swinging legs', 'swinging on something', 'sword fighting', 'tai chi', 'taking a shower', 'tango dancing', 'tap dancing', 'tapping guitar', 'tapping pen', 'tasting beer', 'tasting food', 'testifying', 'texting', 'throwing axe', 'throwing ball', 'throwing discus', 'tickling', 'tobogganing', 'tossing coin', 'tossing salad', 'training dog', 'trapezing', 'trimming or shaving beard', 'trimming trees', 'triple jump', 'tying bow tie', 'tying knot (not on a tie)', 'tying tie', 'unboxing', 'unloading truck', 'using computer', 'using remote controller (not gaming)', 'using segway', 'vault', 'waiting in line', 'walking the dog', 'washing dishes', 'washing feet', 'washing hair', 'washing hands', 'water skiing', 'water sliding', 'watering plants', 'waxing back', 'waxing chest', 'waxing eyebrows', 'waxing legs', 'weaving basket', 'welding', 'whistling', 'windsurfing', 'wrapping present', 'wrestling', 'writing', 'yawning', 'yoga', 'zumba']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "class_names_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in class_names]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/k400_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 5. Kinetics-700 class names"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Tokenize the class names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Class names obtained from https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769\n",
+ "\n",
+ "class_names = ['abseiling', 'acting in play', 'adjusting glasses', 'air drumming', 'alligator wrestling', 'answering questions', 'applauding', 'applying cream', 'archaeological excavation', 'archery', 'arguing', 'arm wrestling', 'arranging flowers', 'arresting', 'assembling bicycle', 'assembling computer', 'attending conference', 'auctioning', 'baby waking up', 'backflip (human)', 'baking cookies', 'bandaging', 'barbequing', 'bartending', 'base jumping', 'bathing dog', 'battle rope training', 'beatboxing', 'bee keeping', 'being excited', 'being in zero gravity', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blending fruit', 'blowdrying hair', 'blowing bubble gum', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bodysurfing', 'bookbinding', 'bottling', 'bouncing ball (not juggling)', 'bouncing on bouncy castle', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'breaking boards', 'breaking glass', 'breathing fire', 'brush painting', 'brushing floor', 'brushing hair', 'brushing teeth', 'building cabinet', 'building lego', 'building sandcastle', 'building shed', 'bulldozing', 'bungee jumping', 'burping', 'busking', 'calculating', 'calligraphy', 'canoeing or kayaking', 'capoeira', 'capsizing', 'card stacking', 'card throwing', 'carrying baby', 'carrying weight', 'cartwheeling', 'carving ice', 'carving marble', 'carving pumpkin', 'carving wood with a knife', 'casting fishing line', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing gear in car', 'changing oil', 'changing wheel (not on bike)', 'chasing', 'checking tires', 'checking watch', 'cheerleading', 'chewing gum', 'chiseling stone', 'chiseling wood', 'chopping meat', 'chopping wood', 'clam digging', 'clapping', 'clay pottery making', 'clean and jerk', 'cleaning gutters', 'cleaning pool', 'cleaning shoes', 'cleaning toilet', 'cleaning windows', 'climbing a rope', 'climbing ladder', 'climbing tree', 'closing door', 'coloring in', 'combing hair', 'contact juggling', 'contorting', 'cooking chicken', 'cooking egg', 'cooking on campfire', 'cooking sausages (not on barbeque)', 'cooking scallops', 'cosplaying', 'coughing', 'counting money', 'country line dancing', 'cracking back', 'cracking knuckles', 'cracking neck', 'crawling baby', 'crocheting', 'crossing eyes', 'crossing river', 'crying', 'cumbia', 'curling (sport)', 'curling eyelashes', 'curling hair', 'cutting apple', 'cutting cake', 'cutting nails', 'cutting orange', 'cutting pineapple', 'cutting watermelon', 'dancing ballet', 'dancing charleston', 'dancing gangnam style', 'dancing macarena', 'deadlifting', 'dealing cards', 'decorating the christmas tree', 'decoupage', 'delivering mail', 'digging', 'dining', 'directing traffic', 'disc golfing', 'diving cliff', 'docking boat', 'dodgeball', 'doing aerobics', 'doing jigsaw puzzle', 'doing laundry', 'doing nails', 'doing sudoku', 'drawing', 'dribbling basketball', 'drinking shots', 'driving car', 'driving tractor', 'drooling', 'drop kicking', 'drumming fingers', 'dumpster diving', 'dunking basketball', 'dyeing eyebrows', 'dyeing hair', 'eating burger', 'eating cake', 'eating carrots', 'eating chips', 'eating doughnuts', 'eating hotdog', 'eating ice cream', 'eating nachos', 'eating spaghetti', 'eating watermelon', 'egg hunting', 'embroidering', 'entering church', 'exercising arm', 'exercising with an exercise ball', 'extinguishing fire', 'faceplanting', 'falling off bike', 'falling off chair', 'feeding birds', 'feeding fish', 'feeding goats', 'fencing (sport)', 'fidgeting', 'filling cake', 'filling eyebrows', 'finger snapping', 'fixing bicycle', 'fixing hair', 'flint knapping', 'flipping bottle', 'flipping pancake', 'fly tying', 'flying kite', 'folding clothes', 'folding napkins', 'folding paper', 'front raises', 'frying vegetables', 'gargling', 'geocaching', 'getting a haircut', 'getting a piercing', 'getting a tattoo', 'giving or receiving award', 'gold panning', 'golf chipping', 'golf driving', 'golf putting', 'gospel singing in church', 'grinding meat', 'grooming cat', 'grooming dog', 'grooming horse', 'gymnastics tumbling', 'hammer throw', 'hand washing clothes', 'head stand', 'headbanging', 'headbutting', 'helmet diving', 'herding cattle', 'high fiving', 'high jump', 'high kick', 'historical reenactment', 'hitting baseball', 'hockey stop', 'holding snake', 'home roasting coffee', 'hopscotch', 'hoverboarding', 'huddling', 'hugging (not baby)', 'hugging baby', 'hula hooping', 'hurdling', 'hurling (sport)', 'ice climbing', 'ice fishing', 'ice skating', 'ice swimming', 'inflating balloons', 'installing carpet', 'ironing', 'ironing hair', 'javelin throw', 'jaywalking', 'jetskiing', 'jogging', 'juggling balls', 'juggling fire', 'juggling soccer ball', 'jumping bicycle', 'jumping into pool', 'jumping jacks', 'jumping sofa', 'jumpstyle dancing', 'karaoke', 'kicking field goal', 'kicking soccer ball', 'kissing', 'kitesurfing', 'knitting', 'krumping', 'land sailing', 'laughing', 'lawn mower racing', 'laying bricks', 'laying concrete', 'laying decking', 'laying stone', 'laying tiles', 'leatherworking', 'letting go of balloon', 'licking', 'lifting hat', 'lighting candle', 'lighting fire', 'listening with headphones', 'lock picking', 'long jump', 'longboarding', 'looking at phone', 'looking in mirror', 'luge', 'lunge', 'making a cake', 'making a sandwich', 'making balloon shapes', 'making bubbles', 'making cheese', 'making horseshoes', 'making jewelry', 'making latte art', 'making paper aeroplanes', 'making pizza', 'making slime', 'making snowman', 'making sushi', 'making tea', 'making the bed', 'marching', 'marriage proposal', 'massaging back', 'massaging feet', 'massaging legs', 'massaging neck', \"massaging person's head\", 'metal detecting', 'milking cow', 'milking goat', 'mixing colours', 'moon walking', 'mopping floor', 'mosh pit dancing', 'motorcycling', 'mountain climber (exercise)', 'moving baby', 'moving child', 'moving furniture', 'mowing lawn', 'mushroom foraging', 'needle felting', 'news anchoring', 'opening bottle (not wine)', 'opening coconuts', 'opening door', 'opening present', 'opening refrigerator', 'opening wine bottle', 'packing', 'paragliding', 'parasailing', 'parkour', 'passing American football (in game)', 'passing American football (not in game)', 'passing soccer ball', 'peeling apples', 'peeling banana', 'peeling potatoes', 'person collecting garbage', 'petting animal (not cat)', 'petting cat', 'petting horse', 'photobombing', 'photocopying', 'picking apples', 'picking blueberries', 'pillow fight', 'pinching', 'pirouetting', 'planing wood', 'planting trees', 'plastering', 'playing accordion', 'playing american football', 'playing badminton', 'playing bagpipes', 'playing basketball', 'playing bass guitar', 'playing beer pong', 'playing billiards', 'playing blackjack', 'playing cards', 'playing cello', 'playing checkers', 'playing chess', 'playing clarinet', 'playing controller', 'playing cricket', 'playing cymbals', 'playing darts', 'playing didgeridoo', 'playing dominoes', 'playing drums', 'playing field hockey', 'playing flute', 'playing gong', 'playing guitar', 'playing hand clapping games', 'playing harmonica', 'playing harp', 'playing ice hockey', 'playing keyboard', 'playing kickball', 'playing laser tag', 'playing lute', 'playing mahjong', 'playing maracas', 'playing marbles', 'playing monopoly', 'playing netball', 'playing nose flute', 'playing oboe', 'playing ocarina', 'playing organ', 'playing paintball', 'playing pan pipes', 'playing piano', 'playing piccolo', 'playing pinball', 'playing ping pong', 'playing poker', 'playing polo', 'playing recorder', 'playing road hockey', 'playing rounders', 'playing rubiks cube', 'playing saxophone', 'playing scrabble', 'playing shuffleboard', 'playing slot machine', 'playing squash or racquetball', 'playing tennis', 'playing trombone', 'playing trumpet', 'playing ukulele', 'playing violin', 'playing volleyball', 'playing with trains', 'playing xylophone', 'poaching eggs', 'poking bellybutton', 'pole vault', 'polishing furniture', 'polishing metal', 'popping balloons', 'pouring beer', 'pouring milk', 'pouring wine', 'preparing salad', 'presenting weather forecast', 'pretending to be a statue', 'pull ups', 'pulling espresso shot', 'pulling rope (game)', 'pumping fist', 'pumping gas', 'punching bag', 'punching person (boxing)', 'push up', 'pushing car', 'pushing cart', 'pushing wheelbarrow', 'pushing wheelchair', 'putting in contact lenses', 'putting on eyeliner', 'putting on foundation', 'putting on lipstick', 'putting on mascara', 'putting on sari', 'putting on shoes', 'putting wallpaper on wall', 'raising eyebrows', 'reading book', 'reading newspaper', 'recording music', 'repairing puncture', 'riding a bike', 'riding camel', 'riding elephant', 'riding mechanical bull', 'riding mule', 'riding or walking with horse', 'riding scooter', 'riding snow blower', 'riding unicycle', 'ripping paper', 'roasting marshmallows', 'roasting pig', 'robot dancing', 'rock climbing', 'rock scissors paper', 'roller skating', 'rolling eyes', 'rolling pastry', 'rope pushdown', 'running on treadmill', 'sailing', 'salsa dancing', 'saluting', 'sanding floor', 'sanding wood', 'sausage making', 'sawing wood', 'scrambling eggs', 'scrapbooking', 'scrubbing face', 'scuba diving', 'seasoning food', 'separating eggs', 'setting table', 'sewing', 'shaking hands', 'shaking head', 'shaping bread dough', 'sharpening knives', 'sharpening pencil', 'shaving head', 'shaving legs', 'shearing sheep', 'shining flashlight', 'shining shoes', 'shoot dance', 'shooting basketball', 'shooting goal (soccer)', 'shooting off fireworks', 'shopping', 'shot put', 'shouting', 'shoveling snow', 'shredding paper', 'shucking oysters', 'shuffling cards', 'shuffling feet', 'side kick', 'sieving', 'sign language interpreting', 'silent disco', 'singing', 'sipping cup', 'situp', 'skateboarding', 'ski ballet', 'ski jumping', 'skiing crosscountry', 'skiing mono', 'skiing slalom', 'skipping rope', 'skipping stone', 'skydiving', 'slacklining', 'slapping', 'sled dog racing', 'sleeping', 'slicing onion', 'smashing', 'smelling feet', 'smoking', 'smoking hookah', 'smoking pipe', 'snatch weight lifting', 'sneezing', 'snorkeling', 'snowboarding', 'snowkiting', 'snowmobiling', 'somersaulting', 'spelunking', 'spinning plates', 'spinning poi', 'splashing water', 'spray painting', 'spraying', 'springboard diving', 'square dancing', 'squat', 'squeezing orange', 'stacking cups', 'stacking dice', 'standing on hands', 'staring', 'steer roping', 'steering car', 'sticking tongue out', 'stomping grapes', 'stretching arm', 'stretching leg', 'sucking lolly', 'surfing crowd', 'surfing water', 'surveying', 'sweeping floor', 'swimming backstroke', 'swimming breast stroke', 'swimming butterfly stroke', 'swimming front crawl', 'swimming with dolphins', 'swimming with sharks', 'swing dancing', 'swinging baseball bat', 'swinging on something', 'sword fighting', 'sword swallowing', 'tackling', 'tagging graffiti', 'tai chi', 'taking photo', 'talking on cell phone', 'tango dancing', 'tap dancing', 'tapping guitar', 'tapping pen', 'tasting beer', 'tasting food', 'tasting wine', 'testifying', 'texting', 'threading needle', 'throwing axe', 'throwing ball (not baseball or American football)', 'throwing discus', 'throwing knife', 'throwing snowballs', 'throwing tantrum', 'throwing water balloon', 'tickling', 'tie dying', 'tightrope walking', 'tiptoeing', 'tobogganing', 'tossing coin', 'tossing salad', 'training dog', 'trapezing', 'treating wood', 'trimming or shaving beard', 'trimming shrubs', 'trimming trees', 'triple jump', 'twiddling fingers', 'tying bow tie', 'tying knot (not on a tie)', 'tying necktie', 'tying shoe laces', 'unboxing', 'uncorking champagne', 'unloading truck', 'using a microscope', 'using a paint roller', 'using a power drill', 'using a sledge hammer', 'using a wrench', 'using atm', 'using bagging machine', 'using circular saw', 'using inhaler', 'using megaphone', 'using puppets', 'using remote controller (not gaming)', 'using segway', 'vacuuming car', 'vacuuming floor', 'visiting the zoo', 'wading through mud', 'wading through water', 'waiting in line', 'waking up', 'walking on stilts', 'walking the dog', 'walking through snow', 'walking with crutches', 'washing dishes', 'washing feet', 'washing hair', 'washing hands', 'watching tv', 'water skiing', 'water sliding', 'watering plants', 'waving hand', 'waxing armpits', 'waxing back', 'waxing chest', 'waxing eyebrows', 'waxing legs', 'weaving basket', 'weaving fabric', 'welding', 'whistling', 'windsurfing', 'winking', 'wood burning (art)', 'wrapping present', 'wrestling', 'writing', 'yarn spinning', 'yawning', 'yoga', 'zumba']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "class_names_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in class_names]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/k700_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 6. Moments in time class names"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class_names = ['adult female singing', 'adult female speaking', 'adult male singing', 'adult male speaking', 'aiming', 'applauding', 'arresting', 'ascending', 'asking', 'assembling', 'attacking', 'autographing', 'baking', 'balancing', 'baptizing', 'barbecuing', 'barking', 'bathing', 'bending', 'bicycling', 'biting', 'blocking', 'blowing', 'boarding', 'boating', 'boiling', 'bouncing', 'bowing', 'bowling', 'boxing', 'breaking', 'brushing', 'bubbling', 'building', 'bulldozing', 'burning', 'burying', 'buttoning', 'buying', 'calling', 'camping', 'carrying', 'carving', 'catching', 'celebrating', 'chasing', 'cheering', 'cheerleading', 'chewing', 'child singing', 'child speaking', 'chopping', 'clapping', 'clawing', 'cleaning', 'clearing', 'climbing', 'clinging', 'clipping', 'closing', 'coaching', 'colliding', 'combing', 'combusting', 'competing', 'constructing', 'cooking', 'coughing', 'covering', 'cracking', 'crafting', 'cramming', 'crashing', 'crawling', 'crouching', 'crushing', 'crying', 'cuddling', 'cutting', 'dancing', 'descending', 'destroying', 'digging', 'dining', 'dipping', 'discussing', 'diving', 'dragging', 'draining', 'drawing', 'drenching', 'dressing', 'drilling', 'drinking', 'dripping', 'driving', 'dropping', 'drumming', 'drying', 'dunking', 'dusting', 'eating', 'emptying', 'entering', 'erupting', 'exercising', 'exiting', 'extinguishing', 'falling', 'feeding', 'fencing', 'fighting', 'filling', 'filming', 'fishing', 'flicking', 'flipping', 'floating', 'flooding', 'flowing', 'flying', 'folding', 'frowning', 'frying', 'fueling', 'gambling', 'gardening', 'giggling', 'giving', 'grilling', 'grinning', 'gripping', 'grooming', 'guarding', 'hammering', 'handcuffing', 'handwriting', 'hanging', 'hiking', 'hitchhiking', 'hitting', 'howling', 'hugging', 'hunting', 'imitating', 'inflating', 'injecting', 'instructing', 'interviewing', 'jogging', 'joining', 'juggling', 'jumping', 'kicking', 'kissing', 'kneeling', 'knitting', 'knocking', 'landing', 'laughing', 'launching', 'leaking', 'leaning', 'leaping', 'lecturing', 'licking', 'lifting', 'loading', 'locking', 'manicuring', 'marching', 'marrying', 'massaging', 'measuring', 'mopping', 'mowing', 'officiating', 'opening', 'operating', 'overflowing', 'packaging', 'packing', 'painting', 'parading', 'paying', 'pedaling', 'peeling', 'performing', 'photographing', 'picking', 'piloting', 'pitching', 'placing', 'planting', 'playing', 'playing fun', 'playing music', 'playing sports', 'playing videogames', 'plugging', 'plunging', 'pointing', 'poking', 'pouring', 'praying', 'preaching', 'pressing', 'protesting', 'pulling', 'punching', 'punting', 'pushing', 'putting', 'queuing', 'racing', 'rafting', 'raining', 'raising', 'reaching', 'reading', 'removing', 'repairing', 'resting', 'riding', 'rinsing', 'rising', 'roaring', 'rocking', 'rolling', 'rowing', 'rubbing', 'running', 'sailing', 'saluting', 'sanding', 'sawing', 'scratching', 'screwing', 'scrubbing', 'selling', 'serving', 'sewing', 'shaking', 'shaving', 'shooting', 'shopping', 'shouting', 'shoveling', 'shredding', 'shrugging', 'signing', 'singing', 'sitting', 'skating', 'sketching', 'skiing', 'skipping', 'slapping', 'sleeping', 'slicing', 'sliding', 'slipping', 'smashing', 'smelling', 'smiling', 'smoking', 'snapping', 'sneezing', 'sniffing', 'snowing', 'snuggling', 'socializing', 'sowing', 'speaking', 'spilling', 'spinning', 'spitting', 'splashing', 'spraying', 'spreading', 'sprinkling', 'sprinting', 'squatting', 'squinting', 'stacking', 'standing', 'starting', 'stealing', 'steering', 'stirring', 'stitching', 'stomping', 'stopping', 'storming', 'stretching', 'stroking', 'studying', 'submerging', 'surfing', 'sweeping', 'swerving', 'swimming', 'swinging', 'talking', 'taping', 'tapping', 'tattooing', 'teaching', 'tearing', 'telephoning', 'throwing', 'tickling', 'towing', 'trimming', 'tripping', 'tuning', 'turning', 'twisting', 'tying', 'typing', 'unloading', 'unpacking', 'vacuuming', 'waking', 'walking', 'washing', 'watering', 'waving', 'waxing', 'weeding', 'welding', 'wetting', 'whistling', 'winking', 'working', 'wrapping', 'wrestling', 'writing', 'yawning']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "class_names_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in class_names]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add the metadata"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/MiT_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 7.GLUE datasets"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the order is crucial.\n",
+ "class_names_dict = {\n",
+ " 'SST-2': [\"bad\", \"wonderful\"],\n",
+ " \"CoLA\": [ \"incorrect\", \"correct\" ],\n",
+ " \"RTE\": [ \"no\", \"yes\",],\n",
+ " \"MRPC\": [ \"no\", \"yes\",],\n",
+ " \"QQP\": [ \"no\", \"yes\",],\n",
+ " \"QNLI\": [ \"no\", \"yes\",],\n",
+ " \"MNLI\": [\"no\", \"maybe\", \"yes\"]\n",
+ "}\n",
+ "tokenizer = ClipTokenizer()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for dataset_name, class_names in class_names_dict.items():\n",
+ " print(f'{dataset_name}, {class_names}')\n",
+ " class_names_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in class_names]\n",
+ " # Add the metadata\n",
+ " shared_target_set = {\n",
+ " 'data': class_names_tokenized,\n",
+ " 'modality': 'text'\n",
+ " }\n",
+ " # This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ " with open(f'/nfs/zhujinguo/datasets/open_source_dataset/bert_pretrain_data/glue_data/{dataset_name}_class_name_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 8.VQA All Answers "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['',\n",
+ " 'woods',\n",
+ " 'trash can',\n",
+ " 'hanging',\n",
+ " 'wooden',\n",
+ " 'cooking',\n",
+ " 'china',\n",
+ " 'kids',\n",
+ " 'bike rack',\n",
+ " 'on phone',\n",
+ " 'music',\n",
+ " 'travel',\n",
+ " 'tulip',\n",
+ " 'arrow',\n",
+ " 'branch',\n",
+ " 'chevron',\n",
+ " 'mouth',\n",
+ " 'on right',\n",
+ " 'rice',\n",
+ " 'plate',\n",
+ " 'lots',\n",
+ " 'nature',\n",
+ " 'fruits',\n",
+ " 'throwing frisbee',\n",
+ " 'blonde',\n",
+ " 'life jacket',\n",
+ " 'ham',\n",
+ " 'hay',\n",
+ " 'hat',\n",
+ " 'to get to other side',\n",
+ " '12:35',\n",
+ " 'shadow',\n",
+ " '12:30',\n",
+ " 'crown',\n",
+ " 'crows',\n",
+ " 'bottom',\n",
+ " 'fish',\n",
+ " 'benches',\n",
+ " 'fabric',\n",
+ " 'congratulations',\n",
+ " 'passenger',\n",
+ " 'ana',\n",
+ " 'triangles',\n",
+ " 'roll',\n",
+ " 'chain',\n",
+ " 'chair',\n",
+ " 'crates',\n",
+ " '11:45',\n",
+ " 'shirts',\n",
+ " 'mirrors',\n",
+ " 'parachute',\n",
+ " 'surfing',\n",
+ " 'mango',\n",
+ " 'years',\n",
+ " 'apron',\n",
+ " 'entering',\n",
+ " '6:20',\n",
+ " '6:25',\n",
+ " 'in her hand',\n",
+ " 'night time',\n",
+ " 'zoo',\n",
+ " 'printer',\n",
+ " 'suitcases',\n",
+ " 's',\n",
+ " '4:20',\n",
+ " 'wine tasting',\n",
+ " 'west',\n",
+ " 'newspaper',\n",
+ " 'wires',\n",
+ " 'singapore',\n",
+ " 'advertisement',\n",
+ " 'for photo',\n",
+ " 'multi colored',\n",
+ " 'steamed',\n",
+ " 'traffic',\n",
+ " 'brushing',\n",
+ " 'stagecoach',\n",
+ " 'thailand',\n",
+ " 'pelicans',\n",
+ " 'pull',\n",
+ " 'dirty',\n",
+ " 'rust',\n",
+ " 'tennis shoes',\n",
+ " 'cream',\n",
+ " 'puppy',\n",
+ " 'in box',\n",
+ " 'waving',\n",
+ " 'ceramic',\n",
+ " 'skate park',\n",
+ " 'sand',\n",
+ " 'bathing suit',\n",
+ " 'on tower',\n",
+ " '193',\n",
+ " 'fast food',\n",
+ " 'clock',\n",
+ " 'full',\n",
+ " 'adidas',\n",
+ " 'more',\n",
+ " '100 feet',\n",
+ " 'door',\n",
+ " 'seagull',\n",
+ " 'huge',\n",
+ " 'city bus',\n",
+ " 'paper',\n",
+ " 'signs',\n",
+ " 'smiling',\n",
+ " 'sauce',\n",
+ " 'weeds',\n",
+ " 'very tall',\n",
+ " 'right side',\n",
+ " 'theater',\n",
+ " 'emirates',\n",
+ " 'graffiti',\n",
+ " 'yellow and red',\n",
+ " 'back left',\n",
+ " 'visilab',\n",
+ " 'singles',\n",
+ " 'barrel',\n",
+ " 'blender',\n",
+ " 'sunny',\n",
+ " 'caramel',\n",
+ " 'analog',\n",
+ " 'snowboarder',\n",
+ " 'bell',\n",
+ " 'belt',\n",
+ " 'awake',\n",
+ " '37',\n",
+ " 'cheddar',\n",
+ " 'emergency',\n",
+ " 'couple',\n",
+ " 'harry potter',\n",
+ " 'china airlines',\n",
+ " 'hawaiian',\n",
+ " 'lufthansa',\n",
+ " 'land',\n",
+ " 'windowsill',\n",
+ " 'video',\n",
+ " 'scania',\n",
+ " 'listening to music',\n",
+ " 'pencil',\n",
+ " 'baby',\n",
+ " 'pocket',\n",
+ " 'robot',\n",
+ " 'spatula',\n",
+ " 'mutt',\n",
+ " 'snake',\n",
+ " '1:55',\n",
+ " \"valentine's day\",\n",
+ " '1:50',\n",
+ " 'glasses',\n",
+ " 'books',\n",
+ " 'in cabbage town',\n",
+ " 'festival',\n",
+ " 'big ben',\n",
+ " 'stomach',\n",
+ " 'school bus',\n",
+ " 'mountainous',\n",
+ " 'dishes',\n",
+ " 'hot dog',\n",
+ " 'thanksgiving',\n",
+ " 'teddy bear',\n",
+ " 'watching tv',\n",
+ " 'jar',\n",
+ " 'police officer',\n",
+ " 'jal',\n",
+ " 'tape',\n",
+ " 'riding',\n",
+ " 'styrofoam',\n",
+ " 'frame',\n",
+ " 'roman numerals',\n",
+ " 'skateboarding',\n",
+ " 'staring',\n",
+ " 'doorway',\n",
+ " 'commuter',\n",
+ " 'tattoo',\n",
+ " 'off',\n",
+ " 'clocks',\n",
+ " 'wet',\n",
+ " 'street sign',\n",
+ " 'pier',\n",
+ " 'flying kites',\n",
+ " 'pomeranian',\n",
+ " 'taking selfie',\n",
+ " 'swimming',\n",
+ " 'letters',\n",
+ " 'cat and dog',\n",
+ " 'ponytail',\n",
+ " \"don't know\",\n",
+ " 'by window',\n",
+ " 'water ski',\n",
+ " 'panda',\n",
+ " 'holding umbrella',\n",
+ " 'ski slope',\n",
+ " 'plow',\n",
+ " 'sweater',\n",
+ " 'coins',\n",
+ " 'comcast',\n",
+ " '1990',\n",
+ " 'nuts',\n",
+ " 'ladder',\n",
+ " 'cell phone',\n",
+ " 'arch',\n",
+ " '1 foot',\n",
+ " 'coffee',\n",
+ " '1 in back',\n",
+ " 'white and blue',\n",
+ " 'safe',\n",
+ " 'l',\n",
+ " 'turn right',\n",
+ " 'sled',\n",
+ " 'barrier',\n",
+ " 'tree branch',\n",
+ " 'on chair',\n",
+ " 'john',\n",
+ " 'dogs',\n",
+ " 'on grass',\n",
+ " 'cereal',\n",
+ " 'crosstown',\n",
+ " 'giraffes',\n",
+ " 'gravy',\n",
+ " 'germany',\n",
+ " 'rectangle',\n",
+ " 't shirt',\n",
+ " 'away',\n",
+ " 'playing game',\n",
+ " 'skull and crossbones',\n",
+ " 'kitchen',\n",
+ " '3 feet',\n",
+ " 'onion rings',\n",
+ " 'cylinder',\n",
+ " 'tissue',\n",
+ " 'cone',\n",
+ " 'wheel',\n",
+ " 'hand',\n",
+ " 'yamaha',\n",
+ " 'cooler',\n",
+ " 'night',\n",
+ " 'cirrus',\n",
+ " 'bottom right',\n",
+ " 'sailboats',\n",
+ " 'silverware',\n",
+ " 'gun',\n",
+ " 'horseback riding',\n",
+ " 'slow down',\n",
+ " '2:30',\n",
+ " '2:35',\n",
+ " 'teacher',\n",
+ " '3:25',\n",
+ " '3:20',\n",
+ " 'pillow',\n",
+ " 'green and red',\n",
+ " 'uphill',\n",
+ " 'flat screen',\n",
+ " 'produce',\n",
+ " 'vases',\n",
+ " 'corona',\n",
+ " 'serving',\n",
+ " 'still',\n",
+ " 'tiara',\n",
+ " 'top left',\n",
+ " 'slacks',\n",
+ " 'not',\n",
+ " 'now',\n",
+ " '5 star',\n",
+ " 'do not enter',\n",
+ " 'stir fry',\n",
+ " 'ring',\n",
+ " 'button up',\n",
+ " 'coffee pot',\n",
+ " 'hsbc',\n",
+ " 'apples',\n",
+ " 'windsor',\n",
+ " 'america',\n",
+ " '7:35',\n",
+ " 'cauliflower',\n",
+ " 'living room',\n",
+ " 'jeans',\n",
+ " 'tropical',\n",
+ " 'skiers',\n",
+ " 'playing baseball',\n",
+ " 'foam',\n",
+ " 'westin',\n",
+ " 'taxi',\n",
+ " 'neon',\n",
+ " 'foot',\n",
+ " 'swim',\n",
+ " 'blackberry',\n",
+ " 'apartments',\n",
+ " 'kite flying',\n",
+ " 'stars',\n",
+ " 'pitcher',\n",
+ " 'omelet',\n",
+ " '600',\n",
+ " 'cabinets',\n",
+ " 'not at all',\n",
+ " 'poor',\n",
+ " 'poop',\n",
+ " 'pool',\n",
+ " 'red bull',\n",
+ " 'tracks',\n",
+ " 'all way',\n",
+ " 'grill',\n",
+ " 'us airways',\n",
+ " 'horizontal',\n",
+ " 'chopsticks',\n",
+ " 'computers',\n",
+ " 'evening',\n",
+ " 'talking on phone',\n",
+ " 'starbucks',\n",
+ " 'heavy',\n",
+ " 'safety',\n",
+ " '7',\n",
+ " 'houses',\n",
+ " 'dishwasher',\n",
+ " 'american',\n",
+ " 'horse',\n",
+ " 'station',\n",
+ " 'on track',\n",
+ " 'boredom',\n",
+ " 'toward',\n",
+ " 'silver and red',\n",
+ " 'in grass',\n",
+ " 'gray and black',\n",
+ " 'tongs',\n",
+ " 'close',\n",
+ " 'pictures',\n",
+ " 'zipper',\n",
+ " 'gmc',\n",
+ " 'empty',\n",
+ " 'juice',\n",
+ " '000',\n",
+ " 'cherries',\n",
+ " 'motorola',\n",
+ " 'vests',\n",
+ " 'cones',\n",
+ " 'no man',\n",
+ " 'towards',\n",
+ " 'blue and white',\n",
+ " 'squirrel',\n",
+ " 'at&t',\n",
+ " 'working',\n",
+ " 'on car',\n",
+ " 'e',\n",
+ " 'pizza hut',\n",
+ " 'wave',\n",
+ " 'snowboards',\n",
+ " 'jump',\n",
+ " 'pancake',\n",
+ " 'on sidewalk',\n",
+ " 'camera',\n",
+ " 'visibility',\n",
+ " \"i don't know\",\n",
+ " 'on woman',\n",
+ " 'toaster oven',\n",
+ " 'bandana',\n",
+ " 'farm',\n",
+ " '10:35',\n",
+ " '10:45',\n",
+ " 'taking pictures',\n",
+ " 'costume',\n",
+ " 'slide',\n",
+ " 'ankle',\n",
+ " '9:50',\n",
+ " 'ottoman',\n",
+ " 'baggage claim',\n",
+ " 'trucks',\n",
+ " 'tent',\n",
+ " 'kicking',\n",
+ " 'looking at phone',\n",
+ " 'controller',\n",
+ " 'smoking',\n",
+ " 'toothbrushes',\n",
+ " 'speaker',\n",
+ " 'party',\n",
+ " '42',\n",
+ " 'balloons',\n",
+ " 'riding elephant',\n",
+ " 'propeller',\n",
+ " 'intersection',\n",
+ " 'library',\n",
+ " 'home',\n",
+ " 'grinding',\n",
+ " 'blue and gray',\n",
+ " 'north',\n",
+ " 'strawberries',\n",
+ " 'display',\n",
+ " 'finch',\n",
+ " 'star',\n",
+ " 'foil',\n",
+ " 'samsung',\n",
+ " 'helmets',\n",
+ " 'wakeboard',\n",
+ " 'men',\n",
+ " 'joshua',\n",
+ " 'sliced',\n",
+ " 'jackets',\n",
+ " 'under',\n",
+ " 'for fun',\n",
+ " 'room',\n",
+ " 'roof',\n",
+ " 'checkerboard',\n",
+ " 'before',\n",
+ " 'glazed',\n",
+ " 'ascending',\n",
+ " 'hoodie',\n",
+ " 'blinders',\n",
+ " 'in street',\n",
+ " 'lettuce',\n",
+ " 'frosted',\n",
+ " 'ginger',\n",
+ " 'san diego',\n",
+ " 'blue',\n",
+ " 'mario',\n",
+ " 'no cat',\n",
+ " 'hotel room',\n",
+ " 'celery',\n",
+ " 'watermelon',\n",
+ " 'ears',\n",
+ " 'lotion',\n",
+ " 'parsley',\n",
+ " 'mozzarella',\n",
+ " 'basil',\n",
+ " 'hot sauce',\n",
+ " 'cane',\n",
+ " 'radiator',\n",
+ " 'yes',\n",
+ " 'strap',\n",
+ " 'kitesurfing',\n",
+ " 'dead',\n",
+ " '10 years',\n",
+ " 'ham and cheese',\n",
+ " 'skateboarder',\n",
+ " 'magazine',\n",
+ " 'afternoon',\n",
+ " 'selfie',\n",
+ " 'down',\n",
+ " 'tennis',\n",
+ " 'batman',\n",
+ " 'landing',\n",
+ " 'muffins',\n",
+ " 'handicap',\n",
+ " 'jacket',\n",
+ " 'riding bikes',\n",
+ " 'father',\n",
+ " '0',\n",
+ " 'round',\n",
+ " 'frosting',\n",
+ " 'box',\n",
+ " 'boy',\n",
+ " 'bow',\n",
+ " 'bob',\n",
+ " 'sun hat',\n",
+ " 'pizza box',\n",
+ " 'man in middle',\n",
+ " 'tablecloth',\n",
+ " 'basket',\n",
+ " 'shoes',\n",
+ " 'police',\n",
+ " 'monitor',\n",
+ " 'lunch',\n",
+ " 'man made',\n",
+ " 'elephants',\n",
+ " 'first',\n",
+ " '4 ft',\n",
+ " 'pastry',\n",
+ " 'jungle',\n",
+ " '200',\n",
+ " 'sheets',\n",
+ " 'not high',\n",
+ " 'morning',\n",
+ " 'seat',\n",
+ " 'boundaries',\n",
+ " '2:00',\n",
+ " 'tour',\n",
+ " 'pacifier',\n",
+ " 'squash',\n",
+ " 'laptops',\n",
+ " 'riding motorcycle',\n",
+ " 'sleeping',\n",
+ " 'westjet',\n",
+ " 'outdoor',\n",
+ " 'bow tie',\n",
+ " 'wii remotes',\n",
+ " 'genetics',\n",
+ " 'knife',\n",
+ " 'pockets',\n",
+ " 'harness',\n",
+ " 'traffic lights',\n",
+ " 'on water',\n",
+ " 'on road',\n",
+ " 'sitting',\n",
+ " 'washing',\n",
+ " 'africa',\n",
+ " 'dachshund',\n",
+ " 'brushing her teeth',\n",
+ " 'numbers',\n",
+ " 'comfort',\n",
+ " 'coming',\n",
+ " 'toiletries',\n",
+ " 'dragon',\n",
+ " 'faucet',\n",
+ " 'top',\n",
+ " 'fork and spoon',\n",
+ " 'urban',\n",
+ " '1:05',\n",
+ " '1:00',\n",
+ " 'couch',\n",
+ " 'eggs',\n",
+ " '11:10',\n",
+ " '11:15',\n",
+ " 'opaque',\n",
+ " 'motorcycles',\n",
+ " 'doubles',\n",
+ " 'mouthwash',\n",
+ " 'mailbox',\n",
+ " 'restaurant',\n",
+ " 'baseball bat',\n",
+ " 'very big',\n",
+ " 'peppers',\n",
+ " 'kodak',\n",
+ " 'windmill',\n",
+ " 'stripe',\n",
+ " 'shape',\n",
+ " 'cut',\n",
+ " 'cup',\n",
+ " 'easter',\n",
+ " 'candle',\n",
+ " '5:18',\n",
+ " '5:15',\n",
+ " '5:10',\n",
+ " 'tennis court',\n",
+ " 'planes',\n",
+ " 'curb',\n",
+ " 'cafe',\n",
+ " 'mayonnaise',\n",
+ " 'cooked',\n",
+ " 'pink and blue',\n",
+ " 'dessert',\n",
+ " 'legos',\n",
+ " 'packing',\n",
+ " 'orange and yellow',\n",
+ " 'ketchup and mustard',\n",
+ " 'mustache',\n",
+ " 'horns',\n",
+ " 'potato salad',\n",
+ " 'backwards',\n",
+ " 'yard',\n",
+ " 'skateboard',\n",
+ " 'yarn',\n",
+ " 'bulldog',\n",
+ " 'nobody',\n",
+ " 'beets',\n",
+ " 'mercedes',\n",
+ " 'daffodils',\n",
+ " 'racket',\n",
+ " 'to dry',\n",
+ " 'hammock',\n",
+ " 'statues',\n",
+ " 'denim',\n",
+ " 'flashlight',\n",
+ " 'carnations',\n",
+ " 'hair',\n",
+ " 'shaking hands',\n",
+ " 'socks',\n",
+ " 'female',\n",
+ " 'looking at camera',\n",
+ " 'coke',\n",
+ " 'flip',\n",
+ " 'circus',\n",
+ " 'dresser',\n",
+ " 'cocker spaniel',\n",
+ " 'school',\n",
+ " 'feathers',\n",
+ " 'umbrella',\n",
+ " 'luggage',\n",
+ " 'leaves',\n",
+ " 'to left',\n",
+ " '2:05',\n",
+ " '400',\n",
+ " '3:15',\n",
+ " '3:10',\n",
+ " 'detroit',\n",
+ " 'public market center',\n",
+ " 'small',\n",
+ " 'soda',\n",
+ " 'fire truck',\n",
+ " '700',\n",
+ " 'schnauzer',\n",
+ " 'sparrow',\n",
+ " 'fork and knife',\n",
+ " 'wedding',\n",
+ " 'beads',\n",
+ " 'mud',\n",
+ " 'mug',\n",
+ " 'finger',\n",
+ " 'herding',\n",
+ " 'news',\n",
+ " 'behind fence',\n",
+ " 'elm',\n",
+ " 't shirt and jeans',\n",
+ " 'conference',\n",
+ " 'monkey',\n",
+ " 'n',\n",
+ " 'nike',\n",
+ " 'ocean',\n",
+ " 'cherry',\n",
+ " 'teddy bears',\n",
+ " 'wii controllers',\n",
+ " 'descending',\n",
+ " 'coffee maker',\n",
+ " 'khaki',\n",
+ " 'dinner',\n",
+ " 'toothpicks',\n",
+ " 'fern',\n",
+ " 'bats',\n",
+ " 'sunlight',\n",
+ " 'towing',\n",
+ " 'kiting',\n",
+ " 'setting',\n",
+ " 'papers',\n",
+ " 'picture',\n",
+ " 'football',\n",
+ " 'long time',\n",
+ " 'posts',\n",
+ " 'cloudy',\n",
+ " 'pork',\n",
+ " 'tuxedo',\n",
+ " '1 in middle',\n",
+ " 'pickle',\n",
+ " 'nursing',\n",
+ " 'black and brown',\n",
+ " 'sailboat',\n",
+ " '2 years',\n",
+ " 'platform',\n",
+ " 'farmer',\n",
+ " 'cutting hair',\n",
+ " 'catching',\n",
+ " '100 year party ct',\n",
+ " 'turn',\n",
+ " \"men's\",\n",
+ " 'batting',\n",
+ " 'surf',\n",
+ " 'equestrian',\n",
+ " 'wii remote',\n",
+ " '1 hour',\n",
+ " 'guitar',\n",
+ " 'turkey',\n",
+ " 'direction',\n",
+ " 'pilot',\n",
+ " 'case',\n",
+ " 'statue',\n",
+ " 'peanut butter',\n",
+ " 'towel',\n",
+ " 'tower',\n",
+ " 'competition',\n",
+ " 'moon',\n",
+ " 'burton',\n",
+ " 'on floor',\n",
+ " 'barber shop',\n",
+ " 'flats',\n",
+ " 'grass',\n",
+ " 'roses',\n",
+ " 'pillows',\n",
+ " 'qantas',\n",
+ " 'apartment',\n",
+ " 'dairy',\n",
+ " 'crest',\n",
+ " 'sub',\n",
+ " 'sun',\n",
+ " 'suv',\n",
+ " 'christian',\n",
+ " 'donuts',\n",
+ " 'horses',\n",
+ " 'flat',\n",
+ " 'flag',\n",
+ " 'lighting',\n",
+ " 'short',\n",
+ " 'shore',\n",
+ " '10:50',\n",
+ " '10:55',\n",
+ " 'soccer',\n",
+ " 'clouds',\n",
+ " 'alcohol',\n",
+ " 'hill',\n",
+ " 'snowboarding',\n",
+ " 'urinal',\n",
+ " 'stork',\n",
+ " 'storm',\n",
+ " 'store',\n",
+ " 'surfboard',\n",
+ " 'king',\n",
+ " '1 on right',\n",
+ " '8:35',\n",
+ " 'skyscrapers',\n",
+ " 'wilson',\n",
+ " 'electric',\n",
+ " 'patterned',\n",
+ " 'national express',\n",
+ " 'opponent',\n",
+ " 'triangle',\n",
+ " '9:05',\n",
+ " 'sweet',\n",
+ " 'shirt',\n",
+ " '9',\n",
+ " 'cement',\n",
+ " 'prince',\n",
+ " 'ping pong',\n",
+ " 'phone',\n",
+ " 'playing video game',\n",
+ " 'sports',\n",
+ " 'sepia',\n",
+ " 'rainbow',\n",
+ " 'peach',\n",
+ " 'peace',\n",
+ " 'parking meter',\n",
+ " 'windy',\n",
+ " 'very high',\n",
+ " 'gray and white',\n",
+ " 'compaq',\n",
+ " 'kangaroo',\n",
+ " 'leaving',\n",
+ " 'cell phones',\n",
+ " 'duck',\n",
+ " 'frog',\n",
+ " 'parrots',\n",
+ " 'hammer time',\n",
+ " 'wrist',\n",
+ " 'stuffed animal',\n",
+ " 'on bed',\n",
+ " 'out',\n",
+ " 'brown and white',\n",
+ " 'organic',\n",
+ " 'g',\n",
+ " 'tennis racket',\n",
+ " 'umbrellas',\n",
+ " 'unknown',\n",
+ " 'blue and black',\n",
+ " 'dreadlocks',\n",
+ " 'clip',\n",
+ " 'pipe',\n",
+ " 'cubs',\n",
+ " 'married',\n",
+ " 'identification',\n",
+ " 'on his face',\n",
+ " 'stucco',\n",
+ " 'unclear',\n",
+ " 'motorbike',\n",
+ " 'rooster',\n",
+ " 'camel',\n",
+ " 'name',\n",
+ " \"can't tell\",\n",
+ " 'in middle',\n",
+ " 'pasta',\n",
+ " 'spiral',\n",
+ " 'coffee cup',\n",
+ " 'cake',\n",
+ " 'maroon',\n",
+ " 'space shuttle',\n",
+ " 'left and right',\n",
+ " 'yellow',\n",
+ " 'electricity',\n",
+ " 'deli',\n",
+ " 'dell',\n",
+ " 'eagle',\n",
+ " 'behind woman',\n",
+ " 'military',\n",
+ " 'on dresser',\n",
+ " 'machine',\n",
+ " 'gaming',\n",
+ " 'in sky',\n",
+ " 'wing',\n",
+ " 'wind',\n",
+ " 'wine',\n",
+ " 'baseball uniform',\n",
+ " 'new orleans',\n",
+ " 'silver',\n",
+ " 'bowls',\n",
+ " 'lego',\n",
+ " 'guitar hero',\n",
+ " 'legs',\n",
+ " 'man on left',\n",
+ " 'person',\n",
+ " 'in stands',\n",
+ " 'zebras',\n",
+ " 'victoria',\n",
+ " 'little girl',\n",
+ " 'recliner',\n",
+ " 'menu',\n",
+ " 'fair',\n",
+ " 'looking out window',\n",
+ " 'pirate',\n",
+ " 'vaio',\n",
+ " 'chrome',\n",
+ " 'life',\n",
+ " 'lift',\n",
+ " 'child',\n",
+ " 'chili',\n",
+ " 'above toilet',\n",
+ " 'gray and red',\n",
+ " 'babies',\n",
+ " 'bird feeder',\n",
+ " 'roast beef',\n",
+ " 'main street',\n",
+ " 'train car',\n",
+ " 'pavement',\n",
+ " 'steps',\n",
+ " 'people',\n",
+ " 'fox',\n",
+ " 'fog',\n",
+ " 'happy',\n",
+ " 'tigers',\n",
+ " 'lays',\n",
+ " '4 way',\n",
+ " 'peacock',\n",
+ " 'venice',\n",
+ " 'ollie',\n",
+ " 'subway',\n",
+ " 'teal',\n",
+ " 'team',\n",
+ " 'current',\n",
+ " 'no smoking',\n",
+ " 'love',\n",
+ " 'sunbathing',\n",
+ " 'heineken',\n",
+ " 'winter',\n",
+ " 'elephant',\n",
+ " 'wheelchair',\n",
+ " '8 feet',\n",
+ " 'colorado',\n",
+ " 'bagels',\n",
+ " 'polka dot',\n",
+ " '2 men',\n",
+ " 'gatorade',\n",
+ " 'arriving',\n",
+ " 'bowling',\n",
+ " 'not there',\n",
+ " 'container',\n",
+ " 'dodgers',\n",
+ " '12:10',\n",
+ " '12:15',\n",
+ " 'ostrich',\n",
+ " '2',\n",
+ " 'typing',\n",
+ " 'give way',\n",
+ " 'hockey',\n",
+ " 'robe',\n",
+ " 'bridge',\n",
+ " 'barbed wire',\n",
+ " 'cantaloupe',\n",
+ " 'jets',\n",
+ " '1:35',\n",
+ " '1:30',\n",
+ " 'twin',\n",
+ " 'teddy',\n",
+ " 'toothpaste',\n",
+ " 'vacation',\n",
+ " 'orange juice',\n",
+ " 'no dog',\n",
+ " 'skier',\n",
+ " 'orange and black',\n",
+ " 'clay',\n",
+ " 'tying tie',\n",
+ " 'tube',\n",
+ " '6:40',\n",
+ " '6:45',\n",
+ " 'take off',\n",
+ " 'bundt',\n",
+ " 'aa',\n",
+ " 'gone',\n",
+ " 'am',\n",
+ " '5:45',\n",
+ " '5:40',\n",
+ " 'for sale',\n",
+ " 'dinosaur',\n",
+ " '2 hours',\n",
+ " 'dead end',\n",
+ " 'bottom left',\n",
+ " 'using laptop',\n",
+ " 'young',\n",
+ " 'indoors',\n",
+ " 'parking garage',\n",
+ " 'pitching',\n",
+ " '4:00',\n",
+ " '4:05',\n",
+ " 'information',\n",
+ " 'not very',\n",
+ " 'no parking',\n",
+ " 'countryside',\n",
+ " 'backyard',\n",
+ " 'hauling',\n",
+ " 'pans',\n",
+ " 'rottweiler',\n",
+ " 'mexican',\n",
+ " 'swimsuit',\n",
+ " 'harbor',\n",
+ " 'air canada',\n",
+ " 'highway',\n",
+ " 'porcelain',\n",
+ " 'hydrant',\n",
+ " 'w',\n",
+ " '2 people',\n",
+ " 'cartoon',\n",
+ " 'obama',\n",
+ " 'refrigerators',\n",
+ " 'play',\n",
+ " 'cover',\n",
+ " 'sleeve',\n",
+ " 'crosswalk',\n",
+ " '1 year',\n",
+ " 'whole',\n",
+ " 'lanyard',\n",
+ " 'hitting',\n",
+ " 'fire',\n",
+ " 'farmers market',\n",
+ " 'white and brown',\n",
+ " 'sesame',\n",
+ " 'knee pads',\n",
+ " 'audi',\n",
+ " 'cucumber',\n",
+ " 'stainless steel',\n",
+ " 'baseball glove',\n",
+ " 'santa',\n",
+ " 'cobblestone',\n",
+ " 'taking picture',\n",
+ " 'garden',\n",
+ " 'index',\n",
+ " 'twins',\n",
+ " 'bird',\n",
+ " 'leg',\n",
+ " 'banana split',\n",
+ " 'mariners',\n",
+ " 'standing',\n",
+ " 'casserole',\n",
+ " 'high',\n",
+ " 'animal',\n",
+ " 'stop',\n",
+ " 'wallpaper',\n",
+ " 'notebook',\n",
+ " 'logitech',\n",
+ " 'next to toilet',\n",
+ " 'catch frisbee',\n",
+ " 'jesus',\n",
+ " 'owner',\n",
+ " 'steel',\n",
+ " '2:55',\n",
+ " '2:50',\n",
+ " 'soap',\n",
+ " 'biker',\n",
+ " 'bikes',\n",
+ " 'not possible',\n",
+ " 'amtrak',\n",
+ " 'email',\n",
+ " 'black and red',\n",
+ " 'cd',\n",
+ " 'ripe',\n",
+ " 'north face',\n",
+ " 'ball',\n",
+ " 'dusk',\n",
+ " 'bald',\n",
+ " 'overalls',\n",
+ " 'bananas',\n",
+ " '7:55',\n",
+ " 'fire hydrant',\n",
+ " 'conference room',\n",
+ " 'residential',\n",
+ " 'to see',\n",
+ " 'onion',\n",
+ " 'transport',\n",
+ " 'fedex',\n",
+ " 'seeds',\n",
+ " 'bud light',\n",
+ " 'delivery',\n",
+ " 'construction',\n",
+ " 'smooth',\n",
+ " 'volvo',\n",
+ " 'black and yellow',\n",
+ " 'mushroom',\n",
+ " 'footprints',\n",
+ " 'rural',\n",
+ " 'buoy',\n",
+ " 'all',\n",
+ " 'wreath',\n",
+ " 'blanket',\n",
+ " \"women's\",\n",
+ " 'home plate',\n",
+ " 'lizard',\n",
+ " 'above',\n",
+ " 'smoke',\n",
+ " 'indians',\n",
+ " 'boats',\n",
+ " 'green',\n",
+ " 'giraffe',\n",
+ " 'on tray',\n",
+ " 'lying down',\n",
+ " 'in vase',\n",
+ " 'sliding',\n",
+ " 'pooping',\n",
+ " 'train tracks',\n",
+ " 'goalie',\n",
+ " 'dc',\n",
+ " ...]"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/VQA/trainval_label2ans.pkl', 'rb') as f:\n",
+ " answers_mapping = pickle.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = ClipTokenizer()\n",
+ "answer_tokenized = [torch.tensor(tokenizer.encode(x + \" <|endoftext|>\")) for x in answers_mapping]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Add meta data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "shared_target_set = {\n",
+ " 'data': answer_tokenized,\n",
+ " 'modality': 'text'\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Save the target set"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# This file location should be specified in the \"SHARED_TARGETS\" section of the config file.\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/VQA_Answers_CLIP_with_endoftext.pkl', 'wb') as f:\n",
+ " pickle.dump(shared_target_set, f)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "a745cf6333d4d8275ecd56c526d26202f2d2beb96e1206fac92576cf98b427be"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/tools/video_categories.ipynb b/tools/video_categories.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..36c151ea46d7846eb63321082453ae1011c552ba
--- /dev/null
+++ b/tools/video_categories.ipynb
@@ -0,0 +1,128 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "4f3d41ec-7908-48a3-9be6-83f4d33401f1",
+ "metadata": {},
+ "source": [
+ "# Generate categories for video datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "886f5daa-a807-48f6-a03b-cd0f5d666a24",
+ "metadata": {},
+ "source": [
+ "## K400 "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "d25c153b-bdab-4cf8-bc8e-748f2a4dd489",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Class names obtained from https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769\n",
+ "\n",
+ "class_names = ['abseiling', 'air drumming', 'answering questions', 'applauding', 'applying cream', 'archery', 'arm wrestling', 'arranging flowers', 'assembling computer', 'auctioning', 'baby waking up', 'baking cookies', 'balloon blowing', 'bandaging', 'barbequing', 'bartending', 'beatboxing', 'bee keeping', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bookbinding', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'brush painting', 'brushing hair', 'brushing teeth', 'building cabinet', 'building shed', 'bungee jumping', 'busking', 'canoeing or kayaking', 'capoeira', 'carrying baby', 'cartwheeling', 'carving pumpkin', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing oil', 'changing wheel', 'checking tires', 'cheerleading', 'chopping wood', 'clapping', 'clay pottery making', 'clean and jerk', 'cleaning floor', 'cleaning gutters', 'cleaning pool', 'cleaning shoes', 'cleaning toilet', 'cleaning windows', 'climbing a rope', 'climbing ladder', 'climbing tree', 'contact juggling', 'cooking chicken', 'cooking egg', 'cooking on campfire', 'cooking sausages', 'counting money', 'country line dancing', 'cracking neck', 'crawling baby', 'crossing river', 'crying', 'curling hair', 'cutting nails', 'cutting pineapple', 'cutting watermelon', 'dancing ballet', 'dancing charleston', 'dancing gangnam style', 'dancing macarena', 'deadlifting', 'decorating the christmas tree', 'digging', 'dining', 'disc golfing', 'diving cliff', 'dodgeball', 'doing aerobics', 'doing laundry', 'doing nails', 'drawing', 'dribbling basketball', 'drinking', 'drinking beer', 'drinking shots', 'driving car', 'driving tractor', 'drop kicking', 'drumming fingers', 'dunking basketball', 'dying hair', 'eating burger', 'eating cake', 'eating carrots', 'eating chips', 'eating doughnuts', 'eating hotdog', 'eating ice cream', 'eating spaghetti', 'eating watermelon', 'egg hunting', 'exercising arm', 'exercising with an exercise ball', 'extinguishing fire', 'faceplanting', 'feeding birds', 'feeding fish', 'feeding goats', 'filling eyebrows', 'finger snapping', 'fixing hair', 'flipping pancake', 'flying kite', 'folding clothes', 'folding napkins', 'folding paper', 'front raises', 'frying vegetables', 'garbage collecting', 'gargling', 'getting a haircut', 'getting a tattoo', 'giving or receiving award', 'golf chipping', 'golf driving', 'golf putting', 'grinding meat', 'grooming dog', 'grooming horse', 'gymnastics tumbling', 'hammer throw', 'headbanging', 'headbutting', 'high jump', 'high kick', 'hitting baseball', 'hockey stop', 'holding snake', 'hopscotch', 'hoverboarding', 'hugging', 'hula hooping', 'hurdling', 'hurling (sport)', 'ice climbing', 'ice fishing', 'ice skating', 'ironing', 'javelin throw', 'jetskiing', 'jogging', 'juggling balls', 'juggling fire', 'juggling soccer ball', 'jumping into pool', 'jumpstyle dancing', 'kicking field goal', 'kicking soccer ball', 'kissing', 'kitesurfing', 'knitting', 'krumping', 'laughing', 'laying bricks', 'long jump', 'lunge', 'making a cake', 'making a sandwich', 'making bed', 'making jewelry', 'making pizza', 'making snowman', 'making sushi', 'making tea', 'marching', 'massaging back', 'massaging feet', 'massaging legs', \"massaging person's head\", 'milking cow', 'mopping floor', 'motorcycling', 'moving furniture', 'mowing lawn', 'news anchoring', 'opening bottle', 'opening present', 'paragliding', 'parasailing', 'parkour', 'passing American football (in game)', 'passing American football (not in game)', 'peeling apples', 'peeling potatoes', 'petting animal (not cat)', 'petting cat', 'picking fruit', 'planting trees', 'plastering', 'playing accordion', 'playing badminton', 'playing bagpipes', 'playing basketball', 'playing bass guitar', 'playing cards', 'playing cello', 'playing chess', 'playing clarinet', 'playing controller', 'playing cricket', 'playing cymbals', 'playing didgeridoo', 'playing drums', 'playing flute', 'playing guitar', 'playing harmonica', 'playing harp', 'playing ice hockey', 'playing keyboard', 'playing kickball', 'playing monopoly', 'playing organ', 'playing paintball', 'playing piano', 'playing poker', 'playing recorder', 'playing saxophone', 'playing squash or racquetball', 'playing tennis', 'playing trombone', 'playing trumpet', 'playing ukulele', 'playing violin', 'playing volleyball', 'playing xylophone', 'pole vault', 'presenting weather forecast', 'pull ups', 'pumping fist', 'pumping gas', 'punching bag', 'punching person (boxing)', 'push up', 'pushing car', 'pushing cart', 'pushing wheelchair', 'reading book', 'reading newspaper', 'recording music', 'riding a bike', 'riding camel', 'riding elephant', 'riding mechanical bull', 'riding mountain bike', 'riding mule', 'riding or walking with horse', 'riding scooter', 'riding unicycle', 'ripping paper', 'robot dancing', 'rock climbing', 'rock scissors paper', 'roller skating', 'running on treadmill', 'sailing', 'salsa dancing', 'sanding floor', 'scrambling eggs', 'scuba diving', 'setting table', 'shaking hands', 'shaking head', 'sharpening knives', 'sharpening pencil', 'shaving head', 'shaving legs', 'shearing sheep', 'shining shoes', 'shooting basketball', 'shooting goal (soccer)', 'shot put', 'shoveling snow', 'shredding paper', 'shuffling cards', 'side kick', 'sign language interpreting', 'singing', 'situp', 'skateboarding', 'ski jumping', 'skiing (not slalom or crosscountry)', 'skiing crosscountry', 'skiing slalom', 'skipping rope', 'skydiving', 'slacklining', 'slapping', 'sled dog racing', 'smoking', 'smoking hookah', 'snatch weight lifting', 'sneezing', 'sniffing', 'snorkeling', 'snowboarding', 'snowkiting', 'snowmobiling', 'somersaulting', 'spinning poi', 'spray painting', 'spraying', 'springboard diving', 'squat', 'sticking tongue out', 'stomping grapes', 'stretching arm', 'stretching leg', 'strumming guitar', 'surfing crowd', 'surfing water', 'sweeping floor', 'swimming backstroke', 'swimming breast stroke', 'swimming butterfly stroke', 'swing dancing', 'swinging legs', 'swinging on something', 'sword fighting', 'tai chi', 'taking a shower', 'tango dancing', 'tap dancing', 'tapping guitar', 'tapping pen', 'tasting beer', 'tasting food', 'testifying', 'texting', 'throwing axe', 'throwing ball', 'throwing discus', 'tickling', 'tobogganing', 'tossing coin', 'tossing salad', 'training dog', 'trapezing', 'trimming or shaving beard', 'trimming trees', 'triple jump', 'tying bow tie', 'tying knot (not on a tie)', 'tying tie', 'unboxing', 'unloading truck', 'using computer', 'using remote controller (not gaming)', 'using segway', 'vault', 'waiting in line', 'walking the dog', 'washing dishes', 'washing feet', 'washing hair', 'washing hands', 'water skiing', 'water sliding', 'watering plants', 'waxing back', 'waxing chest', 'waxing eyebrows', 'waxing legs', 'weaving basket', 'welding', 'whistling', 'windsurfing', 'wrapping present', 'wrestling', 'writing', 'yawning', 'yoga', 'zumba']\n",
+ "\n",
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/K400_official/category_mapping.txt', 'w') as f:\n",
+ " for idx, class_name in enumerate(class_names):\n",
+ " f.write(f'{class_name}\\t{idx}\\n')\n",
+ " \n",
+ " \n",
+ " "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2b70ab11-6029-4790-91f4-c4e52cd36e8b",
+ "metadata": {},
+ "source": [
+ "## K700"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "fd06fe59-5e18-42db-b6bc-b655b547fb6d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Class names obtained from https://gist.github.com/willprice/f19da185c9c5f32847134b87c1960769\n",
+ "\n",
+ "class_names = ['abseiling', 'acting in play', 'adjusting glasses', 'air drumming', 'alligator wrestling', 'answering questions', 'applauding', 'applying cream', 'archaeological excavation', 'archery', 'arguing', 'arm wrestling', 'arranging flowers', 'arresting', 'assembling bicycle', 'assembling computer', 'attending conference', 'auctioning', 'baby waking up', 'backflip (human)', 'baking cookies', 'bandaging', 'barbequing', 'bartending', 'base jumping', 'bathing dog', 'battle rope training', 'beatboxing', 'bee keeping', 'being excited', 'being in zero gravity', 'belly dancing', 'bench pressing', 'bending back', 'bending metal', 'biking through snow', 'blasting sand', 'blending fruit', 'blowdrying hair', 'blowing bubble gum', 'blowing glass', 'blowing leaves', 'blowing nose', 'blowing out candles', 'bobsledding', 'bodysurfing', 'bookbinding', 'bottling', 'bouncing ball (not juggling)', 'bouncing on bouncy castle', 'bouncing on trampoline', 'bowling', 'braiding hair', 'breading or breadcrumbing', 'breakdancing', 'breaking boards', 'breaking glass', 'breathing fire', 'brush painting', 'brushing floor', 'brushing hair', 'brushing teeth', 'building cabinet', 'building lego', 'building sandcastle', 'building shed', 'bulldozing', 'bungee jumping', 'burping', 'busking', 'calculating', 'calligraphy', 'canoeing or kayaking', 'capoeira', 'capsizing', 'card stacking', 'card throwing', 'carrying baby', 'carrying weight', 'cartwheeling', 'carving ice', 'carving marble', 'carving pumpkin', 'carving wood with a knife', 'casting fishing line', 'catching fish', 'catching or throwing baseball', 'catching or throwing frisbee', 'catching or throwing softball', 'celebrating', 'changing gear in car', 'changing oil', 'changing wheel (not on bike)', 'chasing', 'checking tires', 'checking watch', 'cheerleading', 'chewing gum', 'chiseling stone', 'chiseling wood', 'chopping meat', 'chopping wood', 'clam digging', 'clapping', 'clay pottery making', 'clean and jerk', 'cleaning gutters', 'cleaning pool', 'cleaning shoes', 'cleaning toilet', 'cleaning windows', 'climbing a rope', 'climbing ladder', 'climbing tree', 'closing door', 'coloring in', 'combing hair', 'contact juggling', 'contorting', 'cooking chicken', 'cooking egg', 'cooking on campfire', 'cooking sausages (not on barbeque)', 'cooking scallops', 'cosplaying', 'coughing', 'counting money', 'country line dancing', 'cracking back', 'cracking knuckles', 'cracking neck', 'crawling baby', 'crocheting', 'crossing eyes', 'crossing river', 'crying', 'cumbia', 'curling (sport)', 'curling eyelashes', 'curling hair', 'cutting apple', 'cutting cake', 'cutting nails', 'cutting orange', 'cutting pineapple', 'cutting watermelon', 'dancing ballet', 'dancing charleston', 'dancing gangnam style', 'dancing macarena', 'deadlifting', 'dealing cards', 'decorating the christmas tree', 'decoupage', 'delivering mail', 'digging', 'dining', 'directing traffic', 'disc golfing', 'diving cliff', 'docking boat', 'dodgeball', 'doing aerobics', 'doing jigsaw puzzle', 'doing laundry', 'doing nails', 'doing sudoku', 'drawing', 'dribbling basketball', 'drinking shots', 'driving car', 'driving tractor', 'drooling', 'drop kicking', 'drumming fingers', 'dumpster diving', 'dunking basketball', 'dyeing eyebrows', 'dyeing hair', 'eating burger', 'eating cake', 'eating carrots', 'eating chips', 'eating doughnuts', 'eating hotdog', 'eating ice cream', 'eating nachos', 'eating spaghetti', 'eating watermelon', 'egg hunting', 'embroidering', 'entering church', 'exercising arm', 'exercising with an exercise ball', 'extinguishing fire', 'faceplanting', 'falling off bike', 'falling off chair', 'feeding birds', 'feeding fish', 'feeding goats', 'fencing (sport)', 'fidgeting', 'filling cake', 'filling eyebrows', 'finger snapping', 'fixing bicycle', 'fixing hair', 'flint knapping', 'flipping bottle', 'flipping pancake', 'fly tying', 'flying kite', 'folding clothes', 'folding napkins', 'folding paper', 'front raises', 'frying vegetables', 'gargling', 'geocaching', 'getting a haircut', 'getting a piercing', 'getting a tattoo', 'giving or receiving award', 'gold panning', 'golf chipping', 'golf driving', 'golf putting', 'gospel singing in church', 'grinding meat', 'grooming cat', 'grooming dog', 'grooming horse', 'gymnastics tumbling', 'hammer throw', 'hand washing clothes', 'head stand', 'headbanging', 'headbutting', 'helmet diving', 'herding cattle', 'high fiving', 'high jump', 'high kick', 'historical reenactment', 'hitting baseball', 'hockey stop', 'holding snake', 'home roasting coffee', 'hopscotch', 'hoverboarding', 'huddling', 'hugging (not baby)', 'hugging baby', 'hula hooping', 'hurdling', 'hurling (sport)', 'ice climbing', 'ice fishing', 'ice skating', 'ice swimming', 'inflating balloons', 'installing carpet', 'ironing', 'ironing hair', 'javelin throw', 'jaywalking', 'jetskiing', 'jogging', 'juggling balls', 'juggling fire', 'juggling soccer ball', 'jumping bicycle', 'jumping into pool', 'jumping jacks', 'jumping sofa', 'jumpstyle dancing', 'karaoke', 'kicking field goal', 'kicking soccer ball', 'kissing', 'kitesurfing', 'knitting', 'krumping', 'land sailing', 'laughing', 'lawn mower racing', 'laying bricks', 'laying concrete', 'laying decking', 'laying stone', 'laying tiles', 'leatherworking', 'letting go of balloon', 'licking', 'lifting hat', 'lighting candle', 'lighting fire', 'listening with headphones', 'lock picking', 'long jump', 'longboarding', 'looking at phone', 'looking in mirror', 'luge', 'lunge', 'making a cake', 'making a sandwich', 'making balloon shapes', 'making bubbles', 'making cheese', 'making horseshoes', 'making jewelry', 'making latte art', 'making paper aeroplanes', 'making pizza', 'making slime', 'making snowman', 'making sushi', 'making tea', 'making the bed', 'marching', 'marriage proposal', 'massaging back', 'massaging feet', 'massaging legs', 'massaging neck', \"massaging person's head\", 'metal detecting', 'milking cow', 'milking goat', 'mixing colours', 'moon walking', 'mopping floor', 'mosh pit dancing', 'motorcycling', 'mountain climber (exercise)', 'moving baby', 'moving child', 'moving furniture', 'mowing lawn', 'mushroom foraging', 'needle felting', 'news anchoring', 'opening bottle (not wine)', 'opening coconuts', 'opening door', 'opening present', 'opening refrigerator', 'opening wine bottle', 'packing', 'paragliding', 'parasailing', 'parkour', 'passing American football (in game)', 'passing American football (not in game)', 'passing soccer ball', 'peeling apples', 'peeling banana', 'peeling potatoes', 'person collecting garbage', 'petting animal (not cat)', 'petting cat', 'petting horse', 'photobombing', 'photocopying', 'picking apples', 'picking blueberries', 'pillow fight', 'pinching', 'pirouetting', 'planing wood', 'planting trees', 'plastering', 'playing accordion', 'playing american football', 'playing badminton', 'playing bagpipes', 'playing basketball', 'playing bass guitar', 'playing beer pong', 'playing billiards', 'playing blackjack', 'playing cards', 'playing cello', 'playing checkers', 'playing chess', 'playing clarinet', 'playing controller', 'playing cricket', 'playing cymbals', 'playing darts', 'playing didgeridoo', 'playing dominoes', 'playing drums', 'playing field hockey', 'playing flute', 'playing gong', 'playing guitar', 'playing hand clapping games', 'playing harmonica', 'playing harp', 'playing ice hockey', 'playing keyboard', 'playing kickball', 'playing laser tag', 'playing lute', 'playing mahjong', 'playing maracas', 'playing marbles', 'playing monopoly', 'playing netball', 'playing nose flute', 'playing oboe', 'playing ocarina', 'playing organ', 'playing paintball', 'playing pan pipes', 'playing piano', 'playing piccolo', 'playing pinball', 'playing ping pong', 'playing poker', 'playing polo', 'playing recorder', 'playing road hockey', 'playing rounders', 'playing rubiks cube', 'playing saxophone', 'playing scrabble', 'playing shuffleboard', 'playing slot machine', 'playing squash or racquetball', 'playing tennis', 'playing trombone', 'playing trumpet', 'playing ukulele', 'playing violin', 'playing volleyball', 'playing with trains', 'playing xylophone', 'poaching eggs', 'poking bellybutton', 'pole vault', 'polishing furniture', 'polishing metal', 'popping balloons', 'pouring beer', 'pouring milk', 'pouring wine', 'preparing salad', 'presenting weather forecast', 'pretending to be a statue', 'pull ups', 'pulling espresso shot', 'pulling rope (game)', 'pumping fist', 'pumping gas', 'punching bag', 'punching person (boxing)', 'push up', 'pushing car', 'pushing cart', 'pushing wheelbarrow', 'pushing wheelchair', 'putting in contact lenses', 'putting on eyeliner', 'putting on foundation', 'putting on lipstick', 'putting on mascara', 'putting on sari', 'putting on shoes', 'putting wallpaper on wall', 'raising eyebrows', 'reading book', 'reading newspaper', 'recording music', 'repairing puncture', 'riding a bike', 'riding camel', 'riding elephant', 'riding mechanical bull', 'riding mule', 'riding or walking with horse', 'riding scooter', 'riding snow blower', 'riding unicycle', 'ripping paper', 'roasting marshmallows', 'roasting pig', 'robot dancing', 'rock climbing', 'rock scissors paper', 'roller skating', 'rolling eyes', 'rolling pastry', 'rope pushdown', 'running on treadmill', 'sailing', 'salsa dancing', 'saluting', 'sanding floor', 'sanding wood', 'sausage making', 'sawing wood', 'scrambling eggs', 'scrapbooking', 'scrubbing face', 'scuba diving', 'seasoning food', 'separating eggs', 'setting table', 'sewing', 'shaking hands', 'shaking head', 'shaping bread dough', 'sharpening knives', 'sharpening pencil', 'shaving head', 'shaving legs', 'shearing sheep', 'shining flashlight', 'shining shoes', 'shoot dance', 'shooting basketball', 'shooting goal (soccer)', 'shooting off fireworks', 'shopping', 'shot put', 'shouting', 'shoveling snow', 'shredding paper', 'shucking oysters', 'shuffling cards', 'shuffling feet', 'side kick', 'sieving', 'sign language interpreting', 'silent disco', 'singing', 'sipping cup', 'situp', 'skateboarding', 'ski ballet', 'ski jumping', 'skiing crosscountry', 'skiing mono', 'skiing slalom', 'skipping rope', 'skipping stone', 'skydiving', 'slacklining', 'slapping', 'sled dog racing', 'sleeping', 'slicing onion', 'smashing', 'smelling feet', 'smoking', 'smoking hookah', 'smoking pipe', 'snatch weight lifting', 'sneezing', 'snorkeling', 'snowboarding', 'snowkiting', 'snowmobiling', 'somersaulting', 'spelunking', 'spinning plates', 'spinning poi', 'splashing water', 'spray painting', 'spraying', 'springboard diving', 'square dancing', 'squat', 'squeezing orange', 'stacking cups', 'stacking dice', 'standing on hands', 'staring', 'steer roping', 'steering car', 'sticking tongue out', 'stomping grapes', 'stretching arm', 'stretching leg', 'sucking lolly', 'surfing crowd', 'surfing water', 'surveying', 'sweeping floor', 'swimming backstroke', 'swimming breast stroke', 'swimming butterfly stroke', 'swimming front crawl', 'swimming with dolphins', 'swimming with sharks', 'swing dancing', 'swinging baseball bat', 'swinging on something', 'sword fighting', 'sword swallowing', 'tackling', 'tagging graffiti', 'tai chi', 'taking photo', 'talking on cell phone', 'tango dancing', 'tap dancing', 'tapping guitar', 'tapping pen', 'tasting beer', 'tasting food', 'tasting wine', 'testifying', 'texting', 'threading needle', 'throwing axe', 'throwing ball (not baseball or American football)', 'throwing discus', 'throwing knife', 'throwing snowballs', 'throwing tantrum', 'throwing water balloon', 'tickling', 'tie dying', 'tightrope walking', 'tiptoeing', 'tobogganing', 'tossing coin', 'tossing salad', 'training dog', 'trapezing', 'treating wood', 'trimming or shaving beard', 'trimming shrubs', 'trimming trees', 'triple jump', 'twiddling fingers', 'tying bow tie', 'tying knot (not on a tie)', 'tying necktie', 'tying shoe laces', 'unboxing', 'uncorking champagne', 'unloading truck', 'using a microscope', 'using a paint roller', 'using a power drill', 'using a sledge hammer', 'using a wrench', 'using atm', 'using bagging machine', 'using circular saw', 'using inhaler', 'using megaphone', 'using puppets', 'using remote controller (not gaming)', 'using segway', 'vacuuming car', 'vacuuming floor', 'visiting the zoo', 'wading through mud', 'wading through water', 'waiting in line', 'waking up', 'walking on stilts', 'walking the dog', 'walking through snow', 'walking with crutches', 'washing dishes', 'washing feet', 'washing hair', 'washing hands', 'watching tv', 'water skiing', 'water sliding', 'watering plants', 'waving hand', 'waxing armpits', 'waxing back', 'waxing chest', 'waxing eyebrows', 'waxing legs', 'weaving basket', 'weaving fabric', 'welding', 'whistling', 'windsurfing', 'winking', 'wood burning (art)', 'wrapping present', 'wrestling', 'writing', 'yarn spinning', 'yawning', 'yoga', 'zumba']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "1f3aac0e-80b9-4bad-8522-d21396243305",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/K700/category_mapping.txt', 'w') as f:\n",
+ " for idx, class_name in enumerate(class_names):\n",
+ " f.write(f'{class_name}\\t{idx}\\n')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "88db7e57-0753-4ac8-bcc2-a9b519f155f8",
+ "metadata": {},
+ "source": [
+ "## Moments in Time"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "f6ec7830-76cb-419d-bcd9-c201a862bd8e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "classname_file = '/nfs/zhujinguo/datasets/open_source_dataset/MomentsInTime/categories.txt'\n",
+ "class_names = [ ] \n",
+ "with open(classname_file) as f:\n",
+ " for line in f.readlines():\n",
+ " info = line.strip().split(',')[0]\n",
+ " class_names.append(info)\n",
+ " # class_names.append(info.replace('+',' '))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "6e121450-f683-4260-b17c-7d69090aa1e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open('/nfs/zhujinguo/datasets/open_source_dataset/MomentsInTime/category_mapping.txt', 'w') as f:\n",
+ " for idx, class_name in enumerate(class_names):\n",
+ " f.write(f'{class_name}\\t{idx}\\n')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/uniperceiver/checkpoint/__init__.py b/uniperceiver/checkpoint/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b2a6fa774a22991fb39413bca47756c33ae2a29
--- /dev/null
+++ b/uniperceiver/checkpoint/__init__.py
@@ -0,0 +1,3 @@
+
+from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer
+from .custom_checkpoint import PeriodicEpochCheckpointer, TorchCheckpointer
diff --git a/uniperceiver/checkpoint/c2_model_loading.py b/uniperceiver/checkpoint/c2_model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c8d181bd7200bd3fd38446e743f8f16780d6e76
--- /dev/null
+++ b/uniperceiver/checkpoint/c2_model_loading.py
@@ -0,0 +1,407 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import logging
+import re
+from typing import Dict, List
+import torch
+from tabulate import tabulate
+
+
+def convert_basic_c2_names(original_keys):
+ """
+ Apply some basic name conversion to names in C2 weights.
+ It only deals with typical backbone models.
+
+ Args:
+ original_keys (list[str]):
+ Returns:
+ list[str]: The same number of strings matching those in original_keys.
+ """
+ layer_keys = copy.deepcopy(original_keys)
+ layer_keys = [
+ {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys
+ ] # some hard-coded mappings
+
+ layer_keys = [k.replace("_", ".") for k in layer_keys]
+ layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys]
+ layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys]
+ # Uniform both bn and gn names to "norm"
+ layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys]
+ layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys]
+ layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys]
+
+ # stem
+ layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys]
+ # to avoid mis-matching with "conv1" in other components (e.g. detection head)
+ layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys]
+
+ # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5)
+ # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys]
+ # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys]
+
+ # blocks
+ layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys]
+ layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys]
+
+ # DensePose substitutions
+ layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys]
+ layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys]
+ layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys]
+ layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys]
+ layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys]
+ return layer_keys
+
+
+def convert_c2_detectron_names(weights):
+ """
+ Map Caffe2 Detectron weight names to Detectron2 names.
+
+ Args:
+ weights (dict): name -> tensor
+
+ Returns:
+ dict: detectron2 names -> tensor
+ dict: detectron2 names -> C2 names
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Renaming Caffe2 weights ......")
+ original_keys = sorted(weights.keys())
+ layer_keys = copy.deepcopy(original_keys)
+
+ layer_keys = convert_basic_c2_names(layer_keys)
+
+ # --------------------------------------------------------------------------
+ # RPN hidden representation conv
+ # --------------------------------------------------------------------------
+ # FPN case
+ # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then
+ # shared for all other levels, hence the appearance of "fpn2"
+ layer_keys = [
+ k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys
+ ]
+ # Non-FPN case
+ layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # RPN box transformation conv
+ # --------------------------------------------------------------------------
+ # FPN case (see note above about "fpn2")
+ layer_keys = [
+ k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas")
+ for k in layer_keys
+ ]
+ layer_keys = [
+ k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits")
+ for k in layer_keys
+ ]
+ # Non-FPN case
+ layer_keys = [
+ k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys
+ ]
+ layer_keys = [
+ k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits")
+ for k in layer_keys
+ ]
+
+ # --------------------------------------------------------------------------
+ # Fast R-CNN box head
+ # --------------------------------------------------------------------------
+ layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys]
+ layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys]
+ layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys]
+ layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys]
+ # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s
+ layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # FPN lateral and output convolutions
+ # --------------------------------------------------------------------------
+ def fpn_map(name):
+ """
+ Look for keys with the following patterns:
+ 1) Starts with "fpn.inner."
+ Example: "fpn.inner.res2.2.sum.lateral.weight"
+ Meaning: These are lateral pathway convolutions
+ 2) Starts with "fpn.res"
+ Example: "fpn.res2.2.sum.weight"
+ Meaning: These are FPN output convolutions
+ """
+ splits = name.split(".")
+ norm = ".norm" if "norm" in splits else ""
+ if name.startswith("fpn.inner."):
+ # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight']
+ stage = int(splits[2][len("res") :])
+ return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1])
+ elif name.startswith("fpn.res"):
+ # splits example: ['fpn', 'res2', '2', 'sum', 'weight']
+ stage = int(splits[1][len("res") :])
+ return "fpn_output{}{}.{}".format(stage, norm, splits[-1])
+ return name
+
+ layer_keys = [fpn_map(k) for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Mask R-CNN mask head
+ # --------------------------------------------------------------------------
+ # roi_heads.StandardROIHeads case
+ layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys]
+ layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys]
+ layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys]
+ # roi_heads.Res5ROIHeads case
+ layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Keypoint R-CNN head
+ # --------------------------------------------------------------------------
+ # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX"
+ layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys]
+ layer_keys = [
+ k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys
+ ]
+ layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys]
+
+ # --------------------------------------------------------------------------
+ # Done with replacements
+ # --------------------------------------------------------------------------
+ assert len(set(layer_keys)) == len(layer_keys)
+ assert len(original_keys) == len(layer_keys)
+
+ new_weights = {}
+ new_keys_to_original_keys = {}
+ for orig, renamed in zip(original_keys, layer_keys):
+ new_keys_to_original_keys[renamed] = orig
+ if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."):
+ # remove the meaningless prediction weight for background class
+ new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1
+ new_weights[renamed] = weights[orig][new_start_idx:]
+ logger.info(
+ "Remove prediction weight for background class in {}. The shape changes from "
+ "{} to {}.".format(
+ renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape)
+ )
+ )
+ elif renamed.startswith("cls_score."):
+ # move weights of bg class from original index 0 to last index
+ logger.info(
+ "Move classification weights for background class in {} from index 0 to "
+ "index {}.".format(renamed, weights[orig].shape[0] - 1)
+ )
+ new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]])
+ else:
+ new_weights[renamed] = weights[orig]
+
+ return new_weights, new_keys_to_original_keys
+
+
+# Note the current matching is not symmetric.
+# it assumes model_state_dict will have longer names.
+def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
+ """
+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
+ converted to match model_state_dict with heuristics. The returned dict can be later
+ loaded with fvcore checkpointer.
+ If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
+ model and will be renamed at first.
+
+ Strategy: suppose that the models that we will create will have prefixes appended
+ to each of its keys, for example due to an extra level of nesting that the original
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+ res2.conv1.weight. We thus want to match both parameters together.
+ For that, we look for each model weight, look among all loaded keys if there is one
+ that is a suffix of the current weight name, and use it if that's the case.
+ If multiple matches exist, take the one with longest size
+ of the corresponding name. For example, for the same model as before, the pretrained
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+ """
+ model_keys = sorted(model_state_dict.keys())
+ if c2_conversion:
+ ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict)
+ # original_keys: the name in the original dict (before renaming)
+ else:
+ original_keys = {x: x for x in ckpt_state_dict.keys()}
+ ckpt_keys = sorted(ckpt_state_dict.keys())
+
+ def match(a, b):
+ # Matched ckpt_key should be a complete (starts with '.') suffix.
+ # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1,
+ # but matches whatever_conv1 or mesh_head.whatever_conv1.
+ return a == b or a.endswith("." + b)
+
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
+ # ckpt_key string, if it matches
+ match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys]
+ match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys))
+ # use the matched one with longest size in case of multiple matches
+ max_match_size, idxs = match_matrix.max(1)
+ # remove indices that correspond to no-match
+ idxs[max_match_size == 0] = -1
+
+ logger = logging.getLogger(__name__)
+ # matched_pairs (matched checkpoint key --> matched model key)
+ matched_keys = {}
+ result_state_dict = {}
+ for idx_model, idx_ckpt in enumerate(idxs.tolist()):
+ if idx_ckpt == -1:
+ continue
+ key_model = model_keys[idx_model]
+ key_ckpt = ckpt_keys[idx_ckpt]
+ value_ckpt = ckpt_state_dict[key_ckpt]
+ shape_in_model = model_state_dict[key_model].shape
+
+ if shape_in_model != value_ckpt.shape:
+ logger.warning(
+ "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format(
+ key_ckpt, value_ckpt.shape, key_model, shape_in_model
+ )
+ )
+ logger.warning(
+ "{} will not be loaded. Please double check and see if this is desired.".format(
+ key_ckpt
+ )
+ )
+ continue
+
+ assert key_model not in result_state_dict
+ result_state_dict[key_model] = value_ckpt
+ if key_ckpt in matched_keys: # already added to matched_keys
+ logger.error(
+ "Ambiguity found for {} in checkpoint!"
+ "It matches at least two keys in the model ({} and {}).".format(
+ key_ckpt, key_model, matched_keys[key_ckpt]
+ )
+ )
+ raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
+
+ matched_keys[key_ckpt] = key_model
+
+ # logging:
+ matched_model_keys = sorted(matched_keys.values())
+ if len(matched_model_keys) == 0:
+ logger.warning("No weights in checkpoint matched with model.")
+ return ckpt_state_dict
+ common_prefix = _longest_common_prefix(matched_model_keys)
+ rev_matched_keys = {v: k for k, v in matched_keys.items()}
+ original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
+
+ model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
+ table = []
+ memo = set()
+ for key_model in matched_model_keys:
+ if key_model in memo:
+ continue
+ if key_model in model_key_groups:
+ group = model_key_groups[key_model]
+ memo |= set(group)
+ shapes = [tuple(model_state_dict[k].shape) for k in group]
+ table.append(
+ (
+ _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
+ _group_str([original_keys[k] for k in group]),
+ " ".join([str(x).replace(" ", "") for x in shapes]),
+ )
+ )
+ else:
+ key_checkpoint = original_keys[key_model]
+ shape = str(tuple(model_state_dict[key_model].shape))
+ table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
+ table_str = tabulate(
+ table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
+ )
+ logger.info(
+ "Following weights matched with "
+ + (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
+ + ":\n"
+ + table_str
+ )
+
+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
+ for k in unmatched_ckpt_keys:
+ result_state_dict[k] = ckpt_state_dict[k]
+ return result_state_dict
+
+
+def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
+ """
+ Params in the same submodule are grouped together.
+
+ Args:
+ keys: names of all parameters
+ original_names: mapping from parameter name to their name in the checkpoint
+
+ Returns:
+ dict[name -> all other names in the same group]
+ """
+
+ def _submodule_name(key):
+ pos = key.rfind(".")
+ if pos < 0:
+ return None
+ prefix = key[: pos + 1]
+ return prefix
+
+ all_submodules = [_submodule_name(k) for k in keys]
+ all_submodules = [x for x in all_submodules if x]
+ all_submodules = sorted(all_submodules, key=len)
+
+ ret = {}
+ for prefix in all_submodules:
+ group = [k for k in keys if k.startswith(prefix)]
+ if len(group) <= 1:
+ continue
+ original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
+ if len(original_name_lcp) == 0:
+ # don't group weights if original names don't share prefix
+ continue
+
+ for k in group:
+ if k in ret:
+ continue
+ ret[k] = group
+ return ret
+
+
+def _longest_common_prefix(names: List[str]) -> str:
+ """
+ ["abc.zfg", "abc.zef"] -> "abc."
+ """
+ names = [n.split(".") for n in names]
+ m1, m2 = min(names), max(names)
+ ret = [a for a, b in zip(m1, m2) if a == b]
+ ret = ".".join(ret) + "." if len(ret) else ""
+ return ret
+
+
+def _longest_common_prefix_str(names: List[str]) -> str:
+ m1, m2 = min(names), max(names)
+ lcp = [a for a, b in zip(m1, m2) if a == b]
+ lcp = "".join(lcp)
+ return lcp
+
+
+def _group_str(names: List[str]) -> str:
+ """
+ Turn "common1", "common2", "common3" into "common{1,2,3}"
+ """
+ lcp = _longest_common_prefix_str(names)
+ rest = [x[len(lcp) :] for x in names]
+ rest = "{" + ",".join(rest) + "}"
+ ret = lcp + rest
+
+ # add some simplification for BN specifically
+ ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
+ ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
+ return ret
diff --git a/uniperceiver/checkpoint/custom_checkpoint.py b/uniperceiver/checkpoint/custom_checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..7efb51d144028ce5ff9ffbaa96f7590580cd4952
--- /dev/null
+++ b/uniperceiver/checkpoint/custom_checkpoint.py
@@ -0,0 +1,305 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import os
+import pickle
+import torch
+# from typing import Any
+from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer, _IncompatibleKeys
+from fvcore.common.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
+from torch.nn.parallel import DistributedDataParallel
+
+import uniperceiver.utils.comm as comm
+from uniperceiver.utils.env import TORCH_VERSION
+from uniperceiver.utils.file_io import PathManager
+from collections import defaultdict
+import copy
+import io
+
+from .c2_model_loading import align_and_update_state_dicts
+from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Tuple
+
+# import deepspeed
+# from deepspeed.runtime.engine import DeepSpeedEngine
+import shutil
+from timm.utils import ModelEma
+
+
+class PeriodicEpochCheckpointer(PeriodicCheckpointer):
+ def step(self, iteration: int, epoch: int, **kwargs: Any) -> None:
+ """
+ Perform the appropriate action at the given iteration.
+
+ Args:
+ iteration (int): the current iteration, ranged in [0, max_iter-1].
+ kwargs (Any): extra data to save, same as in
+ :meth:`Checkpointer.save`.
+ """
+ iteration = int(iteration)
+ epoch = int(epoch)
+ additional_state = {"iteration": iteration}
+ additional_state.update(kwargs)
+
+ if (iteration + 1) % self.period == 0:
+
+ self.checkpointer.save(
+ "{}_Epoch_{:05d}_Iter_{:07d}".format(self.file_prefix, epoch,
+ iteration),
+ **additional_state)
+
+ if self.max_to_keep is not None:
+ self.recent_checkpoints.append(
+ self.checkpointer.get_checkpoint_file())
+ # pyre-fixme[58]: `>` is not supported for operand types `int` and
+ # `Optional[int]`.
+ if len(self.recent_checkpoints) > self.max_to_keep:
+ file_to_delete = self.recent_checkpoints.pop(0)
+ if self.path_manager.exists(
+ file_to_delete) and not file_to_delete.endswith(
+ f"{self.file_prefix}_final.pth"):
+ self.path_manager.rm(file_to_delete)
+
+
+
+class TorchCheckpointer(Checkpointer):
+ """
+ Same as :class:`Checkpointer`, but is able to handle models in uniperceiver
+ model zoo, and apply conversions for legacy models.
+ """
+ def __init__(
+ self,
+ model,
+ model_ema: ModelEma,
+ save_dir="",
+ *,
+ save_to_disk=None,
+ checkpoint_mapping=None,
+ mapping=False,
+ resume_tau=True,
+ ceph_save=False,
+ ceph_config=None,
+ **checkpointables,
+ ):
+ is_main_process = comm.is_main_process()
+ super().__init__(
+ model,
+ save_dir,
+ save_to_disk=is_main_process
+ if save_to_disk is None else save_to_disk,
+ **checkpointables,
+ )
+ self.path_manager = PathManager
+
+ if checkpoint_mapping is None:
+ self.checkpoint_mapping = None
+ else:
+ self.checkpoint_mapping = defaultdict(list)
+ for mapping_pair in checkpoint_mapping:
+ self.checkpoint_mapping[mapping_pair['ORIGIN']].append(
+ mapping_pair['DEST'])
+ self.mapping = mapping
+ self.resume_tau = resume_tau
+ self.ceph_save = ceph_save
+ if self.ceph_save:
+ self.path_prefix = 's3://checkpoints_zjg/'
+ self.client = PetrelBackend(path_mapping={},
+ tcs_conf_path=ceph_config)
+ # if self.ceph_save and is_main_process:
+ # # for local machine debug
+ # if os.path.relpath(self.save_dir, os.getcwd()).startswith('outputs'):
+ # self.client.remove(self.save_dir)
+
+ def _load_file(self, filename):
+ if filename.endswith(".pkl"):
+ with PathManager.open(filename, "rb") as f:
+ data = pickle.load(f, encoding="latin1")
+ if "model" in data and "__author__" in data:
+ # file is in Detectron2 model zoo format
+ self.logger.info("Reading a file from '{}'".format(
+ data["__author__"]))
+ return data
+ else:
+ # assume file is from Caffe2 / Detectron1 model zoo
+ if "blobs" in data:
+ # Detection models have "blobs", but ImageNet models don't
+ data = data["blobs"]
+ data = {
+ k: v
+ for k, v in data.items() if not k.endswith("_momentum")
+ }
+ return {
+ "model": data,
+ "__author__": "Caffe2",
+ "matching_heuristics": True
+ }
+ if self.ceph_save:
+ relpath = os.path.relpath(filename, os.getcwd())
+ s3url = os.path.join(self.path_prefix, relpath)
+ with io.BytesIO(self.client.get(s3url)) as buffer:
+ loaded = torch.load(buffer, map_location=torch.device("cpu"))
+ else:
+ loaded = super()._load_file(filename) # load native pth checkpoint
+ if "model" not in loaded:
+ loaded = {"model": loaded}
+ return loaded
+
+ def save(self, name: str, **kwargs: Any) -> None:
+ """
+ Dump model and checkpointables to a file.
+
+ Args:
+ name (str): name of the file.
+ kwargs (dict): extra arbitrary data to save.
+ """
+ if not self.save_dir or not self.save_to_disk:
+ return
+
+ data = {}
+ data["model"] = self.model.state_dict()
+ for key, obj in self.checkpointables.items():
+ data[key] = obj.state_dict()
+ data.update(kwargs)
+
+ basename = "{}.pth".format(name)
+
+ if self.ceph_save:
+ local_save_file = os.path.join(self.save_dir, basename)
+ relpath = os.path.relpath(local_save_file, os.getcwd())
+ save_file = os.path.join(self.path_prefix, relpath)
+ assert os.path.basename(save_file) == basename, basename
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ with io.BytesIO() as f:
+ torch.save(data, f)
+ self.client.put(f.getvalue(), save_file)
+ else:
+ save_file = os.path.join(self.save_dir, basename)
+ assert os.path.basename(save_file) == basename, basename
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ with self.path_manager.open(save_file, "wb") as f:
+ torch.save(data, f)
+ self.tag_last_checkpoint(basename)
+
+ def load(self,
+ path: str,
+ checkpointables: Optional[List[str]] = None) -> Dict[str, Any]:
+ """
+ Load from the given checkpoint.
+
+ Args:
+ path (str): path or url to the checkpoint. If empty, will not load
+ anything.
+ checkpointables (list): List of checkpointable names to load. If not
+ specified (None), will load all the possible checkpointables.
+ Returns:
+ dict:
+ extra data loaded from the checkpoint that has not been
+ processed. For example, those saved with
+ :meth:`.save(**extra_data)`.
+ """
+ if not path:
+ # no checkpoint provided
+ self.logger.info(
+ "No checkpoint found. Initializing model from scratch")
+ return {}
+ self.logger.info("[Checkpointer] Loading from {} ...".format(path))
+ if not self.ceph_save:
+ if not os.path.isfile(path):
+ path = self.path_manager.get_local_path(path)
+ assert os.path.isfile(path), "Checkpoint {} not found!".format(
+ path)
+ else:
+ relpath = os.path.relpath(path, os.getcwd())
+ s3url = os.path.join(self.path_prefix, relpath)
+ #TODO: dev branch is needed
+ # if not self.client.exists(s3url):
+ # assert self.client.exists(s3url), "Checkpoint {} not found!".format(s3url)
+
+ checkpoint = self._load_file(path)
+ incompatible = self._load_model(checkpoint)
+ if (incompatible is not None
+ ): # handle some existing subclasses that returns None
+ self._log_incompatible_keys(incompatible)
+
+ for key in self.checkpointables if checkpointables is None else checkpointables:
+ if key in checkpoint:
+ self.logger.info("Loading {} from {} ...".format(key, path))
+ obj = self.checkpointables[key]
+ obj.load_state_dict(checkpoint.pop(key))
+
+ # return any further checkpoint data
+ return checkpoint
+
+ def _convert_checkpoint(self, checkpoint):
+ # for multitask pretrain and fintune
+ if self.checkpoint_mapping is not None and self.mapping:
+ pretrain_checkpoint = checkpoint["model"]
+ for origin_task in self.checkpoint_mapping.keys():
+ for k in list(pretrain_checkpoint.keys()):
+ if origin_task in k:
+ # mapping to downstrean task
+ state_dict_temp = copy.deepcopy(
+ pretrain_checkpoint.pop(k))
+ for subtask in self.checkpoint_mapping[origin_task]:
+ new_key = k.replace(origin_task, subtask)
+ pretrain_checkpoint[new_key] = state_dict_temp
+ checkpoint["model"] = pretrain_checkpoint
+
+ if not self.resume_tau:
+ pretrain_checkpoint = checkpoint["model"]
+ for k in list(pretrain_checkpoint.keys()):
+ if "logit_scale" in k:
+ pretrain_checkpoint.pop(k)
+ checkpoint["model"] = pretrain_checkpoint
+ return checkpoint
+
+ def _load_model(self, checkpoint):
+ if checkpoint.get("matching_heuristics", False):
+ self._convert_ndarray_to_tensor(checkpoint["model"])
+ # convert weights by name-matching heuristics
+ model_state_dict = self.model.state_dict()
+ align_and_update_state_dicts(
+ model_state_dict,
+ checkpoint["model"],
+ c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
+ )
+ checkpoint["model"] = model_state_dict
+
+ # convert checkpoint for pretrained model between different tasks
+ checkpoint = self._convert_checkpoint(checkpoint)
+
+ # for non-caffe2 models, use standard ways to load it
+ incompatible = super()._load_model(checkpoint)
+ if incompatible is None: # support older versions of fvcore
+ return None
+
+ model_buffers = dict(self.model.named_buffers(recurse=False))
+ for k in ["pixel_mean", "pixel_std"]:
+ # Ignore missing key message about pixel_mean/std.
+ # Though they may be missing in old checkpoints, they will be correctly
+ # initialized from config anyway.
+ if k in model_buffers:
+ try:
+ incompatible.missing_keys.remove(k)
+ except ValueError:
+ pass
+ return incompatible
+
+ def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
+ """
+ Log information about the incompatible keys returned by ``_load_model``.
+ """
+ for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
+ self.logger.warning(
+ "Skip loading parameter '{}' to the model due to incompatible "
+ "shapes: {} in the checkpoint but {} in the "
+ "model! You might want to double check if this is expected.".
+ format(k, shape_checkpoint, shape_model))
+ if incompatible.missing_keys:
+ self.logger.info(
+ get_missing_parameters_message(incompatible.missing_keys))
+ if incompatible.unexpected_keys:
+ self.logger.info(
+ get_unexpected_parameters_message(
+ incompatible.unexpected_keys))
+
+ def resume_or_load(self, path, resume: bool = True, **kwargs):
+ super().resume_or_load(path, resume=resume)
diff --git a/uniperceiver/config/__init__.py b/uniperceiver/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..269b4cb1787cedd4260a3e9c71654263f0930fa0
--- /dev/null
+++ b/uniperceiver/config/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .compat import downgrade_config, upgrade_config
+from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable
+
+
+
+
+
+__all__ = [
+ "CfgNode", "get_cfg", "global_cfg", "set_global_cfg", "downgrade_config",
+ "upgrade_config", "configurable",
+]
diff --git a/uniperceiver/config/compat.py b/uniperceiver/config/compat.py
new file mode 100644
index 0000000000000000000000000000000000000000..11a08c439bf14defd880e37a938fab8a08e68eeb
--- /dev/null
+++ b/uniperceiver/config/compat.py
@@ -0,0 +1,229 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+Backward compatibility of configs.
+
+Instructions to bump version:
++ It's not needed to bump version if new keys are added.
+ It's only needed when backward-incompatible changes happen
+ (i.e., some existing keys disappear, or the meaning of a key changes)
++ To bump version, do the following:
+ 1. Increment _C.VERSION in defaults.py
+ 2. Add a converter in this file.
+
+ Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X,
+ and a function "downgrade" which in-place downgrades config from X to X-1
+
+ In each function, VERSION is left unchanged.
+
+ Each converter assumes that its input has the relevant keys
+ (i.e., the input is not a partial config).
+ 3. Run the tests (test_config.py) to make sure the upgrade & downgrade
+ functions are consistent.
+"""
+
+import logging
+from typing import List, Optional, Tuple
+
+from .config import CfgNode as CN
+from .defaults import _C
+
+__all__ = ["upgrade_config", "downgrade_config"]
+
+
+def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN:
+ """
+ Upgrade a config from its current version to a newer version.
+
+ Args:
+ cfg (CfgNode):
+ to_version (int): defaults to the latest version.
+ """
+ cfg = cfg.clone()
+ if to_version is None:
+ to_version = _C.VERSION
+
+ assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format(
+ cfg.VERSION, to_version
+ )
+ for k in range(cfg.VERSION, to_version):
+ converter = globals()["ConverterV" + str(k + 1)]
+ converter.upgrade(cfg)
+ cfg.VERSION = k + 1
+ return cfg
+
+
+def downgrade_config(cfg: CN, to_version: int) -> CN:
+ """
+ Downgrade a config from its current version to an older version.
+
+ Args:
+ cfg (CfgNode):
+ to_version (int):
+
+ Note:
+ A general downgrade of arbitrary configs is not always possible due to the
+ different functionalities in different versions.
+ The purpose of downgrade is only to recover the defaults in old versions,
+ allowing it to load an old partial yaml config.
+ Therefore, the implementation only needs to fill in the default values
+ in the old version when a general downgrade is not possible.
+ """
+ cfg = cfg.clone()
+ assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format(
+ cfg.VERSION, to_version
+ )
+ for k in range(cfg.VERSION, to_version, -1):
+ converter = globals()["ConverterV" + str(k)]
+ converter.downgrade(cfg)
+ cfg.VERSION = k - 1
+ return cfg
+
+
+def guess_version(cfg: CN, filename: str) -> int:
+ """
+ Guess the version of a partial config where the VERSION field is not specified.
+ Returns the version, or the latest if cannot make a guess.
+
+ This makes it easier for users to migrate.
+ """
+ logger = logging.getLogger(__name__)
+
+ def _has(name: str) -> bool:
+ cur = cfg
+ for n in name.split("."):
+ if n not in cur:
+ return False
+ cur = cur[n]
+ return True
+
+ # Most users' partial configs have "MODEL.WEIGHT", so guess on it
+ ret = None
+ if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"):
+ ret = 1
+
+ if ret is not None:
+ logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret))
+ else:
+ ret = _C.VERSION
+ logger.warning(
+ "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format(
+ filename, ret
+ )
+ )
+ return ret
+
+
+def _rename(cfg: CN, old: str, new: str) -> None:
+ old_keys = old.split(".")
+ new_keys = new.split(".")
+
+ def _set(key_seq: List[str], val: str) -> None:
+ cur = cfg
+ for k in key_seq[:-1]:
+ if k not in cur:
+ cur[k] = CN()
+ cur = cur[k]
+ cur[key_seq[-1]] = val
+
+ def _get(key_seq: List[str]) -> CN:
+ cur = cfg
+ for k in key_seq:
+ cur = cur[k]
+ return cur
+
+ def _del(key_seq: List[str]) -> None:
+ cur = cfg
+ for k in key_seq[:-1]:
+ cur = cur[k]
+ del cur[key_seq[-1]]
+ if len(cur) == 0 and len(key_seq) > 1:
+ _del(key_seq[:-1])
+
+ _set(new_keys, _get(old_keys))
+ _del(old_keys)
+
+
+class _RenameConverter:
+ """
+ A converter that handles simple rename.
+ """
+
+ RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name)
+
+ @classmethod
+ def upgrade(cls, cfg: CN) -> None:
+ for old, new in cls.RENAME:
+ _rename(cfg, old, new)
+
+ @classmethod
+ def downgrade(cls, cfg: CN) -> None:
+ for old, new in cls.RENAME[::-1]:
+ _rename(cfg, new, old)
+
+
+class ConverterV1(_RenameConverter):
+ RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")]
+
+
+class ConverterV2(_RenameConverter):
+ """
+ A large bulk of rename, before public release.
+ """
+
+ RENAME = [
+ ("MODEL.WEIGHT", "MODEL.WEIGHTS"),
+ ("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"),
+ ("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD",
+ "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH",
+ ),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT",
+ "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT",
+ ),
+ (
+ "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD",
+ "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH",
+ ),
+ ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"),
+ ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"),
+ ("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"),
+ ("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"),
+ ("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"),
+ ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"),
+ ("TEST.AUG_ON", "TEST.AUG.ENABLED"),
+ ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"),
+ ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"),
+ ("TEST.AUG_FLIP", "TEST.AUG.FLIP"),
+ ]
+
+ @classmethod
+ def upgrade(cls, cfg: CN) -> None:
+ super().upgrade(cfg)
+
+ if cfg.MODEL.META_ARCHITECTURE == "RetinaNet":
+ _rename(
+ cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS"
+ )
+ _rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
+ del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"]
+ del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"]
+ else:
+ _rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS")
+ _rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES")
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"]
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"]
+ del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"]
+
+ @classmethod
+ def downgrade(cls, cfg: CN) -> None:
+ super().downgrade(cfg)
+
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS")
+ _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES")
+ cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS
+ cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES
+ cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version
diff --git a/uniperceiver/config/config.py b/uniperceiver/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..773de3b9bd89de6baee0dfa1d148e951c6198557
--- /dev/null
+++ b/uniperceiver/config/config.py
@@ -0,0 +1,286 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import functools
+import inspect
+import logging
+from fvcore.common.config import CfgNode as _CfgNode
+
+from uniperceiver.utils.file_io import PathManager
+
+
+class CfgNode(_CfgNode):
+ """
+ The same as `fvcore.common.config.CfgNode`, but different in:
+
+ 1. Use unsafe yaml loading by default.
+ Note that this may lead to arbitrary code execution: you must not
+ load a config file from untrusted sources before manually inspecting
+ the content of the file.
+ 2. Support config versioning.
+ When attempting to merge an old config, it will convert the old config automatically.
+
+ .. automethod:: clone
+ .. automethod:: freeze
+ .. automethod:: defrost
+ .. automethod:: is_frozen
+ .. automethod:: load_yaml_with_base
+ .. automethod:: merge_from_list
+ .. automethod:: merge_from_other_cfg
+ """
+
+ @classmethod
+ def _open_cfg(cls, filename):
+ return PathManager.open(filename, "r")
+
+ def load_from_file_tmp(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
+ assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
+ loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
+ loaded_cfg = type(self)(loaded_cfg)
+
+ # defaults.py needs to import CfgNode
+ from .defaults import _C
+
+ latest_ver = _C.VERSION
+ assert (
+ latest_ver == self.VERSION
+ ), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
+
+ return loaded_cfg
+
+ # Note that the default value of allow_unsafe is changed to True
+ def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None:
+ """
+ Load content from the given config file and merge it into self.
+
+ Args:
+ cfg_filename: config filename
+ allow_unsafe: allow unsafe yaml syntax
+ """
+ assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!"
+ loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
+ loaded_cfg = type(self)(loaded_cfg)
+
+ # defaults.py needs to import CfgNode
+ from .defaults import _C
+
+ latest_ver = _C.VERSION
+ assert (
+ latest_ver == self.VERSION
+ ), "CfgNode.merge_from_file is only allowed on a config object of latest version!"
+
+ logger = logging.getLogger(__name__)
+
+ loaded_ver = loaded_cfg.get("VERSION", None)
+ if loaded_ver is None:
+ from .compat import guess_version
+
+ loaded_ver = guess_version(loaded_cfg, cfg_filename)
+ assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format(
+ loaded_ver, self.VERSION
+ )
+
+ if loaded_ver == self.VERSION:
+ self.merge_from_other_cfg(loaded_cfg)
+ else:
+ # compat.py needs to import CfgNode
+ from .compat import upgrade_config, downgrade_config
+
+ logger.warning(
+ "Loading an old v{} config file '{}' by automatically upgrading to v{}. "
+ "See docs/CHANGELOG.md for instructions to update your files.".format(
+ loaded_ver, cfg_filename, self.VERSION
+ )
+ )
+ # To convert, first obtain a full config at an old version
+ old_self = downgrade_config(self, to_version=loaded_ver)
+ old_self.merge_from_other_cfg(loaded_cfg)
+ new_config = upgrade_config(old_self)
+ self.clear()
+ self.update(new_config)
+
+ def dump(self, *args, **kwargs):
+ """
+ Returns:
+ str: a yaml string representation of the config
+ """
+ # to make it show up in docs
+ return super().dump(*args, **kwargs)
+
+ def to_dict_object(self,):
+ for k,v in self.items():
+ if isinstance(v, CfgNode):
+ self[k] = v.to_dict_object()
+ return dict(self)
+
+
+global_cfg = CfgNode()
+
+
+def get_cfg() -> CfgNode:
+ """
+ Get a copy of the default config.
+
+ Returns:
+ a detectron2 CfgNode instance.
+ """
+ from .defaults import _C
+
+ return _C.clone()
+
+
+def set_global_cfg(cfg: CfgNode) -> None:
+ """
+ Let the global config point to the given cfg.
+
+ Assume that the given "cfg" has the key "KEY", after calling
+ `set_global_cfg(cfg)`, the key can be accessed by:
+ ::
+ from detectron2.config import global_cfg
+ print(global_cfg.KEY)
+
+ By using a hacky global config, you can access these configs anywhere,
+ without having to pass the config object or the values deep into the code.
+ This is a hacky feature introduced for quick prototyping / research exploration.
+ """
+ global global_cfg
+ global_cfg.clear()
+ global_cfg.update(cfg)
+
+
+def configurable(init_func=None, *, from_config=None):
+ """
+ Decorate a function or a class's __init__ method so that it can be called
+ with a :class:`CfgNode` object using a :func:`from_config` function that translates
+ :class:`CfgNode` to arguments.
+
+ Examples:
+ ::
+ # Usage 1: Decorator on __init__:
+ class A:
+ @configurable
+ def __init__(self, a, b=2, c=3):
+ pass
+
+ @classmethod
+ def from_config(cls, cfg): # 'cfg' must be the first argument
+ # Returns kwargs to be passed to __init__
+ return {"a": cfg.A, "b": cfg.B}
+
+ a1 = A(a=1, b=2) # regular construction
+ a2 = A(cfg) # construct with a cfg
+ a3 = A(cfg, b=3, c=4) # construct with extra overwrite
+
+ # Usage 2: Decorator on any function. Needs an extra from_config argument:
+ @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B})
+ def a_func(a, b=2, c=3):
+ pass
+
+ a1 = a_func(a=1, b=2) # regular call
+ a2 = a_func(cfg) # call with a cfg
+ a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
+
+ Args:
+ init_func (callable): a class's ``__init__`` method in usage 1. The
+ class must have a ``from_config`` classmethod which takes `cfg` as
+ the first argument.
+ from_config (callable): the from_config function in usage 2. It must take `cfg`
+ as its first argument.
+ """
+
+ if init_func is not None:
+ assert (
+ inspect.isfunction(init_func)
+ and from_config is None
+ and init_func.__name__ == "__init__"
+ ), "Incorrect use of @configurable. Check API documentation for examples."
+
+ @functools.wraps(init_func)
+ def wrapped(self, *args, **kwargs):
+ try:
+ from_config_func = type(self).from_config
+ except AttributeError as e:
+ raise AttributeError(
+ "Class with @configurable must have a 'from_config' classmethod."
+ ) from e
+ if not inspect.ismethod(from_config_func):
+ raise TypeError("Class with @configurable must have a 'from_config' classmethod.")
+
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config_func, *args, **kwargs)
+ init_func(self, **explicit_args)
+ else:
+ init_func(self, *args, **kwargs)
+
+ return wrapped
+
+ else:
+ if from_config is None:
+ return configurable # @configurable() is made equivalent to @configurable
+ assert inspect.isfunction(
+ from_config
+ ), "from_config argument of configurable must be a function!"
+
+ def wrapper(orig_func):
+ @functools.wraps(orig_func)
+ def wrapped(*args, **kwargs):
+ if _called_with_cfg(*args, **kwargs):
+ explicit_args = _get_args_from_config(from_config, *args, **kwargs)
+ return orig_func(**explicit_args)
+ else:
+ return orig_func(*args, **kwargs)
+
+ wrapped.from_config = from_config
+ return wrapped
+
+ return wrapper
+
+
+def _get_args_from_config(from_config_func, *args, **kwargs):
+ """
+ Use `from_config` to obtain explicit arguments.
+
+ Returns:
+ dict: arguments to be used for cls.__init__
+ """
+ signature = inspect.signature(from_config_func)
+ if list(signature.parameters.keys())[0] != "cfg":
+ if inspect.isfunction(from_config_func):
+ name = from_config_func.__name__
+ else:
+ name = f"{from_config_func.__self__}.from_config"
+ raise TypeError(f"{name} must take 'cfg' as the first argument!")
+ support_var_arg = any(
+ param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD]
+ for param in signature.parameters.values()
+ )
+ if support_var_arg: # forward all arguments to from_config, if from_config accepts them
+ ret = from_config_func(*args, **kwargs)
+ else:
+ # forward supported arguments to from_config
+ supported_arg_names = set(signature.parameters.keys())
+ extra_kwargs = {}
+ for name in list(kwargs.keys()):
+ if name not in supported_arg_names:
+ extra_kwargs[name] = kwargs.pop(name)
+ ret = from_config_func(*args, **kwargs)
+ # forward the other arguments to __init__
+ ret.update(extra_kwargs)
+ return ret
+
+
+def _called_with_cfg(*args, **kwargs):
+ """
+ Returns:
+ bool: whether the arguments contain CfgNode and should be considered
+ forwarded to from_config.
+ """
+ from omegaconf import DictConfig
+
+ if len(args) and isinstance(args[0], (_CfgNode, DictConfig)):
+ return True
+ if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)):
+ return True
+ # `from_config`'s first argument is forced to be "cfg".
+ # So the above check covers all cases.
+ return False
diff --git a/uniperceiver/config/defaults.py b/uniperceiver/config/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..48f369de468a199a5cab37c8caa95af52b4cd201
--- /dev/null
+++ b/uniperceiver/config/defaults.py
@@ -0,0 +1,864 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from .config import CfgNode as CN
+
+# -----------------------------------------------------------------------------
+# Config definition
+# -----------------------------------------------------------------------------
+
+_C = CN()
+
+# The version number, to upgrade from old configs to new ones if any
+# changes happen. It's recommended to keep a VERSION in your config file.
+_C.VERSION = 1
+
+_C.NAME = '' # task name
+
+
+# -----------------------------------------------------------------------------
+# Shared targets
+# -----------------------------------------------------------------------------
+_C.SHARED_TARGETS = []
+_C.SHARED_TARGETS_CFG = CN()
+_C.SHARED_TARGETS_CFG.FILE_PATH = ''
+_C.SHARED_TARGETS_CFG.DISTRIBUTED = False
+
+# -----------------------------------------------------------------------------
+# Dataset
+# -----------------------------------------------------------------------------
+_C.DATASETS = CN() #
+
+_C.DATASETS.TRAIN = ''
+
+_C.DATASETS.VAL = ''
+
+_C.DATASETS.TEST = ''
+
+_C.DATASETS.TASK_TYPE = ''
+_C.DATASETS.DATASET_NAME = ''
+_C.DATASETS.TARGET_SET = ['']
+_C.DATASETS.TRAIN_BATCH_SIZE = 64
+_C.DATASETS.TEST_BATCH_SIZE = 64
+_C.DATASETS.VERSION = 'v1'
+
+
+# -----------------------------------------------------------------------------
+# DataLoader
+# -----------------------------------------------------------------------------
+_C.DATALOADER = CN()
+
+_C.DATALOADER.UNIFIED_DATASET = False
+
+_C.DATALOADER.FAKE_DATA = False
+
+_C.DATALOADER.TASK_TYPE = ''
+
+_C.DATALOADER.TRAIN_BATCH_SIZE = 64
+
+_C.DATALOADER.TEST_BATCH_SIZE = 64
+
+_C.DATALOADER.NUM_WORKERS = 4
+
+_C.DATALOADER.FEATS_FOLDER = ''
+
+_C.DATALOADER.LOCAL_PREFIX=''
+
+_C.DATALOADER.SAMPLER=''
+_C.DATALOADER.CACHE_MODE=True
+
+_C.DATALOADER.APPEND_EOS=True
+_C.DATALOADER.ONE_STREAM=True
+_C.DATALOADER.RANDOM_MASK=True
+
+
+
+
+_C.DATALOADER.LOCAL_PREFIX=''
+_C.DATALOADER.CLASS_NAME_FILE = ''
+
+_C.DATALOADER.VISUAL_FEAT = True
+_C.DATALOADER.ANNO_FOLDER = ''
+_C.DATALOADER.ANNO_FILENAME = None
+_C.DATALOADER.S3_PATH = ''
+_C.DATALOADER.S3_ANNO_FOLDER = None
+_C.DATALOADER.CIRCULAR_CACHE_MODE = False
+_C.DATALOADER.ZIP_MODE = False
+_C.DATALOADER.CACHE_ORIGIN_IMAGE = False
+_C.DATALOADER.RANDOM_CAPTION = True
+_C.DATALOADER.AS_NUMPY_AS_POSSIBLE = False
+
+_C.DATALOADER.RELATION_FILE = ''
+
+_C.DATALOADER.GV_FEAT_FILE = ''
+
+_C.DATALOADER.ATTRIBUTE_FILE = ''
+
+_C.DATALOADER.SEQ_PER_SAMPLE = 5
+_C.DATALOADER.MIN_SEQ_PER_SAMPLE = 5
+
+_C.DATALOADER.MAX_FEAT_NUM = -1
+
+_C.DATALOADER.NEGATIVE_SIZE = -1
+
+_C.DATALOADER.INF_BATCH_SIZE = 200 # for single stream retrieval only, chunk size
+
+_C.DATALOADER.USE_GLOBAL_V = True
+
+_C.DATALOADER.USE_WEIGHTED_SAMPLER = False
+
+_C.DATALOADER.SAMPLING_WEIGHT = 1.0
+_C.DATALOADER.TRANSFORM = ''
+
+
+
+# xiaoshi: added for video cls
+_C.DATALOADER.FRAMES_PER_CLIP = 4
+_C.DATALOADER.STRIDE = 5
+_C.DATALOADER.FILE_EXTENSION = ''
+_C.DATALOADER.ANNO_FILE = 'annotation.json'
+_C.DATALOADER.TIMESFORMER_AUG = False
+
+# hao:
+_C.DATALOADER.DO_AS_RETRIEVAL = False
+
+_C.DATALOADER.USE_CEPH = False
+
+# xiaoshi: added for vqa, specify inference mode
+_C.DATALOADER.DO_AS_GEN = True
+_C.DATALOADER.VQA_INPUT = ['image', 'question']
+_C.DATALOADER.SINGLE_CLASS = False
+_C.DATALOADER.SMALL_VAL = True
+_C.DATALOADER.BLOCK_VQ = False
+_C.DATALOADER.DATA_PERCENTAGE = 1.0
+_C.DATALOADER.TWO_EOT = False
+_C.DATALOADER.DATA_K_SAMPLE = -1
+
+_C.DATALOADER.PIN_MEM = True
+_C.DATALOADER.PREFETCH_FACTOR = 2
+_C.DATALOADER.PADDING_TO_MAX = False
+_C.DATALOADER.LOAD_INLABEL = True
+
+
+
+_C.DATALOADER.MULTI_VEIW_NUM = 1
+_C.DATALOADER.MULTI_VEIW = 'v0'
+
+
+
+
+_C.TASKS = [] # task config
+
+_C.ENCODERS = [] # multi encoder config
+# -----------------------------------------------------------------------------
+# Engine
+# -----------------------------------------------------------------------------
+_C.ENGINE = CN()
+
+_C.ENGINE.NAME = ''
+
+_C.ENGINE.MIXUP = 0.
+_C.ENGINE.CUTMIX = 0.
+_C.ENGINE.MIXUP_PROB = 0.
+_C.ENGINE.MIXUP_SWITCH_PROB = 0.0
+_C.ENGINE.MIXUP_MODE = ''
+_C.ENGINE.MIXUP_LABEL_SMOOTHING = 0.0
+
+# change to dataloader
+_C.DATALOADER.MIXUP = 0.
+_C.DATALOADER.CUTMIX = 0.
+_C.DATALOADER.MIXUP_PROB = 0.
+_C.DATALOADER.MIXUP_SWITCH_PROB = 0.0
+_C.DATALOADER.MIXUP_MODE = ''
+_C.DATALOADER.MIXUP_LABEL_SMOOTHING = 0.0
+
+_C.DATALOADER.MINI_BATCHES = 1
+_C.DATALOADER.SYNC_TASK = False
+_C.DATALOADER.STRATEGY = ''
+_C.DATALOADER.TURN_LOG = True
+_C.DATALOADER.TCS_CONF_PATH = 'petreloss.config'
+_C.DATALOADER.NUM_GTS = 1
+_C.DATALOADER.USE_SEG_ID = False
+
+
+# -----------------------------------------------------------------------------
+# Scheduled sampling
+# -----------------------------------------------------------------------------
+_C.SCHEDULED_SAMPLING = CN()
+
+_C.SCHEDULED_SAMPLING.START_EPOCH = 0
+
+_C.SCHEDULED_SAMPLING.INC_EVERY_EPOCH = 5
+
+_C.SCHEDULED_SAMPLING.INC_PROB = 0.05
+
+_C.SCHEDULED_SAMPLING.MAX_PROB = 0.25
+
+# -----------------------------------------------------------------------------
+# Model
+# -----------------------------------------------------------------------------
+_C.MODEL = CN()
+
+_C.MODEL.DEVICE = "cuda"
+
+_C.MODEL.TEMP_NAME = ""
+
+_C.MODEL.IMG_INPUT_SIZE = 224
+_C.MODEL.PATCH_SIZE = 16
+
+_C.MODEL.BLOCK_IMAGENET = False
+
+_C.MODEL.FAKE_PAD_TO_MAX = False
+
+_C.MODEL.VOCAB_SIZE = 1000 # include /
+
+_C.MODEL.META_ARCHITECTURE = ''
+
+_C.MODEL.ENCODER = ''
+
+_C.MODEL.ENCODER_DIM = 1024
+
+_C.MODEL.DECODER = ''
+
+_C.MODEL.DECODER_DIM = 1024
+
+_C.MODEL.PRED_DROPOUT = 0.0
+
+_C.MODEL.PREDICTOR = ''
+
+_C.MODEL.V_PREDICTOR = ''
+
+_C.MODEL.USE_PREDICTOR_BIAS = False
+_C.MODEL.SHARE_PREDICTOR_HIDDEN = False
+_C.MODEL.SHARE_CLS_NAME_QUERY_EMBED = False
+
+_C.MODEL.PRED_TEMPERATURE = 1.0
+
+_C.MODEL.PRED_USE_NORM = True
+
+_C.MODEL.MAX_SEQ_LEN = 17
+
+_C.MODEL.EVAL_MAX_SEQ_LEN = 17
+
+_C.MODEL.MAX_LABEL_LEN = 5
+
+_C.MODEL.WEIGHTS = ''
+
+_C.MODEL.ITM_NEG_PROB = 0.5
+
+# used for image patch
+_C.MODEL.CLS_TOKEN = False
+# xiaoshi: added for video cls
+_C.MODEL.BACKBONE = 'deit_base'
+_C.MODEL.CENTRAL_FRAME_INIT = False
+
+_C.MODEL.SHARE_MODULES = []
+
+_C.MODEL.PROMPT = False
+
+_C.MODEL.PROMPT_PARAM = []
+_C.MODEL.FC_PROMPT = False
+_C.MODEL.FC_PROMPT_OUT = -1
+_C.MODEL.TWO_LOSS = False
+_C.MODEL.FC_BIAS = 0.0
+_C.MODEL.FC_PROMPT_WEIGHTS = 'learn'
+_C.MODEL.FC_PROMPT_INDEX = -1
+
+
+_C.MODEL.GEN_MASK = True
+_C.MODEL.SKIP_WORD_EMB = False
+_C.MODEL.IN_TUNING = False
+
+# ----------------------------------------------------------------------------
+# Token embedding
+# ----------------------------------------------------------------------------
+_C.MODEL.TOKEN_EMBED = CN()
+
+_C.MODEL.TOKEN_EMBED.NAME = ''
+
+_C.MODEL.TOKEN_EMBED.DIM = 1024
+
+_C.MODEL.TOKEN_EMBED.ACTIVATION = 'none'
+
+_C.MODEL.TOKEN_EMBED.ELU_ALPHA = 0.5
+
+_C.MODEL.TOKEN_EMBED.USE_NORM = False
+
+_C.MODEL.TOKEN_EMBED.DROPOUT = 0.0
+
+_C.MODEL.TOKEN_EMBED.POSITION = 'none'
+
+_C.MODEL.TOKEN_EMBED.POSITION_MAX_LEN = 5000
+
+_C.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE = 0
+
+_C.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE = 0
+
+_C.MODEL.OLD_CHECKPONT = True
+
+# ----------------------------------------------------------------------------
+# Visual embedding
+# ----------------------------------------------------------------------------
+_C.MODEL.VISUAL_EMBED = CN()
+
+_C.MODEL.VISUAL_EMBED.NAME = ''
+
+_C.MODEL.VISUAL_EMBED.IN_DIM = 2048
+
+_C.MODEL.VISUAL_EMBED.OUT_DIM = 1024
+
+_C.MODEL.VISUAL_EMBED.ACTIVATION = 'none'
+
+_C.MODEL.VISUAL_EMBED.ELU_ALPHA = 0.5
+
+_C.MODEL.VISUAL_EMBED.USE_NORM = False
+
+_C.MODEL.VISUAL_EMBED.DROPOUT = 0.0
+
+_C.MODEL.VISUAL_EMBED.LOCATION_SIZE = 0
+
+_C.MODEL.VISUAL_EMBED.TYPE_SIZE = 0 # type embedding for image
+
+_C.MODEL.VISUAL_EMBED.PATCH_SIZE = 16
+
+_C.MODEL.VISUAL_EMBED.IMAGE_SIZE = 224
+
+# video embedding
+
+_C.MODEL.VIDEO_EMBED = CN()
+
+_C.MODEL.VIDEO_EMBED.NAME = ''
+
+_C.MODEL.VIDEO_EMBED.IN_DIM = 2048
+
+_C.MODEL.VIDEO_EMBED.OUT_DIM = 1024
+
+_C.MODEL.VIDEO_EMBED.ACTIVATION = 'none'
+
+_C.MODEL.VIDEO_EMBED.ELU_ALPHA = 0.5
+
+_C.MODEL.VIDEO_EMBED.USE_NORM = False
+
+_C.MODEL.VIDEO_EMBED.DROPOUT = 0.0
+
+_C.MODEL.VIDEO_EMBED.POSITION = 'none'
+
+_C.MODEL.VIDEO_EMBED.MAX_LENGTH = 1000
+
+_C.MODEL.VIDEO_EMBED.TYPE_SIZE = 0 # type embedding for image
+
+_C.MODEL.VIDEO_EMBED.ADD_TYPE_EMBED = False
+
+_C.MODEL.VIDEO_EMBED.PATCH_SIZE_S = 16
+
+_C.MODEL.VIDEO_EMBED.PATCH_SIZE_T = 8
+
+_C.MODEL.VIDEO_EMBED.DIVIDE_ST_POS = False
+
+_C.MODEL.VIDEO_EMBED.USE_VISUAL_TOKENIZER = False
+
+_C.MODEL.VIDEO_EMBED.USE_VISUAL_POS = False
+
+_C.MODEL.VIDEO_EMBED.MAX_FRAMES = 8
+_C.MODEL.VIDEO_EMBED.POS_RANDOM = True
+
+# video tokenizer
+
+_C.MODEL.VIDEO_TOKENIZER = CN()
+
+# _C.MODEL.VIDEO_TOKENIZER.PATCH_SIZE_S = 16
+
+# _C.MODEL.VIDEO_TOKENIZER.PATCH_SIZE_T = 8
+
+_C.MODEL.VIDEO_TOKENIZER.FPS = -1 # -1 means using a fixed number of frames
+
+_C.MODEL.VIDEO_TOKENIZER.NUM_FRAMES = 50 # works only when VIDEO_TOKENIZER.NUM_FRAMES == -1
+
+_C.MODEL.VIDEO_TOKENIZER.SAMPLE_OFFSET = 0
+
+_C.MODEL.VIDEO_TOKENIZER.MAX_FRAMES = 40
+
+
+# xiaoshi: added for video cls
+_C.MODEL.NUM_CLASSES = 339
+
+#
+_C.MODEL.PRETRAIN = False
+
+_C.MODEL.FIX_PRETRAIN_PARAM = True
+
+_C.MODEL.USE_ORIGINAL_CODER = False
+
+
+# prompt embedding
+
+_C.MODEL.PROMPT_EMBED = CN()
+
+_C.MODEL.PROMPT_EMBED.NAME = "none"
+
+_C.MODEL.PROMPT_EMBED.DIM = 512
+
+_C.MODEL.PROMPT_EMBED.PROMPT_LENGTH = 10
+_C.MODEL.PROMPT_EMBED.TARGET_PROMPT_LENGTH = 10
+_C.MODEL.PROMPT_EMBED.INPUT_DEEP_PROMPT_LENGTH = 10
+_C.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT_LENGTH = 10
+
+
+_C.MODEL.PROMPT_EMBED.ACTIVATION = 'none'
+
+_C.MODEL.PROMPT_EMBED.ELU_ALPHA = 0.5
+
+_C.MODEL.PROMPT_EMBED.USE_NORM = False
+
+_C.MODEL.PROMPT_EMBED.DROPOUT = 0.0
+
+_C.MODEL.PROMPT_EMBED.WITH_POS = False
+
+_C.MODEL.PROMPT_EMBED.INPUT_PROMPT = False
+_C.MODEL.PROMPT_EMBED.TARGET_PROMPT = False
+
+_C.MODEL.PROMPT_EMBED.DEEP_PROMPT = False
+_C.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT = False
+_C.MODEL.PROMPT_EMBED.SHARE_DEEP_PROMPT = False
+
+_C.MODEL.PROMPT_EMBED.LABLE_PROMPT = False
+_C.MODEL.PROMPT_EMBED.LABEL_SIZE = 0
+
+# ----------------------------------------------------------------------------
+# Pre-training
+# ----------------------------------------------------------------------------
+_C.MODEL.PRETRAINING = CN()
+
+_C.MODEL.PRETRAINING.MODEL_NAME = 'bert-base-uncased'
+
+_C.MODEL.PRETRAINING.FROM_PRETRAINED = 'bert-base-uncased'
+
+_C.MODEL.PRETRAINING.DO_LOWER_CASE = True
+
+# ----------------------------------------------------------------------------
+# BERT
+# ----------------------------------------------------------------------------
+_C.MODEL.BERT = CN()
+
+_C.MODEL.BERT.SCALE_MULTI_BEFORE = False
+
+_C.MODEL.BERT.DROP_PATH_PROB = 0.0
+
+_C.MODEL.BERT.DROP_PATH_PROB_FIXED = False
+
+
+_C.MODEL.BERT.HIDDEN_SIZE = 512
+
+_C.MODEL.BERT.HIDDEN_DROPOUT_PROB = 0.1
+
+_C.MODEL.BERT.HIDDEN_ACT = "gelu"
+
+_C.MODEL.BERT.NUM_ATTENTION_HEADS = 8
+
+_C.MODEL.BERT.INTERMEDIATE_SIZE = 2048
+
+_C.MODEL.BERT.INTERMEDIATE_DROP = 0.1
+
+_C.MODEL.BERT.FFN_DROPOUT_PROB = 0.1
+
+_C.MODEL.BERT.ATTENTION_PROBS_DROPOUT_PROB = 0.1
+
+_C.MODEL.BERT.V_TARGET_SIZE = 0
+
+_C.MODEL.BERT.NUM_HIDDEN_LAYERS = 12
+
+_C.MODEL.BERT.LAYER_DROP = 0.0
+
+_C.MODEL.BERT.V_NUM_HIDDEN_LAYERS = 6
+
+_C.MODEL.BERT.V_LAYER_DROP = 0.0
+
+_C.MODEL.BERT.NUM_UNDERSTANDING_LAYERS = 6
+
+_C.MODEL.BERT.U_LAYER_DROP = 0.0
+
+_C.MODEL.BERT.NUM_GENERATION_LAYERS = 6
+
+_C.MODEL.BERT.G_LAYER_DROP = 0.0
+
+_C.MODEL.BERT.SKIP_TARGET_ENCODE = False
+_C.MODEL.BERT.NORMALIZE_BEFORE = False
+_C.MODEL.BERT.NORMALIZE_DECISION = ''
+_C.MODEL.BERT.QKV_BIAS = True
+
+_C.MODEL.BERT.UNIFY_QKV = True
+
+_C.MODEL.FEATURE_GATHER = False
+_C.MODEL.FEATURE_GATHER_FORCE = False
+
+_C.MODEL.LEARN_TEMP = False
+
+_C.MODEL.LABELS_NUM = 1000
+_C.MODEL.TRANSFORM = True
+
+_C.MODEL.QUEUE_LEN = 1024
+
+_C.MODEL.SwitchParamsInit = False
+_C.MODEL.TimmParamsInit = False
+_C.MODEL.MAEParamsInit = False
+_C.MODEL.MOCOv3ParamsInit = False
+_C.MODEL.POSEMBEDFIX = False
+_C.MODEL.POSEMBED_SCALE = 1.0
+_C.MODEL.CHECKPOINT_FILETER = True
+_C.MODEL.CHECKPOINT_FILETER_VIDEO = True
+_C.MODEL.TimmParamsInitSTD = 0.02
+_C.MODEL.TimmParamsINIT_EMBEDDING_STD = 0.02
+
+_C.MODEL.SHARE_LAYERNORM = False
+_C.MODEL.BW_WORD_ALONE = False
+
+_C.MODEL.BW_EMBED_SPE = True
+_C.MODEL.WORD_SEPERATE = True
+_C.MODEL.BW_OWD_EMBED = False
+_C.MODEL.TEXT_VISUAL_SEPARATE = False
+_C.MODEL.OUTPUT_PROJ = False
+_C.MODEL.POS_BEFORE = True
+_C.MODEL.LN_FP32 = False
+_C.MODEL.GATE_FP32 = False
+_C.MODEL.TAG_TRANSFORM_FP32 = False
+_C.MODEL.MODEL_EMA = False
+_C.MODEL.MODEL_EMA_DECAY = 0.9999
+_C.MODEL.MODEL_EMA_FORCE_CPU = False
+
+_C.MODEL.LAYER_SCALE = False
+_C.MODEL.LAYER_SCALE_INIT = 1e-5
+_C.MODEL.LAYER_SCALE_FP32 = True
+
+_C.MODEL.MASK_RAND = False
+_C.MODEL.MASK_RATIO = 0.25
+_C.MODEL.MIXUP_ALIGN = False
+
+_C.MODEL.LAYER_TOKEN_MASK = False
+_C.MODEL.LAYER_MASK_IDX = [4]
+_C.MODEL.LAYER_MASK_RATIO = [0.25]
+
+_C.MODEL.TOKEN_EMBED_COPY = False
+_C.MODEL.TOKEN_EMBED_VALID_END = 128
+
+
+# ----------------------------------------------------------------------------
+# Solver
+# ----------------------------------------------------------------------------
+_C.SOLVER = CN()
+
+_C.SOLVER.NAME = 'Adam'
+_C.SOLVER.DEEPSPEED = True
+
+_C.SOLVER.RESUME_OPTIMIZER = False
+
+_C.SOLVER.TORCH_OPTIMIZER = False
+_C.SOLVER.PARAMS_SEPERATE = False
+_C.SOLVER.PARAMS_GROUP = False
+
+_C.SOLVER.TORCH_OPTIMIZER = False
+_C.SOLVER.PARAMS_SEPERATE = False
+_C.SOLVER.PARAMS_GROUP = False
+
+_C.SOLVER.EPOCH = 10
+
+_C.SOLVER.MAX_ITER = 10000
+
+_C.SOLVER.CHECKPOINT_PERIOD = 1
+
+_C.SOLVER.CHECKPOINT_MAX_SAVE = 1000
+
+_C.SOLVER.EVAL_PERIOD = 1
+
+_C.SOLVER.BASE_LR = 0.0005
+
+_C.SOLVER.ACCUM_ITER = 0
+
+_C.SOLVER.BIAS_LR_FACTOR = 1.0
+
+_C.SOLVER.WG_LR_FACTOR = 1.0
+
+_C.SOLVER.LR_DECAY = 0.0
+
+_C.SOLVER.WEIGHT_DECAY = 0.0
+
+_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
+
+_C.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT = True
+
+_C.SOLVER.WEIGHT_DECAY_BIAS = 0.0
+
+_C.SOLVER.WEIGHT_DECAY_WG = 0.0
+
+_C.SOLVER.WEIGHT_DECAY_EMBEDDING = 0.05
+
+_C.SOLVER.OUTPUTPROJ_NOWD = False
+
+_C.SOLVER.INITIAL_ACCUMULATOR_VALUE = 0.0
+
+_C.SOLVER.MOMENTUM = 0.9
+
+_C.SOLVER.DAMPENING = 0.0
+
+_C.SOLVER.NESTEROV = 0.0
+
+_C.SOLVER.ALPHA = 0.99
+
+_C.SOLVER.BETAS = [0.9, 0.999]
+
+_C.SOLVER.EPS = 1e-8
+
+_C.SOLVER.AMSGRAD = False
+
+_C.SOLVER.CENTERED = False
+
+_C.SOLVER.GRAD_CLIP_TYPE = 'norm' # norm, value
+
+_C.SOLVER.GRAD_CLIP = 0.1
+
+_C.SOLVER.MIN_LOSS_SCLE = 2048.0
+_C.SOLVER.LOSS_SCALE_WINDOW = 500
+
+_C.SOLVER.NORM_TYPE = 2.0
+
+_C.SOLVER.WRITE_PERIOD = 20
+_C.SOLVER.GradHistogram = False
+_C.SOLVER.GradHistogramPeriod = 200
+
+_C.SOLVER.COMPUTE_MOE_DECISION = False
+
+
+_C.SOLVER.LOG_GRAD = False
+_C.SOLVER.LOG_GRAD_ITER = 300
+
+
+_C.SOLVER.AMP_FP16 = False
+
+_C.SOLVER.APEX_FP16 = False
+
+_C.SOLVER.APEX_OPT_LEVEL = 'O1'
+_C.SOLVER.APEX_MASTER_WEIGHTS = True
+
+_C.SOLVER.FUSED_LAYERNORM = False
+
+_C.SOLVER.BF16 = False
+
+_C.SOLVER.ZEROSTAGE = 0
+
+# used by xiaoshi in default trainer
+_C.SOLVER.FP16 = False
+
+
+_C.SOLVER.GRAD_PRINT = False
+
+_C.SOLVER.CHECKPOINT_MAPPING = []
+_C.SOLVER.CHECKPOINT_MAP = True
+_C.SOLVER.RESUME_TAU = True
+_C.SOLVER.CHECKPOINT_CEPH_SAVE = False
+
+_C.SOLVER.BALANCE_LOSSESS = False
+_C.SOLVER.BALANCE_LOSSESS_WEIGHT = 0.01
+_C.SOLVER.CONSISTENCE_LOSSESS = 0.01
+_C.SOLVER.DIVEGENCE_LOSSESS = 0.01
+_C.SOLVER.WORD_BALANCE_LOSSESS = False
+_C.SOLVER.IMPORTANCE_LOSS = False
+
+_C.SOLVER.AUGLOSS = False
+_C.SOLVER.AUGLOSS_START = -1
+_C.SOLVER.AUGLOSS_INTERVAL = -1
+_C.SOLVER.AUGLOSS_ENDITER = -1
+
+_C.SOLVER.CROSS_LOSS = False
+
+
+_C.SOLVER.LAYER_LR_DECAY = 1.0
+
+_C.SOLVER.FORCE_SOFTMAX_FP16 = False
+_C.SOLVER.FORCE_LN_FP16 = False
+_C.SOLVER.FORCE_NORM_FP16 = False
+_C.SOLVER.FORCE_TEMP_FP16 = False
+
+_C.SOLVER.FORCE_WG_RECAST = False
+
+_C.SOLVER.FORCE_EXPERT_ADDING_FP16 = False
+_C.SOLVER.FORCE_EMBED_FP16 = False
+
+
+
+# ----------------------------------------------------------------------------
+# lr scheduler
+# ----------------------------------------------------------------------------
+_C.LR_SCHEDULER = CN()
+
+_C.LR_SCHEDULER.NAME = 'StepLR'
+
+_C.LR_SCHEDULER.STEP_SIZE = 3
+
+_C.LR_SCHEDULER.GAMMA = 0.1
+
+_C.LR_SCHEDULER.MODEL_SIZE = -1 # for Noam only
+
+_C.LR_SCHEDULER.FACTOR = 1.0 # for Noam only
+
+_C.LR_SCHEDULER.WARMUP = 1000 # epoch, for WarmupXXX; iteration, for Noam
+
+_C.LR_SCHEDULER.MIN_LR = 0.000001
+
+_C.LR_SCHEDULER.STEPS = (3,) # for WarmupMultiStep only
+
+_C.LR_SCHEDULER.WARMUP_FACTOR = 0.0 # for WarmupMultiStep only
+
+_C.LR_SCHEDULER.WARMUP_METHOD = "linear" # for WarmupMultiStep only
+_C.LR_SCHEDULER.WARMUPTYPE = "linear" # for WarmupMultiStep only
+
+_C.LR_SCHEDULER.MILESTONES = []
+
+# ---------------------------------------------------------------------------- #
+# Losses
+# ---------------------------------------------------------------------------- #
+_C.LOSSES = CN()
+
+_C.LOSSES.NAMES = ['']
+
+_C.LOSSES.LOSS_WEIGHT = 1.0
+
+_C.LOSSES.REDUCTION = 'mean'
+
+_C.LOSSES.LABELSMOOTHING = 0.1
+
+_C.LOSSES.MARGIN = 0.2
+
+_C.LOSSES.LOSS_FP32 = False
+
+_C.LOSSES.MAX_VIOLATION = True
+
+# ---------------------------------------------------------------------------- #
+# SCORER options
+# ---------------------------------------------------------------------------- #
+_C.SCORER = CN()
+
+_C.SCORER.NAME = ''
+
+_C.SCORER.TYPES = ['']
+
+_C.SCORER.WEIGHTS = [1.0]
+
+_C.SCORER.GT_PATH = 'coco_train_gts.pkl'
+
+_C.SCORER.CIDER_CACHED = 'coco_train_cider.pkl'
+
+_C.SCORER.EOS_ID = 0
+
+# ---------------------------------------------------------------------------- #
+# Decode strategy
+# ---------------------------------------------------------------------------- #
+_C.DECODE_STRATEGY = CN()
+
+_C.DECODE_STRATEGY.NAME = 'none'
+
+_C.DECODE_STRATEGY.BEAM_SIZE = 1
+
+_C.DECODE_STRATEGY.LEN_PENALTY = 0.0
+
+# ---------------------------------------------------------------------------- #
+# INFERENCE options
+# ---------------------------------------------------------------------------- #
+_C.INFERENCE = CN()
+
+_C.INFERENCE.NAME = ''
+
+_C.INFERENCE.VOCAB = 'CLIP'
+
+_C.INFERENCE.ID_KEY = 'image_id'
+
+_C.INFERENCE.VALUE = 'caption'
+
+_C.INFERENCE.VAL_ANNFILE = 'captions_val5k.json'
+
+_C.INFERENCE.TEST_ANNFILE = 'captions_test5k.json'
+
+_C.INFERENCE.GENERATION_MODE = True
+
+_C.INFERENCE.VAL_EVAL_START = -1
+
+_C.INFERENCE.TEST_EVAL_START = -1
+
+_C.INFERENCE.ITER_BASED = True
+
+_C.INFERENCE.EVAL_BS = 100
+
+# xiaoshi: added for video cls
+_C.INFERENCE.NUM_VIEWS = 1
+
+# ---------------------------------------------------------------------------- #
+# Misc options
+# ---------------------------------------------------------------------------- #
+_C.OUTPUT_DIR = "./output"
+
+_C.SEED = -1
+
+_C.CUDNN_BENCHMARK = False
+
+_C.find_unused_parameters = True
+
+_C.MOE = CN()
+
+_C.MOE.MOE = False
+
+_C.MOE.EP_WORLD_SIZE = 1
+
+_C.MOE.NUM_EXPERTS = 1
+
+_C.MOE.TOP_K = 1
+
+_C.MOE.CAPACITY_FACTOR = 1.0
+
+_C.MOE.EVAL_MIN_CAPACITY = 1.0
+
+_C.MOE.MIN_CAPACITY = 4
+
+_C.MOE.NOISY_GATE_POLICY = 'RSample'
+
+_C.MOE.USE_RTS = True
+
+_C.MOE.USE_TUTEL = False
+
+_C.MOE.MOE_PARAM_GROUP = True
+
+_C.MOE.MOE_EXPERT_TYPE = 'FFN'
+
+_C.MOE.MOE_EXPERT_LOCATION = 'odd'
+
+_C.MOE.SA_LINEAR_OUT_MOE = False
+
+_C.MOE.KV_SHARED = False
+
+_C.MOE.TASK_MOE = False
+
+_C.MOE.CUSTOM_MOE = False
+
+_C.MOE.MOE_TYPE = 'attribute'
+_C.MOE.ATTRIBUTE_LENGTH = 8
+
+
+
+_C.MOE.GATE_SOURCE = 'spe'
+_C.MOE.LAUX_CONFIG = ''
+_C.MOE.LAUX_ONEHOT = '' # batchonehot sampleonehot
+_C.MOE.LAUX_TYPE = 'std' # batchonehot sampleonehot
+_C.MOE.WORD_LAUX = 'even' # onehot
+_C.MOE.ATTENTION_OUT = 'mean'
+_C.MOE.WORD_EXPERT_REGULARIZER = False
+
+_C.MOE.MOE_LAYER_START_IDX = -1
+_C.MOE.MOE_LAYER_END_IDX = 999
+_C.MOE.BATCH_PRIO = False
+_C.MOE.GATE_TYPE = 'deepspeed'
+_C.MOE.LN_MOE = False
+_C.MOE.FFN_SHARE_GATE_DECISION = False
+_C.MOE.FFN_SA_SHARE_GATE = False
+_C.MOE.FFN_MOE_SEPARATE = False
+_C.MOE.MERGE_EXPERTS = False
+_C.MOE.TAG_Transform = False
+_C.MOE.TAG_Transform_ACT = False
+_C.MOE.TAG_Transform_ALONE = False
+_C.MOE.NOISE_STD = 1.0
+
+_C.SOLVER.FLOPS_PROFILER = False
diff --git a/uniperceiver/datasets/__init__.py b/uniperceiver/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..10ed23f1b4e8c978532d5669f96e770e33bc5883
--- /dev/null
+++ b/uniperceiver/datasets/__init__.py
@@ -0,0 +1,19 @@
+
+from .build import (
+ build_dataset_mapper,
+ build_standard_train_loader,
+ build_standard_valtest_loader,
+ build_unified_train_loader,
+)
+
+from .task_dataset.imagenet import ImageNetDataset, ImageNet22KDataset
+from .task_dataset.image_text_pair import ImageTextPairDataset
+from .task_dataset.general_corpus import GeneralCorpusDataset
+from .task_dataset.video_raw import VideoDataSet
+from .task_dataset.vqa import VQADataset
+from .task_dataset.msvd import MSVDDataset
+from .task_dataset.msrvtt import MSRVTTDataset
+
+from .task_dataset.GLUE import GLUEDataset
+
+from .tcsreader import TCSLoader
diff --git a/uniperceiver/datasets/batch_sampler.py b/uniperceiver/datasets/batch_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..815623ca221f5c5c3c1bf2af2e5d069c4c8bfbd1
--- /dev/null
+++ b/uniperceiver/datasets/batch_sampler.py
@@ -0,0 +1,149 @@
+import numpy as np
+import os
+import math
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+import random
+from uniperceiver.utils import comm
+import itertools
+
+from .sampler import TrainingSampler, NaiveSampler, NodeDistributedSampler
+
+from uniperceiver.datasets.unified_dataset import UnifiedDataset
+try:
+ import deepspeed.utils.groups as groups
+ DEEPSPEED_INSTALLED = True
+except:
+
+ DEEPSPEED_INSTALLED = False
+
+
+
+
+
+class WeightedBatchSampler(torch.utils.data.sampler.BatchSampler):
+ def __init__(self,
+ dataset: UnifiedDataset,
+ cfg,
+ task_cfg,
+ stage='train',
+ shuffle=True,
+ drop_last=True):
+ self.dataset = dataset
+ self.cfg = cfg
+ self.task_cfg = task_cfg
+
+ self._tasks = list(self.task_cfg.keys())
+
+ # dataset_names = self.dataset.dataset_name
+ # dataset_list = self.dataset.dataset_list
+
+ unit_sampler = dict()
+ for name, new_cfg in self.task_cfg.items():
+ if new_cfg.DATASETS.DATASET_NAME in [
+ "MSCOCO", "FLICKR", "ImageNet22k", "ImageNet1k", "VG", "VideoDataSet", "K700", 'K400', 'MiT', 'MSVDDataset', 'MSRVTTDataset',
+ "RTE", "CoLA", "SST-2", "MRPC", "QQP", "QNLI", "MNLI", "MNLI_Match", "VQA"
+ ]:
+ sampler = TrainingSampler(self.dataset.datasets[name])
+ elif new_cfg.DATASETS.DATASET_NAME in ["BooksWiki"]:
+ # block cache
+ sampler = NaiveSampler(self.dataset.datasets[name])
+ elif new_cfg.DATASETS.DATASET_NAME in [
+ # "ImageTextPairDataset", 'SBUDataset', 'TQAPretrain'
+ 'YFCC', 'CC12M', 'CC3M', 'SBU', 'TQAPretrain'
+ ]:
+ sampler = NodeDistributedSampler(
+ self.dataset.datasets[name],
+ shuffle=True,
+ num_replicas=comm.get_world_size(),
+ rank=comm.get_rank(),
+ local_rank=comm.get_local_rank(),
+ local_size=comm.get_local_size())
+ else:
+ raise NotImplementedError(
+ f'please check the sampler used for this dataset {new_cfg.DATASETS.DATASET_NAME}'
+ )
+ unit_sampler[name] = sampler
+ self.unit_sampler = unit_sampler
+
+ self.unit_sampler_iter = {
+ k: iter(v)
+ for k, v in self.unit_sampler.items()
+ }
+
+ self.sampling_weights = {
+ k: v.DATALOADER.SAMPLING_WEIGHT
+ for k, v in self.task_cfg.items()
+ }
+
+ self._weights = [self.sampling_weights[k] for k in self._tasks]
+
+ self.stage = stage
+ if self.stage == 'train':
+ self.task_batch_size = {
+ k: v.DATALOADER.TRAIN_BATCH_SIZE
+ for k, v in self.task_cfg.items()
+ }
+ else:
+ raise NotImplementedError('only train dataset supportted now!')
+
+
+
+ self.len = [ len_ds//bs for len_ds, bs in zip([len(ds) for ds in self.dataset.dataset_list], self.task_batch_size.values())]
+
+ self.special_strategy = cfg.DATALOADER.STRATEGY
+
+ self.count = 0
+
+ self.task_index_offset = {
+ k: v
+ for k, v in zip(self.task_cfg.keys(),self.dataset.dataset_scale.tolist())
+ }
+
+
+ def __len__(self):
+ return sum(self.len)
+
+ def __iter__(self):
+
+ batch = []
+ while True:
+
+ if self.special_strategy == 'uniform':
+ task = self._tasks[comm.get_local_rank() % len(self._tasks)]
+ elif self.special_strategy == 'uniformv2':
+ task = self._tasks[(self.count + comm.get_rank()) %
+ len(self._tasks)]
+ self.count = (self.count + 1) % len(self._tasks)
+ elif self.special_strategy == 'turn':
+ task = self._tasks[self.count % len(self._tasks)]
+ self.count = (self.count + 1) % len(self._tasks)
+ else:
+ task = random.choices(self._tasks,
+ weights=self._weights)[0]
+
+ if self.cfg.MOE.MOE and DEEPSPEED_INSTALLED and groups.expert_parallel_is_initialized(
+ ) and groups.get_expert_data_parallel_world_size() > 1:
+ task = comm.broadcast_object(
+ task,
+ src=comm.get_rank() -
+ comm.get_rank() % groups.get_expert_parallel_world_size(),
+ group=groups.get_expert_parallel_group())
+
+ """
+ all sampler are infinite stream
+ """
+ sample_index_offset = self.task_index_offset[task]
+ for i in range(self.task_batch_size[task]):
+ try:
+ batch.append(
+ next(self.unit_sampler_iter[task]) + sample_index_offset)
+ except:
+ self.unit_sampler_iter[task] = iter(self.unit_sampler[task])
+ batch.append(
+ next(self.unit_sampler_iter[task]) + sample_index_offset)
+
+ assert len(batch) == self.task_batch_size[task]
+ yield batch
+ batch = []
diff --git a/uniperceiver/datasets/build.py b/uniperceiver/datasets/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..95a9ae2978e1b81ae80b221295110c2fb814f9ea
--- /dev/null
+++ b/uniperceiver/datasets/build.py
@@ -0,0 +1,355 @@
+import itertools
+import logging
+import numpy as np
+import operator
+import pickle
+from tabulate import tabulate
+from termcolor import colored
+import torch
+import torch.utils.data
+from torch.utils.data import RandomSampler
+from torch.utils.data.distributed import DistributedSampler
+
+from uniperceiver.config import configurable
+from uniperceiver.utils.comm import get_world_size, get_rank
+from uniperceiver.utils.env import seed_all_rng
+from uniperceiver.utils.file_io import PathManager
+from uniperceiver.utils.logger import log_first_n
+from uniperceiver.utils.registry import Registry
+from .common import DatasetFromList, MapDataset
+
+from uniperceiver.functional import pad_tensor, dict_to_cuda, flat_list_of_lists
+from .sampler import NodeDistributedSampler
+from uniperceiver.utils import comm
+from .sampler import TrainingSampler, NaiveSampler
+from .moe_embeddings import get_moe_embedding, get_embed_with_task_type, get_embed_with_shared_tagert_name
+
+
+
+from functools import partial
+
+DATASETS_REGISTRY = Registry("DATASETS") # noqa F401 isort:skip
+DATASETS_REGISTRY.__doc__ = """
+Registry for datasets, i.e. the whole model
+"""
+
+from uniperceiver.datasets.unified_dataset import UnifiedDataset
+from .batch_sampler import WeightedBatchSampler
+
+
+def build_dataset_mapper(cfg, name, stage):
+ dataset_mapper = DATASETS_REGISTRY.get(name)(cfg, stage)
+ return dataset_mapper
+
+def trivial_batch_collator(batch):
+ return batch
+
+def preprocess_batch_collator(batched_inputs, cfg=dict(), shared_targets=dict()):
+
+ ret = {}
+ if cfg.MOE.MOE:
+ moe_type = cfg.MOE.MOE_TYPE
+ else:
+ moe_type = None
+ # sample lists
+ for data_name in ['input_sample', 'target_sample']:
+ ret[(data_name + '_list')] = []
+ num_data = len(batched_inputs[0][data_name])
+ for i in range(num_data):
+ # All samples in data_list can be either be Tensors or groups (i.e., list of Tensors, [Tensors]).
+ # If the samples in data_list are groups, each element in each group will be padded individually, and then all elements in the same group will be concatenated along axis 1.
+ data_list = [sample[data_name][i]['data'] for sample in batched_inputs]
+ # valid_mask_list = [sample[data_name][i]['valid_mask'] for sample in batched_inputs]
+ modality = batched_inputs[0][data_name][i]['modality']
+ data_type = batched_inputs[0][data_name][i]['data_type']
+ sample_info_list = [sample[data_name][i]['sample_info'] for sample in batched_inputs]
+ padding_value = sample_info_list[0].get('padding_value', 0)
+
+ if isinstance(data_list[0], list):
+ if not batched_inputs[0][data_name][i]['sample_info'].get('sample_alone', False):
+ # some data are concatenated inside one sample, e.g. the caption text part during the training.
+ data_group_size = len(data_list[0])
+ # padding individually for each element in each group
+ data_, valid_mask_ = zip(*[pad_tensor(
+ tensor=[data_group[idx] for data_group in data_list],
+ padding_value=padding_value,
+ use_mask=True) for idx in range(data_group_size)])
+
+ # concatenate all elements in the same group along axis 1
+ data = torch.cat(data_, dim=1)
+ valid_mask = torch.cat(valid_mask_, dim=1)
+ else:
+ # image-text retrieval may have multi-caption for one image when inference, e.g., MSCOCO caption dataset.
+ data_list = flat_list_of_lists(data_list)
+ data, valid_mask = pad_tensor(tensor=data_list, padding_value=padding_value, use_mask=True)
+
+ elif isinstance(data_list[0], torch.Tensor):
+ if sample_info_list[0].get('cat_along_first_dim', False):
+ # concatenate data along the first dimention, e.g.: video data
+ data = torch.cat(data_list, dim=0)
+ valid_mask = None
+ else:
+ data, valid_mask = pad_tensor(tensor=data_list, padding_value=padding_value, use_mask=True) # Do we have valid mask that is not caused by padding? AND 1/0 for what?
+
+ else:
+ raise TypeError
+
+ if valid_mask is not None and valid_mask.all():
+ valid_mask = None
+
+ ret[(data_name + '_list')].append({
+ 'data':
+ data,
+ 'invalid_mask':
+ 1 - valid_mask if valid_mask is not None else None,
+ 'modality':
+ modality,
+ 'data_type':
+ data_type,
+ 'sample_info':
+ sample_info_list,
+ 'moe_embedding':
+ get_embed_with_task_type(moe_type, batched_inputs[0]['task_info']['task_type'], data_type)
+ })
+
+
+ # target sets
+ num_target_sets = len(batched_inputs[0]['target_idx'])
+ # change value to -1 for padding location
+ ret['target_idx_list'] = [ pad_tensor(tensor=[sample['target_idx'][i] for sample in batched_inputs], padding_value=-1, use_mask=False) if isinstance(batched_inputs[0]['target_idx'][i], torch.Tensor) else torch.tensor([sample['target_idx'][i] for sample in batched_inputs] ) for i in range(num_target_sets) ]
+ ret['target_set_list'] = [batched_inputs[0]['target_set'][i] for i in range(num_target_sets)]
+
+ # shared target sets
+ ret['shared_target_sets'] = {}
+ for k in shared_targets:
+ padding_value = shared_targets[k]['sample_info'].get('padding_value', 0)
+ if isinstance(shared_targets[k]['data'][0], list):
+ data_list = [d[np.random.randint(0, len(d))] for d in shared_targets[k]['data']] # Randomly choose one for each list
+ else:
+ data_list = shared_targets[k]['data']
+
+ data, valid_mask = pad_tensor(tensor=data_list, padding_value=padding_value, use_mask=True)
+ if valid_mask.all():
+ valid_mask = None
+ ret['shared_target_sets'][k] = [{
+ 'data': data,
+ 'invalid_mask': 1 - valid_mask if valid_mask is not None else None,
+ 'modality': shared_targets[k]['modality'],
+ 'data_type': 'target',
+ 'sample_info': shared_targets[k]['sample_info'],
+ 'moe_embedding': get_embed_with_shared_tagert_name(moe_type, k)
+ }]
+
+ # task info
+ ret['task_info'] = batched_inputs[0]['task_info'] # should task_name be put into task_info?
+
+ ret['task_info']['task_name'] = batched_inputs[0].get('task_name', None)
+
+
+ return ret
+
+
+
+def worker_init_reset_seed(worker_id):
+ seed_all_rng(np.random.randint(2 ** 31) + worker_id)
+
+def load_pkl_file(filepath):
+ return pickle.load(open(filepath, 'rb'), encoding='bytes') if len(filepath) > 0 else None
+
+def load_shared_targets(cfg, stage='train'):
+ shared_targets_cfg = cfg.SHARED_TARGETS
+ shared_targets = {}
+ for shared_target_cfg in shared_targets_cfg:
+ name = shared_target_cfg['NAME']
+
+ if (stage != 'train') and (name not in cfg.DATASETS.TARGET_SET):
+ # For validation and test, we build a dataloader for each task / dataset.
+ # Therefore, the dataloader only needs to load its corresponding shared target set.
+ continue
+
+ # For validation and test, we do not distribute the shared targets
+ distributed = shared_target_cfg['SHARED_TARGETS_CFG']['DISTRIBUTED'] and (stage == 'train')
+
+ shared_targets[name] = load_pkl_file(shared_target_cfg['SHARED_TARGETS_CFG']['FILE_PATH'])
+
+ data = shared_targets[name]['data']
+ if isinstance(data[0], list):
+ max_len = max([len(t) for tl in data for t in tl])
+ else:
+ max_len = max([len(t) for t in data])
+ shared_targets[name]['sample_info'] = {'distributed': distributed, 'max_len': max_len}
+
+ if distributed:
+ world_size = get_world_size()
+ rank = get_rank()
+ total_num = len(shared_targets[name]['data'])
+ local_num = int(np.ceil(total_num / world_size))
+
+ # we pad the shared_targets to a value that can be divided by WORLD_SIZE with no remainer, and then slice it
+ if local_num * world_size > total_num:
+ data = data + [data[0] for _ in range(local_num * world_size - total_num)]
+ shared_targets[name]['data'] = data[rank * local_num : (rank + 1) * local_num]
+
+ # compute the real (unpadded) length of the local slice
+ start_idx = min(rank * local_num, total_num)
+ end_idx = min((rank + 1) * local_num, total_num)
+
+ shared_targets[name]['sample_info'].update({
+ 'total_num': total_num,
+ 'local_num': end_idx - start_idx,
+ 'world_size': world_size,
+ 'rank': rank
+ })
+
+ return shared_targets
+
+
+
+def build_unified_train_loader(cfg, task_cfg, model=None):
+ dataset = UnifiedDataset(cfg, task_cfg, stage="train")
+ batchsampler = WeightedBatchSampler(dataset, cfg, task_cfg)
+ shared_targets = load_shared_targets(cfg)
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ batch_sampler=batchsampler,
+ # sampler=sampler,
+ # batch_size=cfg.DATALOADER.TRAIN_BATCH_SIZE,
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
+ collate_fn=partial(preprocess_batch_collator, shared_targets=shared_targets, cfg=cfg),
+ pin_memory=cfg.DATALOADER.PIN_MEM,
+ worker_init_fn=worker_init_reset_seed,
+ # drop_last=True,
+ prefetch_factor=cfg.DATALOADER.PREFETCH_FACTOR, # default: 2
+ persistent_workers=cfg.DATALOADER.NUM_WORKERS>0)
+
+
+ return dataloader
+
+
+def build_standard_train_loader(cfg, model=None):
+ dataset = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="train")
+ if cfg.DATASETS.TRAIN in [ "ImageTextPairDataset", "ImageNet22KDataset", "ImageNetDataset", "VGPretrain", "VideoDataSet", "VQADataset" ]:
+ sampler = TrainingSampler(dataset)
+ elif cfg.DATASETS.TRAIN in ["GeneralCorpusDataset"]:
+ sampler = NaiveSampler(dataset)
+ else:
+ sampler = NodeDistributedSampler(
+ dataset, shuffle=True,
+ num_replicas=comm.get_world_size(), rank=comm.get_rank(),
+ local_rank=comm.get_local_rank(), local_size=comm.get_local_size())
+ # sampler = TrainingSampler(dataset)
+ dataloader = torch.utils.data.DataLoader(
+ dataset=dataset,
+ sampler=sampler,
+ batch_size=cfg.DATALOADER.TRAIN_BATCH_SIZE,
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
+ collate_fn=partial(preprocess_batch_collator, model=model),
+ pin_memory=cfg.DATALOADER.PIN_MEM,
+ worker_init_fn=worker_init_reset_seed,
+ drop_last=True,
+ persistent_workers=True)
+ return dataloader
+
+
+def _single_modal_dataset(cfg, dataset_mapper=None, *, datalist=None, sampler=None):
+ if len(cfg.DATASETS.TRAIN) > 0:
+ if dataset_mapper is None:
+ dataset_mapper = build_dataset_mapper(cfg, name=cfg.DATASETS.TRAIN, stage="train")
+ if datalist is None:
+ datalist = dataset_mapper.load_data(cfg)
+ else:
+ dataset_mapper = None
+ datalist = None
+ return datalist, dataset_mapper
+
+
+def _train_loader_from_config(cfg,
+ dataset_mapper=None,
+ *,
+ datalist=None,
+ sampler=None,
+ model=None):
+ # xiaoshi: mscoco image captioning: called from defaulttainer, only cfg passed
+ datalist, dataset_mapper = _single_modal_dataset(
+ cfg, dataset_mapper=dataset_mapper, datalist=datalist, sampler=sampler)
+
+ return {
+ "datalist": datalist,
+ "dataset_mapper": dataset_mapper,
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
+ "batch_size": cfg.DATALOADER.TRAIN_BATCH_SIZE,
+ "cfg": cfg,
+ "model": model,
+ }
+
+
+
+def _valtest_loader_from_config(cfg, dataset_mapper=None, *, datalist=None, sampler=None, stage="val"):
+ dataset_names = {
+ "val": cfg.DATASETS.VAL,
+ "test": cfg.DATASETS.TEST,
+ }
+ dataset_name = dataset_names[stage]
+ if len(dataset_name) > 0:
+ if dataset_mapper is None:
+ dataset_mapper = build_dataset_mapper(cfg, name=dataset_name, stage=stage)
+ if datalist is None:
+ datalist = dataset_mapper.load_data(cfg)
+ else:
+ dataset_mapper = None
+ datalist = None
+
+ if dataset_name in ['Flickr30kDatasetForSingleStreamVal', 'Flickr30kDatasetForSingleStreamValV2']:
+ multi_gpu_eval = True
+ batch_size = 1
+ else:
+ multi_gpu_eval = False
+ batch_size = cfg.DATALOADER.TEST_BATCH_SIZE
+
+ return {
+ "datalist": datalist,
+ "dataset_mapper": dataset_mapper,
+ "num_workers": cfg.DATALOADER.NUM_WORKERS,
+ "batch_size": batch_size,
+ "multi_gpu_eval": multi_gpu_eval,
+ "cfg": cfg,
+ "stage": stage
+ }
+
+
+def build_standard_valtest_loader(cfg, task_cfg, stage, multi_gpu_eval):
+ dataset_names = {
+ "val": cfg.DATASETS.VAL,
+ "test": cfg.DATASETS.TEST,
+ }
+ dataset_name = dataset_names[stage]
+ if len(dataset_name) > 0:
+ dataset = build_dataset_mapper(cfg, name=dataset_name, stage=stage)
+ else:
+ return None
+
+ shared_targets = load_shared_targets(cfg, stage=stage)
+
+ if multi_gpu_eval and get_world_size() > 1:
+ # multi-gpu-eval for single stream retrieval
+ sampler = DistributedSampler(dataset, shuffle=True)
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=cfg.DATALOADER.TEST_BATCH_SIZE,
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
+ drop_last=False,
+ sampler=sampler,
+ collate_fn=partial(preprocess_batch_collator, shared_targets=shared_targets, cfg=cfg),
+ pin_memory=cfg.DATALOADER.PIN_MEM,
+ )
+ else:
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ batch_size=cfg.DATALOADER.TEST_BATCH_SIZE,
+ num_workers=cfg.DATALOADER.NUM_WORKERS,
+ drop_last=False,
+ shuffle=False,
+ collate_fn=partial(preprocess_batch_collator, shared_targets=shared_targets, cfg=cfg),
+ pin_memory=cfg.DATALOADER.PIN_MEM,
+ )
+ return data_loader
\ No newline at end of file
diff --git a/uniperceiver/datasets/circular_cached_loader.py b/uniperceiver/datasets/circular_cached_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..787ee2c9da285c0a07e56f342c5b8fd72fdcc433
--- /dev/null
+++ b/uniperceiver/datasets/circular_cached_loader.py
@@ -0,0 +1,192 @@
+import queue
+import random
+from threading import Thread
+import time
+
+import pyarrow as pa
+import torch.multiprocessing as multiprocessing
+
+import torch
+from copy import deepcopy
+
+string_classes = (str, bytes)
+import collections.abc as container_abcs
+import re
+
+def pin_memory(data):
+ if isinstance(data, torch.Tensor):
+ return data.pin_memory()
+ elif isinstance(data, string_classes):
+ return data
+ elif isinstance(data, container_abcs.Mapping):
+ return {k: pin_memory(sample) for k, sample in data.items()}
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
+ return type(data)(*(pin_memory(sample) for sample in data))
+ elif isinstance(data, container_abcs.Sequence):
+ return [pin_memory(sample) for sample in data]
+ elif hasattr(data, "pin_memory"):
+ return data.pin_memory()
+ else:
+ return data
+
+
+np_str_obj_array_pattern = re.compile(r'[SaUO]')
+default_collate_err_msg_format = (
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
+ "dicts or lists; found {}")
+
+
+def default_collate(batch):
+ r"""Puts each data field into a tensor with outer dimension batch size"""
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ return torch.stack(batch, 0, out=out)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ elem = batch[0]
+ if elem_type.__name__ == 'ndarray':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return default_collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, container_abcs.Mapping):
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, container_abcs.Sequence):
+ transposed = zip(*batch)
+ return [default_collate(samples) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
+
+
+class CircularCachedInputIterator(object):
+ """
+ chunk: a serialized List[Dict] in the apache arrow format,
+ could be sequentially loaded into memory with minimum deserialization cost(<1ms)
+ shard: a part of dataset which is allocated to a specific rank(process) in the world,
+ generally contains multiple chunks
+
+ main thread:
+ - populate chunk_index_queue
+ - swap new chunk and old chunk
+ prefetch threads:
+ - fetch chunk_index_queue
+ - populate loaded_chunk_queue
+
+ chunk_index_queue: main -> prefetch, used for shuffling chunk order per epoch
+ loaded_chunk_queue: preftch -> main, a limited-size channel for prefetching worker to send back result
+ """
+ def __init__(self,
+ input_map,
+ batch_size,
+ chunk_path_list,
+ num_data_point,
+ num_shards,
+ shard_id,
+ random_shuffle,
+ num_prefetch_chunk=4,
+ num_worker=4):
+ self.input_map = input_map
+ self.batch_size = batch_size
+ self.num_shareds = num_shards
+ self.shard_id = shard_id
+ self.random_shuffle = random_shuffle
+ self.num_data_point = num_data_point
+ self.chunk_filename_list = chunk_path_list
+ self.chunk = None
+ self.next_chunk_queue = queue.Queue(num_prefetch_chunk)
+ self.index_queue = queue.Queue()
+ self.chunk_index_queue = queue.Queue()
+ self.num_chunk_in_shard = None
+ self.chunk_indexes = None
+ self.worker = None
+ self.num_worker = num_worker
+ self.setup_shard()
+ self.warmup_cache()
+
+ def setup_shard(self):
+ # ensure each shard has the same of of chunks per epoch
+ # this might not be necessary
+ self.num_chunk_in_shard = len(self.chunk_filename_list) // self.num_shareds
+ # [start, end)
+ shard_start = self.num_chunk_in_shard * self.shard_id
+ shard_end = len(self.chunk_filename_list) if self.shard_id == self.num_shareds - 1 else self.num_chunk_in_shard * (self.shard_id + 1)
+ self.chunk_indexes = list(range(shard_start, shard_end))
+
+ def _chunk_prefetch_worker(self):
+ while True:
+ chunk_index = self.get_chunk_index()
+ chunk_filename = self.chunk_filename_list[chunk_index]
+ with open(chunk_filename, "rb") as fin:
+ chunk = pa.deserialize_from(fin, None)
+ self.next_chunk_queue.put(chunk)
+
+ def warmup_cache(self):
+ self.worker = [Thread(target=self._chunk_prefetch_worker, args=[]) for _ in range(self.num_worker)]
+ for worker in self.worker:
+ worker.daemon = True
+ worker.start()
+
+ def get_chunk_index(self):
+ if self.chunk_index_queue.empty():
+ if self.random_shuffle:
+ random.shuffle(self.chunk_indexes)
+ for ind in self.chunk_indexes[:self.num_chunk_in_shard]:
+ self.chunk_index_queue.put(ind)
+ return self.chunk_index_queue.get()
+
+ def get_index(self):
+ if self.index_queue.empty():
+ if self.chunk is not None:
+ del self.chunk # release memory
+ self.chunk = self.next_chunk_queue.get()
+ self.indexes = list(range(len(self.chunk)))
+ if self.random_shuffle:
+ random.shuffle(self.indexes)
+ # keep all shards of the same size
+ for ind in self.indexes:
+ self.index_queue.put(ind)
+ return self.index_queue.get()
+
+ def epoch_size(self):
+ return self.num_data_point // self.num_shareds
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ datas = tuple([] for _ in self.input_map)
+ for _ in range(self.batch_size):
+ ind = self.get_index()
+ data = self.chunk[ind]
+ # value = data['jpeg']
+ # label = data['label']
+ # # DO NOT reuse the buffer
+ # jpegs.append(value)
+ # labels.append(np.array([label], dtype=np.int32))
+ # datas.append(data)
+ for i, k in enumerate(self.input_map):
+ datas[i].append(deepcopy(data[k]))
+ return datas
+
+ next = __next__
+
diff --git a/uniperceiver/datasets/common.py b/uniperceiver/datasets/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f0e823e1b4c01e0a5a61dc415ec1eaab55063bb
--- /dev/null
+++ b/uniperceiver/datasets/common.py
@@ -0,0 +1,244 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import copy
+import itertools
+import logging
+import numpy as np
+import pickle
+import random
+import torch.utils.data as data
+from torch.utils.data.sampler import Sampler
+
+from uniperceiver.utils.serialize import PicklableWrapper
+
+__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"]
+
+
+def _shard_iterator_dataloader_worker(iterable):
+ # Shard the iterable if we're currently inside pytorch dataloader worker.
+ worker_info = data.get_worker_info()
+ if worker_info is None or worker_info.num_workers == 1:
+ # do nothing
+ yield from iterable
+ else:
+ yield from itertools.islice(iterable, worker_info.id, None, worker_info.num_workers)
+
+
+class _MapIterableDataset(data.IterableDataset):
+ """
+ Map a function over elements in an IterableDataset.
+
+ Similar to pytorch's MapIterDataPipe, but support filtering when map_func
+ returns None.
+
+ This class is not public-facing. Will be called by `MapDataset`.
+ """
+
+ def __init__(self, dataset, map_func):
+ self._dataset = dataset
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __iter__(self):
+ for x in map(self._map_func, self._dataset):
+ if x is not None:
+ yield x
+
+
+class MapDataset(data.Dataset):
+ """
+ Map a function over the elements in a dataset.
+ """
+
+ def __init__(self, dataset, map_func):
+ """
+ Args:
+ dataset: a dataset where map function is applied. Can be either
+ map-style or iterable dataset. When given an iterable dataset,
+ the returned object will also be an iterable dataset.
+ map_func: a callable which maps the element in dataset. map_func can
+ return None to skip the data (e.g. in case of errors).
+ How None is handled depends on the style of `dataset`.
+ If `dataset` is map-style, it randomly tries other elements.
+ If `dataset` is iterable, it skips the data and tries the next.
+ """
+ self._dataset = dataset
+ self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work
+
+ self._rng = random.Random(42)
+ self._fallback_candidates = set(range(len(dataset)))
+
+ def __new__(cls, dataset, map_func):
+ is_iterable = isinstance(dataset, data.IterableDataset)
+ if is_iterable:
+ return _MapIterableDataset(dataset, map_func)
+ else:
+ return super().__new__(cls)
+
+ def __getnewargs__(self):
+ return self._dataset, self._map_func
+
+ def __len__(self):
+ return len(self._dataset)
+
+ def __getitem__(self, idx):
+ retry_count = 0
+ cur_idx = int(idx)
+
+ while True:
+ data = self._map_func(self._dataset[cur_idx])
+ if data is not None:
+ self._fallback_candidates.add(cur_idx)
+ return data
+
+ # _map_func fails for this idx, use a random new index from the pool
+ retry_count += 1
+ self._fallback_candidates.discard(cur_idx)
+ cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0]
+
+ if retry_count >= 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Failed to apply `_map_func` for idx: {}, retry count: {}".format(
+ idx, retry_count
+ )
+ )
+
+
+class DatasetFromList(data.Dataset):
+ """
+ Wrap a list to a torch Dataset. It produces elements of the list as data.
+ """
+
+ def __init__(self, lst: list, copy: bool = True, serialize: bool = True):
+ """
+ Args:
+ lst (list): a list which contains elements to produce.
+ copy (bool): whether to deepcopy the element when producing it,
+ so that the result can be modified in place without affecting the
+ source in the list.
+ serialize (bool): whether to hold memory using serialized objects, when
+ enabled, data loader workers can use shared RAM from master
+ process instead of making a copy.
+ """
+ self._lst = lst
+ self._copy = copy
+ self._serialize = serialize
+
+ def _serialize(data):
+ buffer = pickle.dumps(data, protocol=-1)
+ return np.frombuffer(buffer, dtype=np.uint8)
+
+ if self._serialize:
+ logger = logging.getLogger(__name__)
+ logger.info(
+ "Serializing {} elements to byte tensors and concatenating them all ...".format(
+ len(self._lst)
+ )
+ )
+ self._lst = [_serialize(x) for x in self._lst]
+ self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64)
+ self._addr = np.cumsum(self._addr)
+ self._lst = np.concatenate(self._lst)
+ logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2))
+
+ def __len__(self):
+ if self._serialize:
+ return len(self._addr)
+ else:
+ return len(self._lst)
+
+ def __getitem__(self, idx):
+ if self._serialize:
+ start_addr = 0 if idx == 0 else self._addr[idx - 1].item()
+ end_addr = self._addr[idx].item()
+ bytes = memoryview(self._lst[start_addr:end_addr])
+ return pickle.loads(bytes)
+ elif self._copy:
+ return copy.deepcopy(self._lst[idx])
+ else:
+ return self._lst[idx]
+
+
+class ToIterableDataset(data.IterableDataset):
+ """
+ Convert an old indices-based (also called map-style) dataset
+ to an iterable-style dataset.
+ """
+
+ def __init__(self, dataset: data.Dataset, sampler: Sampler, shard_sampler: bool = True):
+ """
+ Args:
+ dataset: an old-style dataset with ``__getitem__``
+ sampler: a cheap iterable that produces indices to be applied on ``dataset``.
+ shard_sampler: whether to shard the sampler based on the current pytorch data loader
+ worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple
+ workers, it is responsible for sharding its data based on worker id so that workers
+ don't produce identical data.
+
+ Most samplers (like our TrainingSampler) do not shard based on dataloader worker id
+ and this argument should be set to True. But certain samplers may be already
+ sharded, in that case this argument should be set to False.
+ """
+ assert not isinstance(dataset, data.IterableDataset), dataset
+ assert isinstance(sampler, Sampler), sampler
+ self.dataset = dataset
+ self.sampler = sampler
+ self.shard_sampler = shard_sampler
+
+ def __iter__(self):
+ if not self.shard_sampler:
+ sampler = self.sampler
+ else:
+ # With map-style dataset, `DataLoader(dataset, sampler)` runs the
+ # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))`
+ # will run sampler in every of the N worker. So we should only keep 1/N of the ids on
+ # each worker. The assumption is that sampler is cheap to iterate so it's fine to
+ # discard ids in workers.
+ sampler = _shard_iterator_dataloader_worker(self.sampler)
+ for idx in sampler:
+ yield self.dataset[idx]
+
+ def __len__(self):
+ return len(self.sampler)
+
+
+class AspectRatioGroupedDataset(data.IterableDataset):
+ """
+ Batch data that have similar aspect ratio together.
+ In this implementation, images whose aspect ratio < (or >) 1 will
+ be batched together.
+ This improves training speed because the images then need less padding
+ to form a batch.
+
+ It assumes the underlying dataset produces dicts with "width" and "height" keys.
+ It will then produce a list of original dicts with length = batch_size,
+ all with similar aspect ratios.
+ """
+
+ def __init__(self, dataset, batch_size):
+ """
+ Args:
+ dataset: an iterable. Each element must be a dict with keys
+ "width" and "height", which will be used to batch data.
+ batch_size (int):
+ """
+ self.dataset = dataset
+ self.batch_size = batch_size
+ self._buckets = [[] for _ in range(2)]
+ # Hard-coded two aspect ratio groups: w > h and w < h.
+ # Can add support for more aspect ratio groups, but doesn't seem useful
+
+ def __iter__(self):
+ for d in self.dataset:
+ w, h = d["width"], d["height"]
+ bucket_id = 0 if w > h else 1
+ bucket = self._buckets[bucket_id]
+ bucket.append(d)
+ if len(bucket) == self.batch_size:
+ data = bucket[:]
+ # Clear bucket first, because code after yield is not
+ # guaranteed to execute
+ del bucket[:]
+ yield data
diff --git a/uniperceiver/datasets/custom_transforms.py b/uniperceiver/datasets/custom_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ddbebf223195b0ff750c2f8da1fd67a8f4ca8c0
--- /dev/null
+++ b/uniperceiver/datasets/custom_transforms.py
@@ -0,0 +1,40 @@
+from torchvision import transforms as T
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ from PIL import Image
+ BICUBIC = Image.BICUBIC
+
+
+def clip_transforms(mode='train', img_size=224, flip_prob=0.5):
+ assert mode in ['train', 'test', 'val']
+ min_size = img_size
+ max_size = img_size
+ # assert min_size <= max_size
+
+
+ normalize_transform = T.Normalize(
+ mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
+ )
+
+ if mode == 'train':
+ transform = T.Compose(
+ [
+ T.Resize(max_size, BICUBIC),
+ T.RandomCrop(min_size),
+ T.RandomHorizontalFlip(flip_prob),
+ T.ToTensor(),
+ normalize_transform,
+ ]
+ )
+ else:
+ transform = T.Compose(
+ [
+ T.Resize(max_size, BICUBIC),
+ T.CenterCrop(min_size),
+ T.ToTensor(),
+ normalize_transform,
+ ]
+ )
+ return transform
diff --git a/uniperceiver/datasets/moe_embeddings.py b/uniperceiver/datasets/moe_embeddings.py
new file mode 100644
index 0000000000000000000000000000000000000000..69c7db4fe009067e5954f8202c02b0e068c8f98f
--- /dev/null
+++ b/uniperceiver/datasets/moe_embeddings.py
@@ -0,0 +1,95 @@
+import torch
+
+def get_moe_embedding(moe_type):
+
+ if moe_type == 'attribute':
+ Task_attribute = {
+ # task input -- TASK_TYPE & data_type
+ 'image_classification': {
+ "input":
+ torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
+ },
+ 'video_classification': {
+ "input":
+ torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
+ },
+ 'text_mlm': {
+ "input":
+ torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
+ },
+ 'image_caption': {
+ "input":
+ torch.tensor(
+ [[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
+ dtype=torch.float)
+ },
+ 'video_caption': {
+ "input":
+ torch.tensor(
+ [[1, 1, 0, 1, 1, 0, 0, 0], [1, 1, 0, 1, 0, 1, 0, 1]],
+ dtype=torch.float)
+ },
+ 'image_retrieval': {
+ 'input':
+ torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
+ 'target':
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ },
+ 'video_retrieval': {
+ 'input':
+ torch.tensor([[1, 0, 0, 1, 1, 0, 0, 0]], dtype=torch.float),
+ 'target':
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ },
+ 'text_classification': {
+ "input":
+ torch.tensor([[0, 1, 0, 1, 0, 1, 0, 0]], dtype=torch.float),
+ },
+
+
+ # SHARED_TARGETS
+ "ImageNet1k":
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "ImageNet22k":
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "MomentsInTime":
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "Kinetics700":
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "Kinetics400":
+ torch.tensor([[1, 0, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "Vocab_Word":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "CoLA-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "MNLI-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "MRPC-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "QNLI-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "QQP-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "RTE-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ "SST-2-target":
+ torch.tensor([[1, 1, 0, 1, 0, 1, 1, 0]], dtype=torch.float),
+ }
+ return Task_attribute
+ else:
+ raise NotImplementedError(f'please check MOE_TYPE {moe_type}')
+
+
+
+def get_embed_with_task_type(moe_type: str, task_type: str, data_type: str):
+ if moe_type is None:
+ return None
+ embedding_dict = get_moe_embedding(moe_type)
+ return embedding_dict[task_type][data_type]
+
+
+def get_embed_with_shared_tagert_name(moe_type: str, set_name: str,):
+ if moe_type is None:
+ return None
+ embedding_dict = get_moe_embedding(moe_type)
+ return embedding_dict[set_name]
diff --git a/uniperceiver/datasets/sampler.py b/uniperceiver/datasets/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..388ab8a51ffbde4716214b7792f27689568e9c22
--- /dev/null
+++ b/uniperceiver/datasets/sampler.py
@@ -0,0 +1,310 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+# Code is copy-pasted exactly as in torch.utils.data.distributed.
+# FIXME remove this once c10d fixes the bug it has
+import os
+import math
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+import random
+from uniperceiver.utils import comm
+import itertools
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, dataset_repeat=1):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+ self.dataset_repeat = dataset_repeat
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset : offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ repeated_indices = []
+ for _ in range(self.dataset_repeat):
+ repeated_indices += torch.tensor(indices)[torch.randperm(len(indices), generator=g)].tolist()
+
+ return iter(repeated_indices)
+
+ def __len__(self):
+ return self.num_samples * self.dataset_repeat
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+class TrainingSampler(Sampler):
+ """
+ In training, we only care about the "infinite stream" of training data.
+ So this sampler produces an infinite stream of indices and
+ all workers cooperate to correctly shuffle the indices and sample different indices.
+
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
+ where `indices` is an infinite stream of indices consisting of
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
+ or `range(size) + range(size) + ...` (if shuffle is False)
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, seed = None):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) -1
+ self.total_size = len(dataset)
+ self.shuffle = shuffle
+ # self.dataset_repeat = dataset_repeat
+ if seed is None:
+ seed = comm.shared_random_seed()
+ self.seed = int(seed)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __iter__(self):
+ start = self.rank
+ yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.randperm(self.total_size, generator=g).tolist()
+ else:
+ yield from torch.arange(self.total_size).tolist()
+
+class NaiveSampler(Sampler):
+ """
+ In training, we only care about the "infinite stream" of training data.
+ So this sampler produces an infinite stream of indices and
+ all workers cooperate to correctly shuffle the indices and sample different indices.
+
+ The samplers in each worker effectively produces `indices[worker_id::num_workers]`
+ where `indices` is an infinite stream of indices consisting of
+ `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True)
+ or `range(size) + range(size) + ...` (if shuffle is False)
+
+ for bookswiki node-block cache
+
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True, seed = None):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size() // comm.get_local_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = comm.get_rank() // comm.get_local_size()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples =int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) -1
+ self.total_size = len(dataset)
+ self.shuffle = shuffle
+ # self.dataset_repeat = dataset_repeat
+ if seed is None:
+ seed = comm.shared_random_seed()
+ self.seed = int(seed)
+
+ def __len__(self):
+ return self.num_samples
+
+ def __iter__(self):
+ start = self.rank
+ yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.randperm(self.total_size, generator=g).tolist()
+ else:
+ yield from torch.arange(self.total_size).tolist()
+
+class NodeDistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ if local_rank is None:
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ if local_size is None:
+ local_size = int(os.environ.get('LOCAL_SIZE', 1))
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.num_replicas = num_replicas
+ self.num_parts = local_size
+ self.rank = rank
+ self.local_rank = local_rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
+
+ self.indices = [i for i in range(len(self.dataset)) if i % self.num_parts == self.local_rank]
+
+ seed = comm.shared_random_seed()
+ self.seed = int(seed)
+
+ def __iter__(self):
+ start = self.rank // self.num_parts
+ yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas // self.num_parts)
+
+ def _infinite_indices(self):
+ g = torch.Generator()
+ g.manual_seed(self.seed)
+ while True:
+ if self.shuffle:
+ yield from torch.tensor(self.indices)[torch.randperm(len(self.indices), generator=g)].tolist()
+ else:
+ yield from self.indices
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+
+
+
+
+class NodeDistributedSampler_bak(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ if local_rank is None:
+ local_rank = int(os.environ.get('LOCAL_RANK', 0))
+ if local_size is None:
+ local_size = int(os.environ.get('LOCAL_SIZE', 1))
+ self.dataset = dataset
+ self.shuffle = shuffle
+ self.num_replicas = num_replicas
+ self.num_parts = local_size
+ self.rank = rank
+ self.local_rank = local_rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+
+ self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+ indices = [i for i in indices if i % self.num_parts == self.local_rank]
+
+ # add extra samples to make it evenly divisible
+ indices += indices[:(self.total_size_parts - len(indices))]
+ assert len(indices) == self.total_size_parts
+
+ # subsample
+ indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
\ No newline at end of file
diff --git a/uniperceiver/datasets/task_dataset/GLUE.py b/uniperceiver/datasets/task_dataset/GLUE.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a781eb506ba52b74d72efe65a35c64449df1f97
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/GLUE.py
@@ -0,0 +1,379 @@
+import os
+import copy
+import pickle
+import random
+import json
+import glob
+import numpy as np
+from uniperceiver.config import configurable
+from uniperceiver.functional import dict_as_tensor
+from uniperceiver.tokenization import ClipTokenizer
+from ..build import DATASETS_REGISTRY
+import pyarrow as pa
+
+__all__ = ["GLUEDataset"]
+
+
+@DATASETS_REGISTRY.register()
+class GLUEDataset:
+ @configurable
+ def __init__(
+ self,
+ cfg: dict,
+ stage: str,
+ anno_file: str,
+ max_seq_len: int,
+ tokenizer,
+ tokenizer_name,
+ input_columns,
+ label_column,
+ input_count,
+ task_name,
+ data_percentage,
+ data_k_sample,
+ ):
+ self.cfg = cfg
+ self.stage = stage
+ self.anno_file = anno_file
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.max_seq_len = max_seq_len
+
+ self.input_columns = input_columns
+ self.label_column = label_column
+ self.input_count = input_count
+
+ self.task_name = task_name
+
+ self.data_percentage = data_percentage
+ self.data_k_sample = data_k_sample
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT,
+ }
+ self.target_set = cfg.DATASETS.TARGET_SET
+
+ self.load_data(cfg)
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ task_name = cfg.DATASETS.DATASET_NAME
+ namesmapping = {
+ "train": "train",
+ "val": "dev",
+ "test": "test",
+ }
+ data_dir = cfg.DATALOADER.ANNO_FOLDER
+ if task_name in ['MNLI', 'QNLI', 'QQP', 'RTE', 'SST-2', 'MRPC', 'CoLA', 'STS-B']:
+ anno_file = os.path.join(data_dir, task_name, 'processed/{name}.tsv'.format(name=namesmapping[stage]))
+ elif task_name == 'MNLI_Match':
+ namesmapping = {
+ "train": "train",
+ "val": "dev_matched",
+ "test": "test_matched",
+ }
+ anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage]))
+ elif task_name == 'MNLI_Mismatch':
+ namesmapping = {
+ "train": "train",
+ "val": "dev_mismatched",
+ "test": "test_mismatched",
+ }
+ anno_file = os.path.join(data_dir, 'MNLI', 'processed/{name}.tsv'.format(name=namesmapping[stage]))
+
+ input_count = 2
+ if task_name == "QQP":
+ input_columns = [4, 5]
+ if stage == 'test':
+ input_columns = [2, 3]
+ label_column = 6
+ elif task_name in ["MNLI_Match", "MNLI_Mismatch"]: # "MNLI" :
+ input_columns = [9, 10]
+ if stage == 'test':
+ input_columns = [9, 10]
+
+ label_column = 12
+ if stage == 'val':
+ label_column = 16
+ elif task_name == "QNLI":
+ input_columns = [2, 3]
+ if stage == 'test':
+ input_columns = [2, 3]
+ label_column = 4
+ elif task_name == "MRPC":
+ input_columns = [4, 5]
+ if stage == 'test':
+ input_columns = [4, 5]
+ label_column = 1
+ elif task_name == "RTE":
+ input_columns = [2, 3]
+ if stage == 'test':
+ input_columns = [2, 3]
+ label_column = 4
+ elif task_name == "STS-B":
+ input_columns = [8, 9]
+ if stage == 'test':
+ input_columns = [8, 9]
+ label_column = 10
+ # Following are single sentence tasks.
+ elif task_name == "SST-2":
+ input_columns = [1]
+ if stage == 'test':
+ input_columns = [2]
+ label_column = 2
+ input_count = 1
+ elif task_name == "CoLA":
+ input_columns = [4]
+ if stage == 'test':
+ input_columns = [2]
+ label_column = 2
+ input_count = 1
+ else:
+ raise NotImplementedError
+
+ ret = {
+ "cfg": cfg,
+ "stage": stage,
+ "anno_file": anno_file,
+ "max_seq_len": cfg.MODEL.MAX_SEQ_LEN,
+ "input_columns": input_columns,
+ "label_column": label_column,
+ "input_count": input_count,
+ "task_name": task_name,
+ "data_percentage": getattr(cfg.DATALOADER, "DATA_PERCENTAGE", 1.0),
+ "data_k_sample": getattr(cfg.DATALOADER, "DATA_K_SAMPLE", -1),
+ "tokenizer": ClipTokenizer(),
+ "tokenizer_name": "clip"
+ }
+
+ return ret
+
+
+
+ def load_data(self, cfg):
+ cache_path = os.path.join(os.path.dirname(self.anno_file), "cache_GLUE_raw_%s_%s_%s.pkl" % (self.task_name, self.tokenizer_name, self.stage))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg)
+
+ pickle.dump(datalist, open(cache_path, "wb"))
+
+ datalist = pickle.load(open(cache_path, "rb"))
+
+ # for few shot exp
+
+ if self.data_percentage < 1.0 and self.stage == "train":
+ print("will sample {} data for trianing-->".format(self.data_percentage))
+ labels2l = dict()
+ for data in datalist:
+
+ label = data['label']
+ if label not in labels2l:
+ labels2l[label] = list()
+ labels2l[label].append(data)
+
+ # samplers_label = len(datalist) * self.data_percentage // len(labels2l.keys())
+ datalist = []
+
+ for v in labels2l.values():
+ datalist.extend(random.sample(v, k=int(self.data_percentage * len(v) + 1)))
+ # datalist.extend(random.sample(v, k=int(samplers_label+1)))
+
+ elif self.data_k_sample > 0 and self.stage == "train":
+ print("will sample {} data for each class when training -->".format(self.data_k_sample))
+ labels2l = dict()
+ for data in datalist:
+
+ label = data['label']
+ if label not in labels2l:
+ labels2l[label] = list()
+ labels2l[label].append(data)
+
+ datalist = []
+
+ for v in labels2l.values():
+ datalist.extend(random.sample(v, k=int(self.data_k_sample)))
+
+ while len(datalist) < 200:
+ datalist = datalist + datalist
+
+ self.datalist = datalist
+
+
+ def load_raw_data(self, cfg):
+ datalist = []
+ if self.task_name.startswith("MNLI"):
+ labelmapping = {
+ "contradiction": 0,
+ "neutral": 1,
+ "entailment": 2,
+ }
+ fin = open(self.anno_file, 'r').readlines()
+ for _, line in enumerate(fin):
+ sensinfo = line.strip().split('\t')
+ if self.task_name == "RTE":
+ label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0
+ elif self.task_name.startswith("MNLI"):
+ label = labelmapping[sensinfo[self.label_column - 1]]
+ elif self.task_name == "QNLI":
+ label = 1.0 if sensinfo[self.label_column - 1] == "entailment" else 0.0
+ elif self.task_name == "STS-B":
+ label = float(sensinfo[self.label_column - 1]) / 5.0
+ else:
+ label = float(sensinfo[self.label_column - 1])
+ datalist.append({
+ # start index from 1 to 0
+ "sentences": [sensinfo[i - 1] for i in self.input_columns],
+ "label": label
+ })
+ return datalist
+
+ def __len__(self):
+ return len(self.datalist)
+
+ def __getitem__(self, index):
+ dataset_dict = copy.deepcopy(self.datalist[index])
+
+ sentences = dataset_dict['sentences']
+
+ # input1: SEN1, this sentence is (spe) input2: word choice: postive and negative
+
+ if self.input_count == 1:
+
+
+ if self.task_name == "SST-2":
+ tokens = self.tokenizer.encode(sentences[0] + " <|endoftext|> It is <|spe|>. <|endoftext|>")
+ elif self.task_name == "CoLA":
+ tokens = self.tokenizer.encode(sentences[0] + " This is <|spe|>. <|endoftext|>")
+ else:
+ raise NotImplementedError
+
+ index = len(tokens) - 3
+ assert index < self.max_seq_len
+ if len(tokens) > self.max_seq_len:
+ tokens = tokens[:self.max_seq_len - 4] + tokens[-4:]
+
+
+
+ else:
+
+ if self.task_name in ["RTE"]:
+ tokens1 = self.tokenizer.encode(sentences[0])
+ if tokens1[-1] == 269:
+ tokens1 = tokens1[:-1]
+ tokens1 = tokens1 + self.tokenizer.encode(" ? <|endoftext|> it is ")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+
+ elif self.task_name in ["MRPC"]:
+ tokens1 = self.tokenizer.encode(sentences[0])
+ if tokens1[-1] == 269:
+ tokens1 = tokens1[:-1]
+ tokens1 = tokens1 + self.tokenizer.encode(" . ")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+
+ elif self.task_name in ["QQP"]:
+ tokens1 = self.tokenizer.encode(sentences[0])
+ if tokens1[-1] == 269:
+ tokens1 = tokens1[:-1]
+ tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> ")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+
+ elif self.task_name in ["QNLI"]:
+ tokens1 = self.tokenizer.encode(sentences[0])
+ if tokens1[-1] == 269:
+ tokens1 = tokens1[:-1]
+ tokens1 = tokens1 + self.tokenizer.encode(" <|endoftext|> it is ")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len - len(tokens1):
+ tokens2 = tokens2[:self.max_seq_len - len(tokens1) - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+
+ elif self.task_name in ["MNLI", "MNLI_Match"]:
+ # sentence0 = sentences[0].replace(")", "").replace("(", "")
+ tokens1 = self.tokenizer.encode(sentences[0])
+ # if tokens1[-1] == 269:
+ # tokens1 = tokens1[:-1]
+ tokens1 = tokens1 # + self.tokenizer.encode(" ? ")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+
+ elif self.task_name in ["RTE", "QNLI", "MNLI", "MNLI_Match"]:
+ tokens1 = self.tokenizer.encode(sentences[0] + "? <|endoftext|>")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+
+ if tokens1[-1] == 269:
+ tokens1 = tokens1[:-1]
+ tokens2 = self.tokenizer.encode(" <|spe|> , ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+ elif self.task_name in ["MRPC", "QQP"]:
+ tokens1 = self.tokenizer.encode(sentences[0] + " <|endoftext|>")
+ tokens2 = self.tokenizer.encode(sentences[1] + " <|endoftext|> ")
+ tokens2 = self.tokenizer.encode(" <|spe|>, ") + tokens2
+ if len(tokens2) > self.max_seq_len // 2:
+ tokens2 = tokens2[:self.max_seq_len // 2 - 1] + [tokens2[-1]]
+ max_len = self.max_seq_len - len(tokens2)
+ else:
+ NotImplementedError
+
+ # tokens = self.tokenizer.add_special_tokens_sentences_pair(tokens1, tokens2, start_type='SPE')
+ if len(tokens1) > max_len:
+ tokens1 = tokens1[:max_len - 1] + [tokens1[-1]]
+
+ tokens = tokens1 + tokens2
+
+ index = len(tokens1)
+ assert index < self.max_seq_len
+
+
+ sentences = np.array(tokens, dtype=np.int64)
+
+
+ if self.task_name in ["SST-2", "CoLA", "MRPC", "RTE", "QNLI", "MNLI", "QQP", "MNLI_Match"]:
+ label = int(dataset_dict['label'])
+ else:
+ raise NotImplementedError()
+
+
+ ret = {
+ 'input_sample': [{
+ 'data': [sentences],
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'invalid_mask': None,
+ 'sample_info' : {
+ 'spe_index': index
+ }
+ }],
+ 'target_sample': [],
+ 'target_idx' : [label],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+
+ dict_as_tensor(ret)
+ return ret
diff --git a/uniperceiver/datasets/task_dataset/general_corpus.py b/uniperceiver/datasets/task_dataset/general_corpus.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a8a114e5d1b4b37b9504bf48224a7cf6e9d7bdb
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/general_corpus.py
@@ -0,0 +1,364 @@
+from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
+import random
+import numpy as np
+import copy
+from torch.utils.data import Dataset
+from uniperceiver.tokenization import ClipTokenizer
+import logging
+import os
+from ..build import DATASETS_REGISTRY
+from uniperceiver.config import configurable
+import pickle
+from uniperceiver.utils import comm
+
+__all__ = ["GeneralCorpusDataset"]
+
+
+@DATASETS_REGISTRY.register()
+class GeneralCorpusDataset(Dataset):
+ @configurable
+ def __init__(self, ann_file, stage,
+ tokenizer, tokenizer_name,
+ seq_len=64, min_seq_len=64,
+ encoding="utf-8",
+ cache_mode=True, cache_local_rank=0, cache_local_size=1,
+ append_eos=False,
+ one_stream=False,
+ random_mask=False,
+ task_type=None,
+ text_type_id=0,
+ mask_bpe_word='spe',
+ version='v1',
+ task_info=None,
+ target_set=None,
+ **kwargs):
+ assert cache_mode, print("only support cache mode!")
+ assert len(task_type) > 0
+ self.version = version
+ self.stage = stage
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.use_clip_tokenizer = tokenizer_name == 'clip'
+ self.task_type = task_type
+ self.append_eos = append_eos
+ self.task_info = task_info
+ self.target_set = target_set
+
+ self.seq_len = seq_len
+ self.min_seq_len = min_seq_len
+ self.cache_mode = cache_mode
+ self.cache_local_size = cache_local_size
+ self.cache_local_rank = cache_local_rank
+
+ self.ann_file = ann_file
+ self.encoding = encoding
+ self.test_mode = False
+ self.random_mask = random_mask
+
+ self.one_stream = one_stream
+
+ self.text_type_id = text_type_id
+
+ self.mask_bpe_word = "<|spe|>" if mask_bpe_word == 'spe' else '<|startoftext|>'
+
+ # load samples into memory
+ if cache_mode:
+ print('dataset cache mode is ON: local size: {}; local rank: {}'.format(cache_local_size,
+ cache_local_rank))
+ self.corpus, self.cursor = self.load_corpus()
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ version = getattr(cfg.DATASETS, 'VERSION', 'v1')
+ if 'SLURM_PROCID' not in os.environ:
+ version = 'v1'
+ if version == 'v2':
+ ann_files = {
+ "train":
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2.txt")
+ if comm.get_world_size() > 1 else os.path.join(
+ cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"),
+ "val":
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc")
+ }
+ elif version == 'v3':
+ ann_files = {
+ "train":
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswikiopen.txt")
+ if comm.get_world_size() > 1 else os.path.join(
+ cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc"),
+ "val":
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki_v2-1000.doc")
+ }
+ else:
+ ann_files = {
+
+ "train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki.doc") if comm.get_world_size() > 1 else
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc"),
+ "val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "bookswiki-1000.doc")
+ }
+
+ task_info = {
+ 'task_type' : cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : cfg.DATALOADER.TRAIN_BATCH_SIZE if stage == 'train' else cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ ret = {
+ "version" : version,
+ "stage" : stage,
+ "ann_file" : ann_files[stage],
+ "seq_len" : cfg.MODEL.MAX_SEQ_LEN,
+ "min_seq_len" : cfg.MODEL.MAX_SEQ_LEN,
+ "cache_mode" : cfg.DATALOADER.CACHE_MODE,
+ "append_eos" : cfg.DATALOADER.APPEND_EOS,
+ "cache_local_rank": comm.get_local_rank(),
+ "cache_local_size": comm.get_local_size(),
+ "one_stream" : cfg.DATALOADER.ONE_STREAM,
+ "task_type" : cfg.DATASETS.TASK_TYPE,
+ "random_mask" : getattr(cfg.DATALOADER, 'RANDOM_MASK', False),
+ "text_type_id" : getattr(cfg.DATALOADER, 'TYPE_EMBEDDING_ID', 0),
+ "mask_bpe_word" : getattr(cfg.DATALOADER, 'MASK_BPE_WORD', 'spe'),
+ "task_info" : task_info,
+ "target_set" : cfg.DATASETS.TARGET_SET
+ }
+
+
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+
+
+ return ret
+
+ @classmethod
+ def add_config(cls, cfg):
+ cfg.DATALOADER.SAMPLER = "NodeDistributed"
+ cfg.DATALOADER.CACHE_MODE = True
+ cfg.DATALOADER.SEQ_PER_SAMPLE = 256
+ cfg.DATALOADER.MIN_SEQ_PER_SAMPLE = 256
+ cfg.DATALOADER.APPEND_EOS = True
+
+
+ def load_corpus(self):
+ if 'SLURM_PROCID' in os.environ:
+ self.cache_local_size = 8 # for convenice
+ cache_path = os.path.dirname(self.ann_file)
+ if self.version == 'v2':
+ cache_filename = 'cache/cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
+ elif self.version == 'v3':
+ cache_filename = 'cache_v3/cache_block_books_wiki_openweb' + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
+ else:
+ cache_filename = 'cache_block' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
+ cache_file = os.path.join(cache_path, cache_filename)
+ if not os.path.exists(cache_file):
+ if self.version == 'v3':
+ raise NotImplementedError
+ # [HACK] we hard code the corpus length
+ if 'SLURM_PROCID' in os.environ:
+ if self.version == 'v2':
+ self.file_len = 244208263
+ block_size = (self.file_len + self.cache_local_size - 1)// self.cache_local_size
+ block_start = block_size * self.cache_local_rank
+ block_end = (block_size) * (
+ 1 + self.cache_local_rank
+ ) if self.cache_local_rank + 1 < self.cache_local_size else self.file_len
+ else:
+ block_start = self.cache_local_rank * 13000000
+ block_end = ( self.cache_local_rank + 1 ) * 13000000
+ else:
+ block_start = 0
+ block_end = 1000
+ count = 0
+ corpus = bytearray()
+ cursor = []
+ c_ = 0
+ i_ = 0
+ for ann_file in self.ann_file.split('+'):
+ with open(ann_file, 'r', encoding=self.encoding) as f:
+ for l in f:
+ l = l.strip()
+ if l != '':
+ # if i_ % self.cache_local_size != self.cache_local_rank:
+ if i_< block_start or i_ >= block_end:
+ # cursor.append(c_)
+ i_ += 1
+ continue
+ l = l.encode()
+ corpus += l
+ cursor.append(c_)
+ c_ += len(l)
+ i_ += 1
+ count += 1
+ cursor.append(len(corpus))
+ cursor = np.array(cursor).astype(np.int, copy=False)
+ pickle.dump({
+ "corpus": corpus,
+ "cursor": cursor,
+ "count": count,
+ }, open(cache_file, "wb"), protocol=4)
+
+ else:
+ cachedata = pickle.load(open(cache_file, "rb"))
+ corpus, cursor, count = cachedata['corpus'], cachedata['cursor'], cachedata['count']
+
+ print("BooksWiki info: rank {} has {} sentences".format(self.cache_local_rank, count))
+
+ return corpus, cursor
+
+
+ def get_line(self, index):
+ return self.corpus[self.cursor[index]:self.cursor[index+1]].decode()
+
+ @property
+ def data_names(self):
+ return ['text', 'mlm_labels']
+
+ def __len__(self):
+ return len(self.cursor) - 1
+
+ def __getitem__(self, item):
+ # def __call__(self, item):
+ raw = self.get_line(item)
+
+ # tokenize
+ if self.use_clip_tokenizer:
+ tokens = self.tokenizer.basic_tokenize(raw)
+ if len(tokens) > 0 and self.append_eos:
+ tokens.append('<|endoftext|>')
+ else:
+ tokens = self.tokenizer.basic_tokenizer.tokenize(raw)
+
+ # add more tokens if len(tokens) < min_len
+ _cur = (item + 1) % (len(self.cursor) - 1)
+ while len(tokens) < self.min_seq_len:
+ if self.use_clip_tokenizer:
+ _cur_tokens = self.tokenizer.basic_tokenize(self.get_line(_cur))
+ if len(_cur_tokens) > 0 and self.append_eos:
+ _cur_tokens.append('<|endoftext|>')
+ else:
+ _cur_tokens = self.tokenizer.basic_tokenizer.tokenize(self.get_line(_cur))
+ tokens.extend(_cur_tokens)
+ _cur = (_cur + 1) % (len(self.cursor) - 1)
+
+ if self.task_type == 'text_mlm':
+ tokens, mlm_labels = self.random_word_wwm(tokens)
+
+ elif self.task_type == 'caption':
+ tokens_tmp = []
+ for token in tokens:
+ tokens_tmp.extend(self.tokenizer.encode_basic_tokenized_token(token))
+ tokens = tokens_tmp
+ mlm_labels = self.tokenizer.encode(
+ self.mask_bpe_word) * len(tokens)
+
+ if self.use_clip_tokenizer:
+ ids = tokens
+ else:
+ # add [CLS], [SEP]
+ tokens = tokens + ['[SEP]']
+ mlm_labels = mlm_labels + [-1]
+
+ # convert token to its vocab id
+ ids = self.tokenizer.convert_tokens_to_ids(tokens)
+
+ # truncate
+ if len(ids) > self.seq_len:
+ ids = ids[:(self.seq_len-1)] + [ids[-1]]
+ mlm_labels = mlm_labels[:(self.seq_len-1)] + [mlm_labels[-1]]
+ elif len(ids) < self.seq_len:
+ ids = ids + [0 for _ in range(self.seq_len - len(ids))]
+ mlm_labels = mlm_labels + [-1 for _ in range(self.seq_len - len(ids))]
+
+
+
+
+ if self.task_type == 'text_mlm':
+ ret = {
+ 'input_sample': [{
+ 'data' : [np.array(ids, dtype=np.int64)],
+ 'invalid_mask': None,
+ 'modality' : 'text',
+ 'data_type': 'input',
+ 'sample_info' : {
+ 'seq_length': len(ids)
+ }
+ }],
+ 'target_sample': [],
+ 'target_idx' : [np.array(mlm_labels, dtype=np.int64)],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+ elif self.task_type == 'caption':
+ source = np.array(ids, dtype=np.int64)
+ source2 = np.array(mlm_labels, dtype=np.int64)
+
+ ret = {
+ 'input_sample': [{
+ 'data': [source, source2],
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {}
+ }],
+ 'target_sample': [],
+ 'target_idx': [np.array(ids, dtype=np.int64)],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+
+ dict_as_tensor(ret)
+ return ret
+
+ def random_word_wwm(self, tokens):
+ output_tokens = []
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ if self.use_clip_tokenizer:
+ sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
+ else:
+ sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(
+ self.tokenizer.encoder[self.mask_bpe_word])
+ else:
+ output_tokens.append("[MASK]")
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
+ else:
+ output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
+ # -> rest 10% randomly keep current token
+ else:
+ for sub_token in sub_tokens:
+ output_tokens.append(sub_token)
+
+ # append current token to output (we will predict these later)
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_label.append(sub_token)
+ else:
+ try:
+ output_label.append(self.tokenizer.vocab[sub_token])
+ except KeyError:
+ # For unknown words (should not occur with BPE vocab)
+ output_label.append(self.tokenizer.vocab["[UNK]"])
+ logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token))
+ else:
+ for sub_token in sub_tokens:
+ # no masking token (will be ignored by loss function later)
+ output_tokens.append(sub_token)
+ output_label.append(-1)
+
+ return output_tokens, output_label
\ No newline at end of file
diff --git a/uniperceiver/datasets/task_dataset/image_text_pair.py b/uniperceiver/datasets/task_dataset/image_text_pair.py
new file mode 100644
index 0000000000000000000000000000000000000000..722ee58ad3b7f1adc4191ca79b351723a886432d
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/image_text_pair.py
@@ -0,0 +1,925 @@
+import random
+import os
+import time
+import json
+from tqdm import trange
+# import jsonlines
+from PIL import Image, ImageFile
+import copy
+
+# ImageFile.LOAD_TRUNCATED_IMAGES = True
+import cv2
+import base64
+import numpy as np
+import pyarrow as pa
+import logging
+# import spacy
+import glob
+from io import BytesIO
+import jsonlines
+
+import torch
+from torch.utils.data import Dataset
+from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
+from collections import defaultdict
+
+from uniperceiver.datasets.zipreader import ZipReader
+import errno
+from uniperceiver.datasets.circular_cached_loader import CircularCachedInputIterator
+
+from uniperceiver.tokenization import ClipTokenizer
+
+from ..build import DATASETS_REGISTRY
+# from uniperceiver.config import kfg
+from uniperceiver.config import configurable
+import pickle
+from uniperceiver.utils import comm
+
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import create_transform
+from torchvision import transforms
+from uniperceiver.datasets.custom_transforms import clip_transforms
+
+__all__ = ["ImageTextPairDataset"]
+
+memorycache = False
+
+
+
+def makedirsExist(path):
+ try:
+ os.makedirs(path, exist_ok=True)
+ except OSError as e:
+ if e.errno == errno.EEXIST:
+ print('Directory not created.')
+ else:
+ raise
+
+def _smart_join(str_or_list, delim):
+ if isinstance(str_or_list, str):
+ return str_or_list
+ else:
+ return delim.join(str_or_list)
+
+@DATASETS_REGISTRY.register()
+class ImageTextPairDataset(Dataset):
+
+ @configurable
+ def __init__(self, cfg, stage, ann_file, image_set, root_path, data_path, s3_path,
+ feats_folder,
+ dataset_name,
+ data_percentage,
+ seq_per_img,
+ tokenizer, tokenizer_name,
+ seq_len=64,
+ mask_prob=(0.15, 0.8), repl_prob=0.1,
+ task_type=True,
+ transform=None, test_mode=False,
+ zip_mode=False,
+ cache_mode=False,
+ cache_origin_image=False,
+ cache_local_rank=0, cache_local_size=1,
+ circular_cache_mode=False,
+ ignore_db_cache=True,
+ aspect_grouping=False,
+ use_ceph=False,
+ tcs_conf_path='',
+ random_caption=False,
+ max_length=-1,
+ as_numpy_as_possible=False,
+ use_node_distirbuted_sampler=False,
+ **kwargs):
+ """
+ Conceptual Captions Dataset
+
+ :param ann_file: annotation jsonl file
+ :param image_set: image folder name, e.g., 'vcr1images'
+ :param root_path: root path to cache database loaded from annotation file
+ :param data_path: path to vcr dataset
+ :param transform: transform
+ :param test_mode: test mode means no labels available
+ :param zip_mode: reading images and metadata in zip archive
+ :param cache_mode: cache whole dataset to RAM first, then __getitem__ read them from RAM
+ :param ignore_db_cache: ignore previous cached database, reload it from annotation file
+ :param tokenizer: default is BertTokenizer from pytorch_pretrained_bert
+ :param aspect_grouping: whether to group images via their aspect
+ :param kwargs:
+ """
+ super(ImageTextPairDataset, self).__init__()
+
+ # assert not cache_mode, 'currently not support cache mode!'
+ assert not test_mode
+ assert not (cache_mode and circular_cache_mode)
+
+ self.mask_prob = mask_prob
+ self.repl_prob = repl_prob
+ self.seq_len = seq_len
+ self.task_type = task_type
+ self.cfg = cfg
+ self.stage = stage
+ self.dataset_name = dataset_name
+ self.feats_folder = feats_folder
+ self.seq_per_img = seq_per_img
+ assert self.seq_per_img == 1
+ self.data_percentage = data_percentage
+
+
+ self.data_path = data_path
+ self.root_path = root_path
+ self.ann_file = ann_file
+ self.image_set = image_set
+ self.transform = transform
+ self.test_mode = test_mode
+ self.zip_mode = zip_mode
+ self.cache_mode = cache_mode
+ self.cache_origin_image = cache_origin_image
+ self.cache_local_rank = cache_local_rank
+ self.cache_local_size = cache_local_size
+ self.circular_cache_mode = circular_cache_mode
+ self.ignore_db_cache = ignore_db_cache
+ self.aspect_grouping = aspect_grouping
+ self.cache_dir = os.path.join(self.data_path, 'cache')
+ self.use_node_distirbuted_sampler = (use_node_distirbuted_sampler or cache_mode)
+ if not os.path.exists(self.cache_dir):
+ makedirsExist(self.cache_dir)
+
+ self.initialized = False
+
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.use_clip_tokenizer = tokenizer_name == 'clip'
+
+ self.zipreader = ZipReader()
+
+ self.use_ceph = use_ceph
+ self.tcs_conf_path = tcs_conf_path
+ if use_ceph:
+ self.data_path = s3_path
+ from uniperceiver.datasets.tcsreader import TCSLoader
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+ else:
+ self.data_path = feats_folder
+
+ if comm.is_main_process():
+ print(f"data_path for Dataset {self.dataset_name} with task {self.task_type}: {self.data_path}")
+
+ self.random_caption = random_caption
+
+
+ if self.dataset_name == 'VG':
+ self.load_VG(self.cfg)
+ elif self.dataset_name in ['MSCOCO', 'FLICKR']:
+ self.load_COCO_flickr(self.cfg)
+ else:
+ self.load_database()
+
+ if self.circular_cache_mode:
+ chunk_dir = os.path.join(self.data_path, '{}_chunks'.format(image_set))
+ self.chunk_path_list = glob.glob(os.path.join(chunk_dir, '*.pa'))
+
+ if self.aspect_grouping:
+ assert False, "not support aspect grouping currently!"
+ self.group_ids = self.group_aspect(self.database)
+
+ self.as_numpy_as_possible = as_numpy_as_possible
+ self.max_length = max_length
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+ anno_filename = cfg.DATALOADER.get("ANNO_FILENAME", "train_spacy.json")
+ if cfg.DATALOADER.USE_CEPH and cfg.DATALOADER.S3_ANNO_FOLDER is not None:
+ anno_folder = cfg.DATALOADER.S3_ANNO_FOLDER
+ else:
+ anno_folder = cfg.DATALOADER.ANNO_FOLDER
+ if cfg.DATASETS.DATASET_NAME == 'MSCOCO':
+ anno_files = {
+ "train": [os.path.join(anno_folder, "captions_train113k.json"), os.path.join(anno_folder, "captions_val5k.json")],
+ # no validation
+ "test": os.path.join(anno_folder, "captions_test5k.json")
+ }
+ elif cfg.DATASETS.DATASET_NAME == 'FLICKR':
+ anno_files = {
+ "train": [os.path.join(anno_folder, "all_data_final_train_2014.jsonline"), os.path.join(anno_folder, "all_data_final_val_set0_2014.jsonline")],
+ # no val
+ # "val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "all_data_final_val_set0_2014.jsonline"),
+ "test": os.path.join(anno_folder, "all_data_final_test_set0_2014.jsonline")
+ }
+ else:
+ anno_files = {
+ "train": os.path.join(anno_folder, anno_filename),
+ "val": os.path.join(anno_folder, anno_filename),
+ "test": os.path.join(anno_folder, anno_filename),
+ }
+ if getattr(cfg.DATALOADER, 'TRANSFORM', None) == 'clip_transforms':
+ transform = clip_transforms(stage, img_size=cfg.MODEL.IMG_INPUT_SIZE)
+ else:
+ # same as imagenet
+ transform = build_transform(is_train=(stage=='train'))
+
+ ret = {
+ 'cfg': cfg,
+ 'stage': stage,
+ 'ann_file' : anno_files[stage],
+ "seq_per_img": 1,
+ 'image_set' : stage,
+ 'root_path' : cfg.DATALOADER.ANNO_FOLDER,
+ 'data_path' : cfg.DATALOADER.FEATS_FOLDER,
+ 's3_path': cfg.DATALOADER.S3_PATH,
+ 'feats_folder': cfg.DATALOADER.FEATS_FOLDER,
+ 'dataset_name': cfg.DATASETS.DATASET_NAME,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ 'seq_len': cfg.MODEL.MAX_SEQ_LEN,
+ 'task_type': cfg.DATASETS.TASK_TYPE,
+ 'transform': transform,
+ 'zip_mode': cfg.DATALOADER.ZIP_MODE,
+ "cache_mode": cfg.DATALOADER.CACHE_MODE,
+ 'cache_origin_image': cfg.DATALOADER.CACHE_ORIGIN_IMAGE,
+ "cache_local_rank": comm.get_local_rank(),
+ "cache_local_size": comm.get_local_size(),
+ "circular_cache_mode": cfg.DATALOADER.CIRCULAR_CACHE_MODE,
+ "use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path": tcs_conf_path,
+ "random_caption": cfg.DATALOADER.RANDOM_CAPTION,
+ "as_numpy_as_possible": cfg.DATALOADER.AS_NUMPY_AS_POSSIBLE,
+ "use_node_distirbuted_sampler": cfg.DATALOADER.SAMPLER == 'NodeDistributed',
+ 'tokenizer': ClipTokenizer(),
+ 'tokenizer_name': "clip",
+
+ }
+
+
+
+ return ret
+
+ def _init_memcached(self):
+ pass
+
+ def load_img_info(self, anno_file):
+ id2path = {}
+ with jsonlines.open(anno_file) as reader:
+ for annotation in reader:
+ image_id = annotation["id"]
+ id2path[image_id] = annotation["img_path"]
+
+ return id2path
+
+ def load_COCO_flickr(self, cfg):
+ # for index_mapping
+ self.idx2name = dict()
+ self.name2idx = dict()
+ if isinstance(self.ann_file, list):
+ imageinfo = list()
+ self.id2path = dict()
+ for anno_file in self.ann_file:
+ if self.dataset_name == 'MSCOCO':
+ imageinfo.extend(json.load(open(anno_file))["images"])
+ else:
+ id2path = self.load_img_info(anno_file)
+ self.id2path.update(id2path)
+ else:
+ if self.dataset_name == 'MSCOCO':
+ imageinfo = json.load(open(self.ann_file))["images"]
+ else:
+ self.id2path = self.load_img_info(self.ann_file)
+
+ if self.dataset_name == 'MSCOCO':
+ for info in imageinfo:
+ self.idx2name[info['id']] = {
+ "split": info['file_path'],
+ "name": info['file_name']}
+ self.name2idx[info['file_name']] = info['id']
+
+ if self.stage == "test":
+ if self.dataset_name == 'MSCOCO':
+ cache_path = os.path.join(
+ os.path.dirname(self.ann_file), "cache",
+ "mscoco_caption_w_testcap_%s.pkl" % ( self.stage)
+ )
+ else:
+ cache_path = os.path.join(
+ self.root_path, "cache",
+ "RetrievalFlickr30k_raw_%s_%s_%d.pkl" % (self.tokenizer_name, self.stage, self.seq_len)
+ )
+
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.ann_file)
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ else:
+ datalist = list()
+ assert self.stage == "train", "no validation now"
+ for i, stage in enumerate(["train", "val"]):
+ if self.dataset_name == 'MSCOCO':
+ cache_path = os.path.join(
+ os.path.dirname(self.ann_file[i]), "cache",
+ "mscoco_caption_w_testcap_%s.pkl" % ( stage)
+ )
+ else:
+ cache_path = os.path.join(
+ self.root_path, "cache",
+ "RetrievalFlickr30k_raw_%s_%s_%d.pkl" % (self.tokenizer_name, stage, self.seq_len)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist_part = self.load_raw_data(cfg, self.ann_file[i])
+ pickle.dump(datalist_part, open(cache_path, "wb"))
+ datalist_part = pickle.load(open(cache_path, "rb"))
+ datalist.extend(datalist_part)
+
+ if self.data_percentage < 1.0 and self.stage == 'train':
+ datalist = random.sample(datalist, k = int(self.data_percentage* len(datalist) ) )
+
+ self.database = pa.array(datalist)
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:")
+ print('!!! length of _temp_list: ', len(datalist))
+ print('!!! size of _temp_list: ', sys.getsizeof(datalist))
+ print('!!! size of pa database: ', sys.getsizeof(self.database))
+ del datalist
+
+ def load_raw_data(self, cfg, anno_file):
+ datalist = []
+ if self.dataset_name == 'MSCOCO':
+ annoinfo = json.load(open(anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ image_caption_info = defaultdict(list)
+ for cap_info in captions_train:
+ image_caption_info[cap_info['image_id']].append(cap_info['caption'])
+
+ for im_id, caps in image_caption_info.items():
+ datalist.append(
+ {
+ "image_id": im_id,
+ "captions": caps,
+ }
+ )
+ else:
+ with jsonlines.open(anno_file) as reader:
+ for annotation in reader:
+ sentences = annotation["sentences"]
+ image_id = annotation["id"]
+ datalist.append({ "image_id": image_id, "imagename": annotation["img_path"], "captions": sentences })
+
+
+ return datalist
+
+ def load_VG(self, cfg):
+ cache_path = os.path.join(
+ os.path.dirname(self.ann_file), "cache",
+ "vg_caption_spe_raw_%s.pkl" % (self.stage)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ _temp_list = []
+ if self.use_ceph:
+ anno_file = os.path.join('s3://visual_genome/annotations', os.path.basename(self.ann_file))
+ annotations = json.load(BytesIO(self.tcs_loader.client.get(anno_file)))
+ else:
+ annotations = json.load(open(self.ann_file))
+
+ for im_id, annoinfo in annotations['phrase'].items():
+ _temp_list.append(
+ {
+ "image_id": im_id,
+ "captions": annoinfo,
+ 'path': annotations['subset'][im_id],
+ }
+ )
+ pickle.dump(_temp_list, open(cache_path, "wb"))
+ else:
+ _temp_list = pickle.load(open(cache_path, "rb"))
+ self.database = pa.array(_temp_list)
+
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.database))
+ del _temp_list
+
+ def load_database(self):
+
+ if self.random_caption:
+ cache_filename = 'spe_cache_random_caption_' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
+ else:
+ cache_filename = 'spe_cache_' + os.path.basename(self.ann_file).replace('.', "_") + "_" + str(self.cache_local_rank) + "_" + str(self.cache_local_size) + '.pkl'
+
+
+ cache_file = os.path.join(self.cache_dir, cache_filename)
+
+ if not os.path.exists((cache_file)):
+ _temp_list = []
+ self.img_path_to_index = {}
+ if self.use_ceph:
+ f = BytesIO(self.tcs_loader.client.get(self.ann_file))
+ else:
+ f = open(self.ann_file, 'r')
+ if self.dataset_name == 'SBU':
+ annofile = json.load(f)
+ else:
+ annofile = f
+ for i, l in enumerate(annofile):
+ if self.use_node_distirbuted_sampler and ((i % self.cache_local_size) != self.cache_local_rank):
+ _temp_list.append(None)
+ continue
+ l = l.strip()
+ if (l == ''):
+ continue
+ if self.dataset_name == 'SBU':
+ self.img_path_to_index[l] = i
+ _temp_list.append([l, annofile[l]])
+ else:
+ _data = json.loads(l)
+ if not self.zip_mode:
+ _data['image'] = _data['image'].replace('.zip@', '')
+ self.img_path_to_index[_data['image']] = i
+ if self.random_caption:
+ _temp_list.append([_data['image'], _smart_join(_data['caption'], '\t'), _data['title'], _data['description']])
+ else:
+ _temp_list.append([_data['image'], _smart_join(_data['caption'], '\t')])
+
+ f.close()
+
+
+ pickle.dump({
+ "path_to_indext": self.img_path_to_index,
+ "temp_list": _temp_list,
+ }, open(cache_file, "wb"), protocol=4)
+ else:
+ cachedata = pickle.load(open(cache_file, "rb"))
+ self.img_path_to_index, _temp_list = cachedata['path_to_indext'], cachedata['temp_list']
+
+ self.database = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.database))
+ del _temp_list
+
+ @property
+ def data_names(self):
+ return ['image', 'im_info', 'text', 'mlm_labels']
+
+ def __getitem__(self, index):
+ for i_try in range(100):
+ try:
+ image_path = None
+ image_id = None
+ idb = None
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ self.dataset_dict = self.database[index].as_py()
+ image_id = self.dataset_dict['image_id']
+ if self.dataset_name == 'VG':
+ imagepath = self.dataset_dict['path']
+ image_path = os.path.join(self.data_path, imagepath)
+ elif self.dataset_name == 'FLICKR':
+ image_path = os.path.join(self.data_path, self.id2path[image_id])
+ else:
+ image_split = self.idx2name[int(image_id)]['split']
+ image_name = self.idx2name[int(image_id)]['name']
+ image_path = os.path.join(self.data_path, image_split, image_name)
+ else:
+ _idb = self.database[index]
+ idb = {'image': str(_idb[0]).strip('./'), 'caption': str(_idb[1]).split('\t')}
+ if self.random_caption:
+ idb['title'] = [_idb[2].as_py()]
+ idb['description'] = [_idb[3].as_py()]
+ return self._data_transform(idb, index=index, as_numpy_as_possible=self.as_numpy_as_possible, image_path=image_path, image_id=image_id)
+ except Exception as e:
+ print(
+ "Failed to load image from idb {} with error {} ; trial {};".format(
+ self.database[index], e, i_try
+ )
+ )
+ index = (index + 1)%len(self.database)
+ while (self.database[index].as_py() is None):
+ index = (index + 1)%len(self.database)
+ continue
+
+ def _data_transform(self, idb, index=None, as_numpy_as_possible=False, fail_image_fill=(0.0, 0.0, 0.0), image_path=None, image_id=None):
+
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ image = self._load_image(image_path)
+ else:
+ if index is None:
+ index = self.img_path_to_index[idb['image']]
+ # image data
+
+ image = self.get_image(idb, index=index)
+ if isinstance(image, Image.Image):
+ w0, h0 = image.size
+ elif isinstance(image, np.ndarray):
+ h0, w0, c_ = image.shape
+ assert c_ == 3
+ else:
+ raise NotImplementedError
+
+ if self.transform is not None:
+ image = self.transform(image)
+
+ if image_id is not None:
+ img_sample_info = {
+ 'id': image_id,
+ 'path': image_path
+ }
+ else:
+ img_sample_info = {
+ 'id': index
+ }
+ ret = {
+ 'input_sample': [{
+ 'data' : image,
+ 'invalid_mask': None,
+ 'modality' : 'image',
+ 'data_type' : 'input',
+ 'sample_info' : copy.deepcopy(img_sample_info)
+ }]
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+ mlm_labels = None
+ u_mask_type = None
+ if self.task_type == 'image_caption' and self.stage != 'train':
+ ret.update({
+ 'target_set': copy.deepcopy(self.target_set),
+ 'target_sample': [],
+ 'target_idx': [],
+ 'task_info': copy.deepcopy(self.task_info)
+ })
+ dict_as_tensor(ret)
+ return ret
+
+ if self.task_type =='image_retrieval' and self.stage != 'train':
+ captions = [caption + " <|endoftext|>" for caption in self.dataset_dict['captions']]
+ caption_tokens_raw = [ self.tokenizer.encode(caption) for caption in captions]
+ if self.dataset_name in ['MSCOCO', 'FLICKR']:
+ caption_tokens = [ caption_token[:(self.seq_len - 1)] + [caption_token[-1]]
+ if len(caption_token) > self.seq_len else caption_token
+ for caption_token in caption_tokens_raw ]
+ return self.package_item(ret, caption_tokens, mlm_labels, u_mask_type)
+
+ # Task #1: Masked Language Modeling
+ if self.random_caption:
+ if len(idb['title']) == 0:
+ caption = idb['description']
+ if len(self.tokenizer.encode(' '.join(caption))) == 0:
+ caption = ['image']
+ else:
+ if random.random() < 0.5:
+ caption = idb['title']
+ if len(self.tokenizer.encode(' '.join(caption))) == 0:
+ caption = idb['description']
+ if len(self.tokenizer.encode(' '.join(caption))) == 0:
+ caption = ['image']
+ else:
+ caption = idb['description']
+ if len(self.tokenizer.encode(' '.join(caption))) == 0:
+ caption = idb['title']
+ if len(self.tokenizer.encode(' '.join(caption))) == 0:
+ caption = ['image']
+ else:
+ if self.dataset_name == 'VG':
+ caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0]
+ while len(caption) < 1:
+ caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0]
+ if caption and caption.lower()[-1] in "qwertyuiopasdfghjklzxcvbnm1234567890":
+ caption = caption + "."
+ elif self.dataset_name in ['MSCOCO', 'FLICKR']:
+ caption = random.sample(self.dataset_dict['captions'], self.seq_per_img)[0]
+ else:
+ caption = idb['caption']
+ if caption and caption[-1] and caption[-1].lower()[-1] in "1234567890qwertyuiopasdfghjklzxcvbnm":
+ caption.append(".")
+
+ # in CC12m
+ # print('Before:', caption)
+ for i_, tok in enumerate(caption):
+ if '' in tok:
+ tok = tok.replace('', 'person')
+ caption[i_] = tok
+
+ if self.task_type == 'mlm':
+ u_mask_type = 1
+ elif self.task_type == 'image_caption':
+ u_mask_type = 0 # causal mask
+
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ caption = caption + " <|endoftext|>"
+ else:
+ caption = caption + ["<|endoftext|>"]
+
+ if self.task_type=='mlm':
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ caption_tokens = self.tokenizer.basic_tokenize(caption)
+ else:
+ if self.use_clip_tokenizer:
+ caption_tokens = self.tokenizer.basic_tokenize(' '.join(caption))
+ else:
+ caption_tokens = self.tokenizer.basic_tokenizer.tokenize(' '.join(caption))
+ caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens)
+ elif self.task_type == 'image_caption':
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ caption_tokens = self.tokenizer.encode(caption)
+ mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens)
+ else:
+ # caption
+ caption_tokens = self.tokenizer.encode(' '.join(caption))
+ mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens)
+ else:
+ if self.dataset_name in ['VG', 'MSCOCO', 'FLICKR']:
+ caption_tokens = self.tokenizer.encode(caption)
+ else:
+ caption_tokens = self.tokenizer.encode(' '.join(caption))
+ mlm_labels = [-1] * len(caption_tokens)
+
+ text = caption_tokens
+
+ # truncate seq to max len
+ if len(text) > self.seq_len:
+ # mlm task
+ text_len_keep = self.seq_len
+ text = text[:(text_len_keep - 1)] + [text[-1]]
+ if self.task_type=='image_caption' or self.task_type=='mlm':
+ mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]]
+
+
+ if as_numpy_as_possible:
+ text = np.array(text)
+ mlm_labels = np.array(mlm_labels)
+
+ return self.package_item(ret, text, mlm_labels, u_mask_type)
+
+
+ # return image, im_info, text, mlm_labels
+
+ def package_item(self, ret, text, mlm_labels, u_mask_type):
+
+
+ if self.task_type == 'image_retrieval':
+ if self.stage == 'train':
+ ret.update({
+ 'target_sample': [{
+ 'data' : [np.array(text, dtype=np.int64)],
+ 'modality' : 'text',
+ 'data_type' : 'target',
+ 'invalid_mask': None,
+ 'sample_info' : {}
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ else:
+ image_id = ret['input_sample'][0]['sample_info']['id']
+ ret['input_sample'][0]['sample_info']['id'] = (image_id, [image_id] * len(text))
+ ret.update({
+ 'target_sample': [{
+ 'data': [np.array(single_text, dtype=np.int64) for single_text in text],
+ 'modality': 'text',
+ 'invalid_mask': None,
+ 'data_type': 'target',
+ 'sample_info': {
+ 'sample_alone': True,
+ }
+ }],
+ 'target_idx': [],
+ 'target_set': [],
+ 'task_info':
+ copy.deepcopy(self.task_info)
+ })
+
+ elif self.task_type == 'mlm':
+
+ raise NotImplementedError('no needed for masked language modeling when given image now.')
+
+ elif self.task_type == 'image_caption':
+ source = np.array(text, dtype=np.int64)
+ source2 = np.array(mlm_labels, dtype=np.int64)
+ ret['input_sample'].append({
+ 'data': [source, source2],
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'text_spe_cat': True,
+ }
+ })
+ ret.update({
+ 'target_sample': [],
+ 'target_idx' : [np.array(text, dtype=np.int64)],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ else:
+ raise NotImplementedError
+
+ dict_as_tensor(ret)
+
+ return ret
+
+ def random_word_wwm(self, tokens):
+ output_tokens = []
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ if self.use_clip_tokenizer:
+ sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
+ else:
+ sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(self.tokenizer.encoder["<|spe|>"])
+ else:
+ output_tokens.append("[MASK]")
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
+ else:
+ output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
+ # -> rest 10% randomly keep current token
+ else:
+ for sub_token in sub_tokens:
+ output_tokens.append(sub_token)
+
+ # append current token to output (we will predict these later)
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_label.append(sub_token)
+ else:
+ try:
+ output_label.append(self.tokenizer.vocab[sub_token])
+ except KeyError:
+ # For unknown words (should not occur with BPE vocab)
+ output_label.append(self.tokenizer.vocab["[UNK]"])
+ logging.warning("Cannot find sub_token '{}' in vocab. Using [UNK] insetad".format(sub_token))
+ else:
+ for sub_token in sub_tokens:
+ # no masking token (will be ignored by loss function later)
+ output_tokens.append(sub_token)
+ output_label.append(-1)
+
+ return output_tokens, output_label
+
+ def cache_images(self, resize_to=(224, 224)):
+ assert not self.zip_mode
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 95]
+ barray = bytearray()
+ cursor = []
+ c_ = 0
+ for i in trange(len(self.database)):
+ if i % self.cache_local_size != self.cache_local_rank:
+ cursor.append(c_)
+ continue
+ idb = self.database[i]
+ if self.cache_origin_image:
+ try:
+ with open(os.path.join(self.data_path, idb['image']), 'rb') as f:
+ im = f.read()
+ except:
+ print("Failed to cache image {}, cache zero byte!".format(idb['image']))
+ im = bytes()
+ else:
+ im = cv2.imread(os.path.join(self.data_path, idb['image']), cv2.IMREAD_COLOR)
+ if im is None:
+ print("Failed to cache image {}, cache zero image!".format(idb['image']))
+ w, h = resize_to
+ im = np.zeros((h, w, 3), dtype=np.uint8)
+ else:
+ im = cv2.resize(im, resize_to)
+ _, im = cv2.imencode('.jpg', im, encode_param)
+ im = im.tobytes()
+ barray += im
+ cursor.append(c_)
+ c_ += len(im)
+ cursor.append(c_)
+
+ return barray, cursor
+
+ def get_image(self, idb, index=None):
+ if index is None:
+ index = self.img_path_to_index[idb['image']]
+ if self.circular_cache_mode:
+ im = idb['image_augmented']
+ else:
+ im = self._load_image(os.path.join(self.data_path, idb['image']))
+ return im
+
+ @staticmethod
+ def b64_decode(string):
+ return base64.decodebytes(string.encode())
+
+ @staticmethod
+ def group_aspect(database):
+ print('grouping aspect...')
+ t = time.time()
+
+ # get shape of all images
+ widths = torch.as_tensor([idb['width'] for idb in database])
+ heights = torch.as_tensor([idb['height'] for idb in database])
+
+ # group
+ group_ids = torch.zeros(len(database))
+ horz = widths >= heights
+ vert = 1 - horz
+ group_ids[horz] = 0
+ group_ids[vert] = 1
+
+ print('Done (t={:.2f}s)'.format(time.time() - t))
+
+ return group_ids
+
+ def __len__(self):
+ length = len(self.database)
+ if self.max_length > 0:
+ length = min(self.max_length, length)
+ return length
+ # return 10000000
+
+ def _load_image(self, path):
+ if '.zip@' in path:
+ return self.zipreader.imread(path).convert('RGB')
+ else:
+ if self.use_ceph:
+ # print('USE TCS!!!!!')
+ return self.tcs_loader(path).convert('RGB')
+ elif not memorycache:
+ with open(path, 'rb') as f:
+ return Image.open(f).convert('RGB')
+ else:
+ # memcached
+ raise NotImplementedError
+
+ def _load_json(self, path):
+ if '.zip@' in path:
+ f = self.zipreader.read(path)
+ return json.loads(f.decode())
+ else:
+ with open(path, 'r') as f:
+ return json.load(f)
+
+
+def build_transform(is_train,
+ input_size=224,
+ color_jitter=0.4,
+ auto_augment='rand-m9-mstd0.5-inc1',
+ train_interpolation='bicubic',
+ re_prob=0.25,
+ re_mode='pixel',
+ re_count=1
+ ):
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ color_jitter=color_jitter,
+ auto_augment=auto_augment,
+ interpolation=train_interpolation,
+ re_prob=re_prob,
+ re_mode=re_mode,
+ re_count=re_count
+ )
+
+ return transform
+
+ t = []
+ size = int((256 / 224) * input_size)
+ t.append(
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+ return transforms.Compose(t)
diff --git a/uniperceiver/datasets/task_dataset/imagenet.py b/uniperceiver/datasets/task_dataset/imagenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d3d00d6a4d1161625c729edc22da6f79f8c863c
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/imagenet.py
@@ -0,0 +1,410 @@
+import os
+import copy
+import pickle
+from PIL import Image
+import torch
+from torchvision import transforms
+import random
+from torchvision.transforms.transforms import ToTensor
+from tqdm import tqdm
+import numpy as np
+from uniperceiver.config import configurable
+from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
+from ..build import DATASETS_REGISTRY
+import glob
+import json
+from collections import defaultdict
+
+from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.data import create_transform
+import pyarrow as pa
+from uniperceiver.utils import comm
+
+__all__ = ["ImageNetDataset", "ImageNet22KDataset"]
+
+
+def load_pkl_file(filepath):
+ return pickle.load(open(filepath, 'rb'), encoding='bytes') if len(filepath) > 0 else None
+
+@DATASETS_REGISTRY.register()
+class ImageNetDataset:
+ @configurable
+ def __init__(
+ self,
+ stage: str,
+ anno_file: str,
+ s3_path: str,
+ feats_folder: str,
+ class_names: list,
+ use_ceph: bool,
+ tcs_conf_path,
+ data_percentage,
+ task_info,
+ target_set,
+ cfg,
+ ):
+ self.stage = stage
+ self.ann_file = anno_file
+ self.feats_folder = feats_folder
+ self.class_names = class_names if (class_names is not None) else None
+ self.data_percentage = data_percentage
+
+ self.initialized = False
+
+ self.cfg = cfg
+
+ self.task_info = task_info
+ self.target_set = target_set
+ # for index_maping
+ self.idx2info = dict()
+
+ self.use_ceph = use_ceph
+ if self.use_ceph:
+ self.feats_folder = s3_path
+ print('debug info for imagenet{} {}'.format(self.ann_file, self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+
+ self.transform = build_transform(is_train=(self.stage == 'train'),
+ input_size=cfg.MODEL.IMG_INPUT_SIZE)
+
+ _temp_list =self.load_data(self.cfg)
+ self.datalist = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print("ImageNet1K Pretrain Dataset:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.datalist))
+ del _temp_list
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "slurm_tools/petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+ ann_files = {
+ "train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "train.txt"),
+ "val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "val.txt"),
+ "test": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "test.txt")
+ }
+
+ task_info = {
+ 'task_type' : cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : cfg.DATALOADER.TRAIN_BATCH_SIZE if stage == 'train' else cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+
+ ret = {
+ "cfg" : cfg,
+ "stage" : stage,
+ "anno_file" : ann_files[stage],
+ "feats_folder" : cfg.DATALOADER.FEATS_FOLDER,
+ 's3_path' : cfg.DATALOADER.S3_PATH,
+ "class_names" : load_pkl_file(cfg.DATALOADER.CLASS_NAME_FILE) if cfg.DATALOADER.CLASS_NAME_FILE else None,
+ "use_ceph" : getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path" : tcs_conf_path,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ "task_info" : task_info,
+ "target_set" : cfg.DATASETS.TARGET_SET
+ }
+
+ return ret
+
+ def _preprocess_datalist(self, datalist):
+ return datalist
+
+ def load_data(self, cfg):
+ datalist = []
+
+ # local file reading
+ with open(self.ann_file, 'r') as f:
+ img_infos = f.readlines()
+
+ if self.stage == "train" and self.data_percentage < 1.0:
+ id2img = dict()
+ for idx, l in enumerate(img_infos):
+ name = int(l.replace('\n', '').split(' ')[1])
+ if name not in id2img:
+ id2img[name] = list()
+ id2img[name].append(idx)
+ self.idx2info[idx] = l.replace('\n', '').split(' ')[0]
+
+ datalist = list()
+ for k, v in id2img.items():
+ for idx in random.sample(v, k=int(len(v)*self.data_percentage)+1):
+ datalist.append({
+ 'image_id': idx,
+ 'class_id': k,
+ "file_path": self.idx2info[idx],
+ })
+ else:
+ datalist = [{
+ 'image_id': idx,
+ 'class_id': int(l.replace('\n', '').split(' ')[1]),
+ "file_path": l.replace('\n', '').split(' ')[0],
+ } for idx, l in enumerate(img_infos)]
+
+ datalist = self._preprocess_datalist(datalist)
+ return datalist
+
+ def __len__(self):
+ return len(self.datalist)
+
+ def __getitem__(self, index):
+ for i_try in range(100):
+ try:
+ dataset_dict =self.datalist[index].as_py()
+ image_id = dataset_dict['image_id']
+ class_id = dataset_dict['class_id']
+ image_name = dataset_dict['file_path']
+
+ # load image
+ image_path = os.path.join(self.feats_folder, self.stage, image_name)
+
+ if self.use_ceph:
+ img = self.tcs_loader(image_path).convert('RGB')
+
+ else:
+ img = Image.open(image_path).convert("RGB")
+
+
+ except Exception as e:
+ print(
+ "Failed to load image from {} with error {} ; trial {}".format(
+ image_path, e, i_try
+ )
+ )
+
+ # let's try another one
+ index = random.randint(0, len(self.datalist) - 1)
+ continue
+
+
+ img = self.transform(img)
+
+
+ ret = {
+ 'input_sample' : [{
+ 'data' : img,
+ 'invalid_mask': None,
+ 'modality' : 'image',
+ 'data_type': 'input',
+ 'sample_info' : {
+ 'id' : image_id,
+ 'path': image_path
+ }
+ }],
+ 'target_sample': [],
+ 'target_idx' : [class_id],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+
+ }
+ return ret
+
+
+
+
+@DATASETS_REGISTRY.register()
+class ImageNet22KDataset:
+ @configurable
+ def __init__(
+ self,
+ stage: str,
+ anno_file: str,
+ s3_path: str,
+ feats_folder: str,
+ use_ceph: bool,
+ tcs_conf_path: str,
+ cfg: str,
+ task_info,
+ target_set,
+ ):
+ self.cfg = cfg
+ self.stage = stage
+ self.ann_file = anno_file
+ self.feats_folder = feats_folder
+ self.task_info = task_info
+ self.target_set = target_set
+ self.initialized = False
+
+ self.use_ceph = use_ceph
+ if self.use_ceph:
+ self.feats_folder = s3_path
+ print('debug info for imagenet22k {} {}'.format(self.ann_file, self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+
+
+ self.transform = build_transform(is_train=(self.stage == 'train'),
+ input_size=cfg.MODEL.IMG_INPUT_SIZE)
+
+ _temp_list = self.load_data(self.cfg)
+ self.datalist = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print("ImageNet22K Pretrain Dataset:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.datalist))
+ del _temp_list
+
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+
+ ann_files = {
+ "train": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "imagenet_22k_filelist_short.txt"),
+ "val": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "imagenet_22k_filelist_short.txt"),
+ }
+
+
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "slurm_tools/petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+
+ task_info = {
+ 'task_type' : cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : cfg.DATALOADER.TRAIN_BATCH_SIZE if stage == 'train' else cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ ret = {
+ "cfg" : cfg,
+ "stage" : stage,
+ "anno_file" : ann_files[stage],
+ 's3_path' : cfg.DATALOADER.S3_PATH,
+ "feats_folder" : cfg.DATALOADER.FEATS_FOLDER,
+ "use_ceph" : getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path": tcs_conf_path,
+ "task_info" : task_info,
+ "target_set" : cfg.DATASETS.TARGET_SET
+ }
+
+ return ret
+
+ def _preprocess_datalist(self, datalist):
+ return datalist
+
+ def load_data(self, cfg):
+ datalist = []
+
+ # local file reading
+ with open(self.ann_file, 'r') as f:
+ img_infos = f.readlines()
+
+ datalist = []
+ for idx, l in enumerate(img_infos):
+ info_strip = l.replace('\n', '').split(' ')
+ wn_id = info_strip[0]
+ class_id = info_strip[2]
+ file_path = wn_id + '/' + wn_id + '_' + info_strip[1] + '.JPEG' # n01440764/n01440764_10074.JPEG
+
+ datalist.append(
+ {
+ 'image_id': idx,
+ 'file_path': file_path,
+ 'class_id': int(class_id)
+ }
+ )
+
+ datalist = self._preprocess_datalist(datalist)
+ return datalist
+
+ def __len__(self):
+ return len(self.datalist)
+
+ def __getitem__(self, index):
+ for i_try in range(100):
+ try:
+ dataset_dict =self.datalist[index].as_py()
+ image_id = dataset_dict['image_id']
+ class_id = dataset_dict['class_id']
+ image_name = dataset_dict['file_path']
+
+ # load image
+ image_path = os.path.join(self.feats_folder, image_name)
+
+ if self.use_ceph:
+ img = self.tcs_loader(image_path).convert('RGB')
+
+ else:
+ img = Image.open(image_path).convert("RGB")
+
+
+ except Exception as e:
+ print(
+ "Failed to load image from {} with error {} ; trial {}".format(
+ image_path, e, i_try
+ )
+ )
+
+ # let's try another one
+ index = random.randint(0, len(self.datalist) - 1)
+ continue
+
+ img = self.transform(img)
+
+ ret = {
+ 'input_sample': [{
+ 'data' : img,
+ 'invalid_mask': None,
+ 'modality' : 'image',
+ 'data_type': 'input',
+ 'sample_info' : {
+ 'id' : image_id,
+ 'path': image_path
+ }
+ }],
+ 'target_sample': [],
+ 'target_idx' : [class_id],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+
+ return ret
+
+
+
+def build_transform(is_train,
+ input_size=224,
+ color_jitter=0.4,
+ auto_augment='rand-m9-mstd0.5-inc1',
+ train_interpolation='bicubic',
+ re_prob=0.25,
+ re_mode='pixel',
+ re_count=1
+ ):
+ if is_train:
+ # this should always dispatch to transforms_imagenet_train
+ transform = create_transform(
+ input_size=input_size,
+ is_training=True,
+ color_jitter=color_jitter,
+ auto_augment=auto_augment,
+ interpolation=train_interpolation,
+ re_prob=re_prob,
+ re_mode=re_mode,
+ re_count=re_count
+ )
+
+ return transform
+
+ t = []
+ size = int((256 / 224) * input_size)
+ t.append(
+ transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images
+ )
+ t.append(transforms.CenterCrop(input_size))
+
+ t.append(transforms.ToTensor())
+ t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
+ return transforms.Compose(t)
diff --git a/uniperceiver/datasets/task_dataset/mscoco_pretrain.py b/uniperceiver/datasets/task_dataset/mscoco_pretrain.py
new file mode 100644
index 0000000000000000000000000000000000000000..320a0266b844cb694a3491c101c244786c052abd
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/mscoco_pretrain.py
@@ -0,0 +1,486 @@
+import os
+import copy
+import pickle
+from PIL import Image
+from torchvision import transforms
+import random
+from torchvision.transforms.transforms import ToTensor
+from tqdm import tqdm
+import numpy as np
+from uniperceiver.config import configurable
+from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
+from ..build import DATASETS_REGISTRY
+import glob
+from uniperceiver.tokenization import ClipTokenizer
+import json
+from collections import defaultdict
+from uniperceiver.datasets.custom_transforms import clip_transforms
+import pyarrow as pa
+from uniperceiver.utils import comm
+
+__all__ = ["ImageTextPairDataset"]
+
+memorycache = False
+
+@DATASETS_REGISTRY.register()
+class ImageTextPairDataset:
+ @configurable
+ def __init__(
+ self,
+ cfg: str,
+ stage: str,
+ anno_file: str,
+ seq_per_img: int,
+ max_seq_len: int,
+ feats_folder: str,
+ relation_file: str,
+ gv_feat_file: str,
+ attribute_file: str,
+ transform,
+ tokenizer,
+ data_percentage,
+ tokenizer_name,
+ use_ceph: bool,
+ tcs_conf_path,
+ task_type,
+ preload_feats = None,
+ random_mask=False,
+ text_type_id=0,
+ ):
+ assert len(task_type)>0
+ self.cfg = cfg
+ self.stage = stage
+ self.anno_file = anno_file
+ self.seq_per_img = seq_per_img
+ assert self.seq_per_img == 1
+ self.use_ceph = use_ceph
+ self.task_type = task_type
+ if self.use_ceph:
+ self.feats_folder = 's3://coco'
+ print('debug info for coco pretrain: {} '.format(self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ if 'SLURM_PROCID' in os.environ:
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+ else:
+ self.tcs_loader = TCSLoader('petreloss_local.config')
+ else:
+ # local image folder
+ self.feats_folder = feats_folder
+ self.max_seq_len = max_seq_len
+ self.relation_file = relation_file
+ self.gv_feat_file = gv_feat_file
+ self.attribute_file = attribute_file
+
+ self.data_percentage = data_percentage
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.use_clip_tokenizer = tokenizer_name == 'clip'
+
+ self.initialized = False
+ self.transform = transform
+
+ self.loaded_feats = None
+ if preload_feats:
+ self.loaded_feats = self.pre_load_feats(preload_feats)
+
+ # for index_maping
+ self.idx2name = dict()
+ self.name2idx = dict()
+ # please
+ if isinstance(self.anno_file, list):
+ imageinfo = list()
+ for anno_file in self.anno_file:
+ imageinfo.extend(json.load(open(anno_file))["images"])
+ else:
+ imageinfo = json.load(open(self.anno_file))["images"]
+ for info in imageinfo:
+ self.idx2name[info['id']] = {
+ "split": info['file_path'],
+ "name": info['file_name']}
+ self.name2idx[info['file_name']] = info['id']
+ self.random_mask = random_mask
+
+ self.text_type_id = text_type_id
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE
+ if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ _temp_list =self.load_data(self.cfg)
+ self.database = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print("MSCOCO Pretrain Dataset:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.database))
+ del _temp_list
+
+
+ def pre_load_feats(self, preload_feat_folder):
+ loaded_feats = {}
+ file_list = glob.glob(os.path.join(preload_feat_folder, '*.pkl'))
+ for fname in file_list:
+ with open(fname, 'rb') as f:
+ feats = pickle.load(f)
+ loaded_feats.update(feats)
+ return loaded_feats
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+ ann_files = {
+ "train": [os.path.join(cfg.DATALOADER.ANNO_FOLDER, "captions_train113k.json"), os.path.join(cfg.DATALOADER.ANNO_FOLDER, "captions_val5k.json")],
+ # no validation
+ "test": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "captions_test5k.json")
+ }
+ if getattr(cfg.DATALOADER, 'TRANSFORM', None) == 'clip_transforms':
+ transform = clip_transforms(stage,
+ img_size=cfg.MODEL.IMG_INPUT_SIZE)
+ else:
+ transform = transforms.Compose([
+ transforms.Resize([224, 224]),
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225))]
+ )
+ ret = {
+ "cfg" : cfg,
+ "stage" : stage,
+ "anno_file" : ann_files[stage],
+ "seq_per_img" : 1,
+ "feats_folder" : cfg.DATALOADER.FEATS_FOLDER,
+ "relation_file" : cfg.DATALOADER.RELATION_FILE,
+ "gv_feat_file" : cfg.DATALOADER.GV_FEAT_FILE,
+ "attribute_file" : cfg.DATALOADER.ATTRIBUTE_FILE,
+ "max_seq_len" : cfg.MODEL.MAX_SEQ_LEN,
+ "use_ceph" : getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path" : tcs_conf_path,
+ "transform" : transform,
+ 'task_type' : cfg.DATASETS.TASK_TYPE,
+ "random_mask" : getattr(cfg.DATALOADER, 'RANDOM_MASK', False),
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ "text_type_id" : getattr(cfg.DATALOADER, 'TYPE_EMBEDDING_ID', 0),
+ }
+
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+
+ return ret
+
+ def _preprocess_datalist(self, datalist):
+ return datalist
+
+ def load_data(self, cfg):
+ if self.stage == "test":
+ cache_path = os.path.join(
+ os.path.dirname(self.anno_file), "cache",
+ "mscoco_caption_w_testcap_%s.pkl" % ( self.stage)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.anno_file)
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ else:
+ datalist = list()
+ assert self.stage == "train", "no validation now"
+ for i, stage in enumerate(["train", "val"]):
+ cache_path = os.path.join(
+ os.path.dirname(self.anno_file[i]), "cache",
+ "mscoco_caption_w_testcap_%s.pkl" % ( stage)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist_part = self.load_raw_data(cfg, self.anno_file[i])
+ pickle.dump(datalist_part, open(cache_path, "wb"))
+ datalist_part = pickle.load(open(cache_path, "rb"))
+ datalist.extend(datalist_part)
+
+ def _load_pkl_file(filepath):
+ return pickle.load(open(filepath, 'rb'), encoding='bytes') if len(filepath) > 0 else None
+
+ ext_data = {
+ "relation": _load_pkl_file(self.relation_file),
+ "attribute": _load_pkl_file(self.attribute_file),
+ "gv_feat": _load_pkl_file(self.gv_feat_file)
+ }
+ for i in range(len(datalist)):
+ image_id = int(datalist[i]['image_id'])
+ for data_type in ext_data:
+ if ext_data[data_type] is not None:
+ if str(image_id) in ext_data[data_type]:
+ datalist[i][data_type] = ext_data[data_type][str(image_id)]
+ elif image_id in ext_data[data_type]:
+ datalist[i][data_type] = ext_data[data_type][image_id]
+
+ if self.data_percentage < 1.0 and self.stage == 'train':
+ datalist = random.sample(datalist, k = int(self.data_percentage* len(datalist) ) )
+
+ return datalist
+
+
+ def load_raw_data(self, cfg, anno_file):
+ datalist = []
+ annoinfo = json.load(open(anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ image_caption_info = defaultdict(list)
+ for cap_info in captions_train:
+ image_caption_info[cap_info['image_id']].append(cap_info['caption'])
+
+ for im_id, caps in image_caption_info.items():
+ datalist.append(
+ {
+ "image_id": im_id,
+ "captions": caps,
+ }
+ )
+
+ return datalist
+
+ def random_word_wwm(self, tokens):
+ output_tokens = []
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ if self.use_clip_tokenizer:
+ sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
+ else:
+ sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(self.tokenizer.encoder["<|spe|>"])
+ else:
+ output_tokens.append("[MASK]")
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
+ else:
+ output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
+ # -> rest 10% randomly keep current token
+ else:
+ for sub_token in sub_tokens:
+ output_tokens.append(sub_token)
+
+ # append current token to output (we will predict these later)
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_label.append(sub_token)
+ else:
+ try:
+ output_label.append(self.tokenizer.vocab[sub_token])
+ except KeyError:
+ # For unknown words (should not occur with BPE vocab)
+ output_label.append(self.tokenizer.vocab["[UNK]"])
+ else:
+ for sub_token in sub_tokens:
+ # no masking token (will be ignored by loss function later)
+ output_tokens.append(sub_token)
+ output_label.append(-1)
+
+ return output_tokens, output_label
+
+ def __len__(self):
+ return len(self.database)
+
+ # def __call__(self, dataset_dict):
+ def __getitem__(self, index):
+ for i_try in range(100):
+ try:
+ dataset_dict = self.database[index].as_py()
+ image_id = dataset_dict['image_id']
+ image_split = self.idx2name[int(image_id)]['split']
+ image_name = self.idx2name[int(image_id)]['name']
+
+ # load image
+ image_path = os.path.join(self.feats_folder, image_split, image_name)
+
+ if self.use_ceph:
+ img = self.tcs_loader(image_path).convert('RGB')
+
+ else:
+ img = Image.open(image_path).convert("RGB")
+
+ break
+ except Exception as e:
+ print("Failed to load image from idb {} with error {} ; trial {};".format(self.database[index], e, i_try))
+ index = (index + 1) % len(self.database)
+ while (self.database[index].as_py() is None):
+ index = (index + 1) % len(self.database)
+ continue
+
+
+
+ img = self.transform(img)
+
+ ret = {
+ 'input_sample': [{
+ 'data' : img,
+ 'invalid_mask': None,
+ 'modality' : 'image',
+ 'data_type': 'input',
+ 'sample_info' :{'id': image_id, 'path': image_path}
+ }]
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+ if self.task_type == 'image_caption' and self.stage != 'train':
+ ret.update({
+ 'target_set': copy.deepcopy(self.target_set),
+ 'target_sample': [],
+ 'target_idx': [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ dict_as_tensor(ret)
+ return ret
+
+
+
+
+
+ if self.task_type =='image_retrieval' and self.stage != 'train':
+ captions = [caption + " <|endoftext|>" for caption in dataset_dict['captions']]
+ caption_tokens_raw = [ self.tokenizer.encode(caption) for caption in captions]
+
+ caption_tokens = [ caption_token[:(self.max_seq_len - 1)] + [caption_token[-1]]
+ if len(caption_token) > self.max_seq_len else caption_token
+ for caption_token in caption_tokens_raw ]
+
+
+ else:
+ caption = random.sample(dataset_dict['captions'], self.seq_per_img)[0]
+ # caption = ['pilgrims', 'coffee', 'house', '-', 'outside', 'the', 'store']
+ caption = caption + " <|endoftext|>"
+
+ if self.task_type == 'mlm':
+ u_mask_type = 1
+ elif self.task_type == 'image_caption':
+ u_mask_type = 0 # causal mask
+
+ if self.task_type=='image_caption' or self.task_type =='mlm':
+ if u_mask_type == 1: # mlm
+ caption_tokens = self.tokenizer.basic_tokenize(caption)
+ caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens)
+ else:
+ # caption
+ caption_tokens = self.tokenizer.encode(caption)
+ mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens)
+
+ else:
+ caption_tokens = self.tokenizer.encode(caption)
+
+ if len(caption_tokens) > self.max_seq_len:
+ # mlm task
+ text_len_keep = self.max_seq_len
+ caption_tokens = caption_tokens[:(text_len_keep - 1)] + [caption_tokens[-1]]
+ if self.task_type=='image_caption' or self.task_type == 'mlm':
+ mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]]
+
+ # self.task_info = {
+ # 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ # 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ # 'batch_size' : self.cfg.DATASETS.TRAIN_BATCH_SIZE
+ # if self.stage == 'train' else self.cfg.DATASETS.TEST_BATCH_SIZE,
+ # 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT
+ # }
+
+ if self.task_type == 'image_caption':
+ source = np.array(caption_tokens, dtype=np.int64)
+ source2 = np.array(mlm_labels, dtype=np.int64)
+ ret['input_sample'].append({
+ 'data' :[source, source2],
+ 'invalid_mask' : None,
+ 'modality' : 'text',
+ 'data_type' : 'input',
+ 'sample_info' :
+ {
+ 'text_spe_cat': True,
+ }
+ })
+ ret.update({
+ "target_sample": [],
+ "target_idx" : [np.array(caption_tokens, dtype=np.int64)],
+ "target_set" : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+
+ elif self.task_type == 'mlm':
+ ret['input_sample'].append({
+ 'data' : [np.array(caption_tokens, dtype=np.int64)],
+ 'invalid_mask': None,
+ 'modality' : 'text',
+ 'data_type' : 'input',
+ 'sample_info' : {"text_token_padding_length": self.max_seq_len}
+ })
+ ret.update({
+ 'target_sample': [],
+ "target_idx" : [np.array(mlm_labels, dtype=np.int64)],
+ "target_set" : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ elif self.task_type == 'image_retrieval':
+ if self.stage == 'train':
+ ret.update({
+ 'target_sample': [{
+ 'data' : [np.array(caption_tokens, dtype=np.int64)],
+ 'modality' : 'text',
+ 'invalid_mask': None,
+ 'data_type' : 'target',
+ 'sample_info' : {}
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ else:
+ ret.update(
+ {
+ 'input_sample': [{
+ 'data' : img, 'invalid_mask': None, 'modality': 'image', 'data_type': 'input',
+ 'sample_info' : {
+ 'id' : (image_id, [image_id] * len(caption_tokens)),
+ 'path' : image_path
+ }
+ }],
+ 'target_sample': [{
+ 'data' : [np.array(caption_token, dtype=np.int64)
+ for caption_token in caption_tokens],
+ 'modality' : 'text',
+ 'invalid_mask': None,
+ 'data_type' : 'target',
+ 'sample_info' : {
+ 'sample_alone': True,
+ }
+
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+ )
+ else:
+ raise NotImplementedError
+
+ dict_as_tensor(ret)
+
+ return ret
diff --git a/uniperceiver/datasets/task_dataset/msrvtt.py b/uniperceiver/datasets/task_dataset/msrvtt.py
new file mode 100644
index 0000000000000000000000000000000000000000..d16bda83a3ef51fe85ffe472b38d12319197a33c
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/msrvtt.py
@@ -0,0 +1,569 @@
+# Copyright 2021 JD.com, Inc., JD AI
+"""
+@author: Yehao Li, Jingwen Chen
+@contact: yehaoli.sysu@gmail.com, chenjingwen.sysu@gmail.com
+"""
+import os
+import copy
+import pickle
+import random
+import numpy as np
+import torch
+from uniperceiver.config import configurable
+from uniperceiver.functional import read_np, dict_as_tensor
+from ..build import DATASETS_REGISTRY
+from uniperceiver.tokenization import ClipTokenizer
+from torchvision.transforms import Compose, RandomApply, ToTensor, Normalize, CenterCrop, Lambda, RandomHorizontalFlip, ColorJitter, Resize, RandomCrop
+from .video_transform import random_short_side_scale_jitter, uniform_crop
+import json
+from io import BytesIO
+import av
+from .video_raw import VideoDataSet
+import io
+from collections import defaultdict
+
+import pyarrow as pa
+from uniperceiver.utils import comm
+import copy
+
+__all__ = ["MSRVTTDataset"]
+
+def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False):
+ """
+ Args:
+ video_frames (int): total frame number of a video
+ sampling_rate (int): sampling rate for clip, pick one every k frames
+ frames_per_clip (int): number of frames of a clip
+ fixed_offset (bool): used with sample offset to decide the offset value deterministically.
+ Returns:
+ list[int]: frame indices (started from zero)
+ """
+ new_sampling_rate = sampling_rate
+ highest_idx = video_frames - int(new_sampling_rate * frames_per_clip)
+ if highest_idx <= 0:
+ random_offset = 0
+ else:
+ if fixed_offset:
+ random_offset = (video_frames - int(new_sampling_rate * frames_per_clip)) // 2
+ else:
+ random_offset = int(np.random.randint(0, highest_idx, 1))
+ frame_idx = [int(random_offset + int(i * sampling_rate)) % video_frames for i in range(frames_per_clip)]
+ frame_idx = [x for x in frame_idx if x < video_frames]
+ return frame_idx
+
+
+@DATASETS_REGISTRY.register()
+class MSRVTTDataset(VideoDataSet):
+ @configurable
+ def __init__(
+ self,
+ stage: str,
+ anno_file: str,
+ seq_per_img: int,
+ max_feat_num: int,
+ max_seq_len: int,
+ feats_folder: str,
+ tokenizer,
+ tokenizer_name,
+ use_ceph: bool,
+ tcs_conf_path,
+ frames_per_clip, interval, num_clips, timesformer_aug,
+ task_type,
+ data_percentage,
+ target_fps=30,
+ random_mask=False,
+ cfg=None,
+ ):
+ self.cfg = cfg
+ self.stage = stage
+ self.anno_file = anno_file
+ self.seq_per_img = seq_per_img
+ self.max_feat_num = max_feat_num
+ self.feats_folder = feats_folder
+ self.max_seq_len = max_seq_len
+ self.task_type = task_type
+
+ self.initialized = False
+
+ # sample_list = list(self.fin.keys())
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.use_clip_tokenizer = self.tokenizer_name == 'clip'
+ # for index_maping
+ self.idx2name = dict()
+ self.name2idx = dict()
+
+ self.use_ceph = use_ceph
+ if isinstance(self.anno_file, list):
+ self.cache_dir = os.path.join(os.path.dirname(self.anno_file[0]), 'cache')
+ else:
+ self.cache_dir = os.path.join(os.path.dirname(self.anno_file), 'cache')
+ self.frames_per_clip = frames_per_clip
+ self.interval = interval
+
+ # self.MULTI_VEIW = self.cfg.DATALOADER.get('MULTI_VEIW', 'v0')
+ # self.MULTI_VEIW_NUM = self.cfg.DATALOADER.get('MULTI_VEIW_NUM', 1)
+ self.random_stride = self.cfg.DATALOADER.get('RANDON_STRIDE', False)
+
+ self.num_clips = num_clips
+ self.is_train = stage == 'train'
+ self.test_mode = stage != 'train'
+ self.transform = self._timesformer_transform() if timesformer_aug else self._transform()
+ self.target_fps = target_fps
+ self.data_percentage = data_percentage
+
+ if self.use_ceph:
+ self.feats_folder = 's3://msrvtt/videos/'
+ if isinstance(self.anno_file, list):
+ self.anno_file = [os.path.join('s3://msrvtt/annotations/', os.path.basename(anno_file)) for anno_file in self.anno_file]
+ else:
+ self.anno_file = os.path.join('s3://msrvtt/annotations/', os.path.basename(self.anno_file))
+ print('debug info for msrvtt pretrain: {} '.format(self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ if 'SLURM_PROCID' in os.environ:
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+ else:
+ self.tcs_loader = TCSLoader('slurm_tools/petreloss_local.config')
+ else:
+ # local image folder
+ self.feats_folder = feats_folder
+
+ if self.use_ceph:
+ if isinstance(self.anno_file, list):
+ videoinfo = list()
+ for anno_file in self.anno_file:
+ videoinfo.extend(json.load(BytesIO(self.tcs_loader.client.get(anno_file)))["images"])
+ else:
+ videoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file)))["images"]
+ else:
+ if isinstance(self.anno_file, list):
+ videoinfo = list()
+ for anno_file in self.anno_file:
+ videoinfo.extend(json.load(open(anno_file))["images"])
+ else:
+ videoinfo = json.load(open(self.anno_file))["images"]
+ for vinfo in videoinfo:
+ self.idx2name[vinfo['id']] = vinfo['file_name']
+ self.name2idx[vinfo['file_name']] = vinfo['id']
+ self.random_mask = random_mask
+ pass
+
+ _temp_list =self.load_data(self.cfg)
+ self.video_list = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.cfg.DATASETS.DATASET_NAME} with task {self.cfg.DATASETS.TASK_TYPE}:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.video_list))
+ del _temp_list
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ if stage == "train":
+ ann_file = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msrvtt_1k_trainval_cocostyle.json")
+ else:
+ assert stage == "test"
+ ann_file = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msrvtt_1k_test_cocostyle.json")
+ feat_path = os.path.join(cfg.DATALOADER.FEATS_FOLDER, "MSRVTT_ResNet152_{}.hdf5".format(stage))
+
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "petreloss_local.config"
+
+ ret = {
+ "stage": stage,
+ "anno_file": ann_file,
+ "seq_per_img": cfg.DATALOADER.SEQ_PER_SAMPLE,
+ "max_feat_num": cfg.DATALOADER.MAX_FEAT_NUM,
+ "feats_folder": feat_path,
+ "max_seq_len": cfg.MODEL.MAX_SEQ_LEN,
+ "use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path": tcs_conf_path,
+ 'task_type': cfg.DATASETS.TASK_TYPE,
+ "frames_per_clip": cfg.DATALOADER.FRAMES_PER_CLIP,
+ "interval": cfg.DATALOADER.STRIDE,
+ "num_clips": 1 if stage == 'train' else cfg.INFERENCE.NUM_VIEWS,
+ "timesformer_aug": cfg.DATALOADER.TIMESFORMER_AUG,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ "cfg": cfg,
+ }
+ if getattr(cfg.INFERENCE, "VOCAB", None) == 'CLIP':
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+ else:
+ raise NotImplementedError
+ return ret
+
+ def load_data(self, cfg):
+ if self.stage == "train":
+ total_datalist = list()
+ for i, stage in enumerate(["train", "val"]):
+ cache_path = os.path.join(
+ self.cache_dir,
+ "msrvtt_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, stage, self.max_seq_len)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.anno_file[i])
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ if isinstance(datalist[0]['caption'], list):
+ new_datalist = list()
+ for data in datalist:
+ if isinstance(data['caption'], str):
+ new_datalist.append(data)
+ else:
+ video_id = data['video_id']
+ for caption in data['caption']:
+ new_datalist.append({
+ "video_id": video_id,
+ "caption": caption,
+ })
+ datalist = new_datalist
+ total_datalist.extend(datalist)
+
+ if self.data_percentage < 1.0 and self.stage == 'train':
+ datalist = random.sample(total_datalist, k = int(self.data_percentage* len(total_datalist) ) )
+ total_datalist = datalist
+
+ else:
+ assert self.stage == "test"
+ cache_path = os.path.join(
+ self.cache_dir,
+ "msrvtt_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, self.stage, self.max_seq_len)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.anno_file)
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ total_datalist = datalist
+ return total_datalist
+
+
+ def load_raw_data(self, cfg, anno_file):
+ datalist = []
+ if self.stage == 'train':
+ if self.use_ceph:
+ annoinfo = json.load(BytesIO(self.tcs_loader.client.get(anno_file)))
+ else:
+ annoinfo = json.load(open(anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ for data in captions_train:
+ datalist.append(
+ {
+ 'video_id': data['image_id'],
+ 'caption': data['caption']
+ }
+ )
+
+ else:
+ if self.use_ceph:
+ annoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file)))
+ else:
+ annoinfo = json.load(open(self.anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ video2caps = defaultdict(list)
+ for data in captions_train:
+ video2caps[data['image_id']].append(data['caption'])
+
+ for videoid, caps in video2caps.items():
+ datalist.append(
+ {
+ 'video_id': videoid,
+ 'caption': caps
+ }
+ )
+ return datalist
+
+ def _timesformer_transform(self):
+ transforms = [
+ Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])),
+ ]
+ if self.test_mode:
+ test_scale = self.cfg.MODEL.IMG_INPUT_SIZE
+ transforms.extend([
+ Lambda(lambda frames: random_short_side_scale_jitter(
+ frames, test_scale, test_scale)[0]),
+ CenterCrop(test_scale),
+ # Lambda(lambda images: torch.stack([uniform_crop(images, 224, i)[0] for i in range(3)], 0))
+ ])
+ else:
+ min_scale = int((256 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+ max_scale = int((320 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+
+ transforms.extend([
+ Lambda(lambda frames: random_short_side_scale_jitter(frames, min_scale, max_scale)[0].unsqueeze(0)),
+ RandomHorizontalFlip(),
+ RandomCrop(self.cfg.MODEL.IMG_INPUT_SIZE)
+ ])
+ transforms.append(
+ # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ # change to imagenet default value to keep consistency with pretrained parameters
+ # Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ )
+ return Compose(transforms)
+
+ def _sample_frame(self, atten_feats):
+ interval = atten_feats.shape[0] / self.max_feat_num
+ selected_indexes = [int(i * interval) for i in range(self.max_feat_num)]
+ selected_frames = atten_feats[selected_indexes, :]
+ return selected_frames
+
+ def random_word_wwm(self, tokens):
+ output_tokens = []
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ if self.use_clip_tokenizer:
+ sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
+ else:
+ sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(self.tokenizer.encoder["<|spe|>"])
+ else:
+ output_tokens.append("[MASK]")
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
+ else:
+ output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
+ # -> rest 10% randomly keep current token
+ else:
+ for sub_token in sub_tokens:
+ output_tokens.append(sub_token)
+
+ # append current token to output (we will predict these later)
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_label.append(sub_token)
+ else:
+ try:
+ output_label.append(self.tokenizer.vocab[sub_token])
+ except KeyError:
+ # For unknown words (should not occur with BPE vocab)
+ output_label.append(self.tokenizer.vocab["[UNK]"])
+ else:
+ for sub_token in sub_tokens:
+ # no masking token (will be ignored by loss function later)
+ output_tokens.append(sub_token)
+ output_label.append(-1)
+
+ # if no word masked, random choose a word to mask
+ # if all([l_ == -1 for l_ in output_label]):
+ # choosed = random.randrange(0, len(output_label))
+ # output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]
+
+ return output_tokens, output_label
+
+
+ def __getitem__(self, idx):
+
+ for i_try in range(100):
+ # try:
+ record = self.video_list[idx].as_py()
+ record = copy.deepcopy(record)
+ video_id = record['video_id']
+ # load video
+
+ video_path = os.path.join(self.feats_folder, self.idx2name[video_id] + '.mp4')
+ if self.use_ceph:
+ container = av.open(io.BytesIO(self.tcs_loader.client.get(video_path)))
+ else:
+ container = av.open(video_path)
+
+
+ # container.streams.video[0].thread_type = "AUTO"
+ stream = container.streams.video[0]
+ total_frames = stream.frames
+ fps = float(container.streams.video[0].average_rate)
+
+ if total_frames == 0:
+ # it returns 0 if not know, but that doesn't mean the video is null
+ for frame in container.decode(stream):
+ total_frames += 1
+ container.close()
+ container = av.open(video_path)
+ stream = container.streams.video[0]
+ # except Exception as e:
+ # print(
+ # "Failed to load video from {} with error {} ; trial {}".format(
+ # video_path, e, i_try
+ # )
+ # )
+
+ # let's try another one
+ # index = random.randint(0, len(self.data_list) - 1)
+ # record = self.data_list[index]
+ # continue
+
+ if self.stage=='train':
+ indices = [self._sample_indices(total_frames, fps)]
+ else:
+ indices = self._get_val_indices(total_frames, fps)
+
+ all_index = set()
+ for index in indices:
+ all_index = all_index.union(set(index))
+
+ start_index = min(all_index)
+ num_frames = len(all_index)
+
+ images = dict()
+
+ fetched = 0
+
+ for frame in container.decode(stream):
+ if frame.index not in all_index or frame.index in images:
+ continue
+ images[frame.index] = frame.to_rgb().to_image()
+ last = frame.index
+ fetched += 1
+ if fetched == num_frames:
+ break
+
+ container.close()
+
+ video_data = list()
+ for ind in indices:
+ seq = list()
+ for i in ind:
+ if i in images:
+ seq.append(images[i])
+ else:
+ seq.append(images[last])
+ video_data.append(self.transform(seq))
+ video_data = torch.cat(video_data, dim=0)
+
+ if video_data.dim() == 4:
+ video_data.unsqueeze_(0) # in case there is only one frame
+
+ ret = {
+ 'input_sample':[
+ {
+ 'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input',
+ 'sample_info':{
+ 'id': video_id,
+ 'path': video_path,
+ 'num_views':num_frames,
+ 'cat_along_first_dim': True,
+ }
+ }
+ ],
+ 'target_sample': [],
+ }
+
+ if self.stage == 'train' and record['caption'] is not None:
+ caption = record['caption']
+ caption = caption + " <|endoftext|>"
+
+ if self.task_type == 'video_mlm':
+ u_mask_type = 1
+ elif self.task_type == 'video_caption':
+ u_mask_type = 0 # causal mask
+
+ if self.task_type=='video_caption' or self.task_type =='video_mlm':
+ if u_mask_type == 1: # mlm
+ caption_tokens = self.tokenizer.basic_tokenize(caption)
+ caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens)
+ else:
+ # caption
+ caption_tokens = self.tokenizer.encode(caption)
+ mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens)
+
+ else:
+ caption_tokens = self.tokenizer.encode(caption)
+
+
+ if len(caption_tokens) > self.max_seq_len:
+ # mlm task
+ text_len_keep = self.max_seq_len
+ caption_tokens = caption_tokens[:(text_len_keep - 1)] + [caption_tokens[-1]]
+ if self.task_type == 'video_caption' or self.task_type == 'video_mlm':
+ mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]]
+
+ ret = {
+ 'input_sample': [{
+ 'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input',
+ 'sample_info':{
+ 'id': video_id,
+ 'path': video_path,
+ 'num_views':num_frames,
+ 'cat_along_first_dim': True,
+ }
+ }]
+ }
+
+ if self.task_type == 'video_caption':
+
+ source = np.array(caption_tokens, dtype=np.int64)
+ source2 = np.array(mlm_labels, dtype=np.int64)
+ ret['input_sample'].append({
+ 'data': [source, source2],
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'text_spe_cat': True,
+ }
+ })
+ ret.update({
+ 'target_sample': [],
+ 'target_idx' : [np.array(caption_tokens, dtype=np.int64)],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+
+ elif self.task_type == 'video_mlm':
+
+ raise NotImplementedError('no needed for masked language modeling when given video now.')
+
+
+ elif self.task_type == 'video_retrieval':
+ ret.update({
+ 'target_sample': [{
+ 'data' : [np.array(caption_tokens, dtype=np.int64)],
+ 'modality' : 'text',
+ 'data_type' : 'target',
+ 'invalid_mask': None,
+ 'sample_info' : {}
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+
+
+
+ dict_as_tensor(ret)
+ return ret
diff --git a/uniperceiver/datasets/task_dataset/msvd.py b/uniperceiver/datasets/task_dataset/msvd.py
new file mode 100644
index 0000000000000000000000000000000000000000..c5c838215250ecc9938f4c26fb1fa6296e7ab3fd
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/msvd.py
@@ -0,0 +1,593 @@
+import os
+import copy
+import pickle
+import random
+import numpy as np
+import torch
+from uniperceiver.config import configurable
+from uniperceiver.functional import read_np, dict_as_tensor
+from ..build import DATASETS_REGISTRY
+from uniperceiver.tokenization import ClipTokenizer
+from torchvision.transforms import Compose, RandomApply, ToTensor, Normalize, CenterCrop, Lambda, RandomHorizontalFlip, ColorJitter, Resize, RandomCrop
+from .video_transform import random_short_side_scale_jitter, uniform_crop
+import json
+from io import BytesIO
+import av
+from .video_raw import VideoDataSet
+import io
+from collections import defaultdict
+
+import pyarrow as pa
+from uniperceiver.utils import comm
+import copy
+
+__all__ = ["MSVDDataset"]
+
+def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False):
+ """
+ Args:
+ video_frames (int): total frame number of a video
+ sampling_rate (int): sampling rate for clip, pick one every k frames
+ frames_per_clip (int): number of frames of a clip
+ fixed_offset (bool): used with sample offset to decide the offset value deterministically.
+ Returns:
+ list[int]: frame indices (started from zero)
+ """
+ new_sampling_rate = sampling_rate
+ highest_idx = video_frames - int(new_sampling_rate * frames_per_clip)
+ if highest_idx <= 0:
+ random_offset = 0
+ else:
+ if fixed_offset:
+ random_offset = (video_frames - int(new_sampling_rate * frames_per_clip)) // 2
+ else:
+ random_offset = int(np.random.randint(0, highest_idx, 1))
+ frame_idx = [int(random_offset + int(i * sampling_rate)) % video_frames for i in range(frames_per_clip)]
+ frame_idx = [x for x in frame_idx if x < video_frames]
+ return frame_idx
+
+
+@DATASETS_REGISTRY.register()
+class MSVDDataset(VideoDataSet):
+ @configurable
+ def __init__(
+ self,
+ stage: str,
+ anno_file: str,
+ seq_per_img: int,
+ max_feat_num: int,
+ max_seq_len: int,
+ feats_folder: str,
+ tokenizer,
+ tokenizer_name,
+ use_ceph: bool,
+ tcs_conf_path,
+ frames_per_clip, interval, num_clips, timesformer_aug,
+ task_type,
+ data_percentage,
+ target_fps=30,
+ random_mask=False,
+ cfg=None,
+ ):
+ self.cfg = cfg
+ self.stage = stage
+ self.anno_file = anno_file
+ self.seq_per_img = seq_per_img
+ self.max_feat_num = max_feat_num
+ self.feats_folder = feats_folder
+ self.max_seq_len = max_seq_len
+ self.task_type = task_type
+
+ self.initialized = False
+
+ # sample_list = list(self.fin.keys())
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.use_clip_tokenizer = self.tokenizer_name == 'clip'
+ # for index_maping
+ self.idx2name = dict()
+ self.name2idx = dict()
+
+ self.use_ceph = use_ceph
+ if isinstance(self.anno_file, list):
+ self.cache_dir = os.path.join(os.path.dirname(self.anno_file[0]), 'cache')
+ else:
+ self.cache_dir = os.path.join(os.path.dirname(self.anno_file), 'cache')
+ self.frames_per_clip = frames_per_clip
+ self.interval = interval
+
+ # self.MULTI_VEIW = self.cfg.DATALOADER.get('MULTI_VEIW', 'v0')
+ # self.MULTI_VEIW_NUM = self.cfg.DATALOADER.get('MULTI_VEIW_NUM', 1)
+ self.random_stride = self.cfg.DATALOADER.get('RANDON_STRIDE', False)
+
+ self.num_clips = num_clips
+ self.is_train = stage == 'train'
+ self.test_mode = stage != 'train'
+ self.transform = self._timesformer_transform() if timesformer_aug else self._transform()
+ self.target_fps = target_fps
+ self.data_percentage = data_percentage
+
+ if self.use_ceph:
+ self.feats_folder = 's3://msvd/YouTubeClips/'
+ if isinstance(self.anno_file, list):
+ self.anno_file = [os.path.join('s3://msvd/annotations/', os.path.basename(anno_file)) for anno_file in self.anno_file]
+ else:
+ self.anno_file = os.path.join('s3://msvd/annotations/', os.path.basename(self.anno_file))
+ print('debug info for msvd pretrain: {} '.format(self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ if 'SLURM_PROCID' in os.environ:
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+ else:
+ self.tcs_loader = TCSLoader('slurm_tools/petreloss_local.config')
+ else:
+ # local image folder
+ self.feats_folder = feats_folder
+
+ if self.use_ceph:
+ if isinstance(self.anno_file, list):
+ videoinfo = list()
+ for anno_file in self.anno_file:
+ videoinfo.extend(json.load(BytesIO(self.tcs_loader.client.get(anno_file)))["images"])
+ else:
+ videoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file)))["images"]
+ else:
+ if isinstance(self.anno_file, list):
+ videoinfo = list()
+ for anno_file in self.anno_file:
+ videoinfo.extend(json.load(open(anno_file))["images"])
+ else:
+ videoinfo = json.load(open(self.anno_file))["images"]
+ for vinfo in videoinfo:
+ self.idx2name[vinfo['id']] = vinfo['file_name']
+ self.name2idx[vinfo['file_name']] = vinfo['id']
+ self.random_mask = random_mask
+ pass
+
+ _temp_list =self.load_data(self.cfg)
+ self.video_list = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.cfg.DATASETS.DATASET_NAME} with task {self.cfg.DATASETS.TASK_TYPE}:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.video_list))
+ del _temp_list
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ if stage == "train":
+ ann_file = [os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msvd_train_cocostyle.json"),
+ os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msvd_val_cocostyle.json")]
+ else:
+ assert stage == "test"
+ ann_file = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "caption_msvd_{}_cocostyle.json".format(stage))
+ feat_path = os.path.join(cfg.DATALOADER.FEATS_FOLDER, "MSVD_ResNet152_{}.hdf5".format(stage))
+
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "slurm_tools/petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+
+ ret = {
+ "stage": stage,
+ "anno_file": ann_file,
+ "seq_per_img": cfg.DATALOADER.SEQ_PER_SAMPLE,
+ "max_feat_num": cfg.DATALOADER.MAX_FEAT_NUM,
+ "feats_folder": feat_path,
+ "max_seq_len": cfg.MODEL.MAX_SEQ_LEN,
+ "use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path": tcs_conf_path,
+ 'task_type': cfg.DATASETS.TASK_TYPE,
+ "frames_per_clip": cfg.DATALOADER.FRAMES_PER_CLIP,
+ "interval": cfg.DATALOADER.STRIDE,
+ "num_clips": 1 if stage == 'train' else cfg.INFERENCE.NUM_VIEWS,
+ "timesformer_aug": cfg.DATALOADER.TIMESFORMER_AUG,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ "cfg": cfg,
+ }
+ if getattr(cfg.INFERENCE, "VOCAB", None) == 'CLIP':
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+ else:
+ raise NotImplementedError
+ return ret
+
+ def load_data(self, cfg):
+ if self.stage == "train":
+ total_datalist = list()
+ for i, stage in enumerate(["train", "val"]):
+ cache_path = os.path.join(
+ self.cache_dir,
+ "msvd_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, stage, self.max_seq_len)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.anno_file[i])
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ if isinstance(datalist[0]['caption'], list):
+ new_datalist = list()
+ for data in datalist:
+ if isinstance(data['caption'], str):
+ new_datalist.append(data)
+ else:
+ video_id = data['video_id']
+ for caption in data['caption']:
+ new_datalist.append({
+ "video_id": video_id,
+ "caption": caption,
+ })
+ datalist = new_datalist
+ total_datalist.extend(datalist)
+
+ if self.data_percentage < 1.0 and self.stage == 'train':
+ datalist = random.sample(total_datalist, k = int(self.data_percentage* len(total_datalist) ) )
+ total_datalist = datalist
+
+ else:
+ assert self.stage == "test"
+ cache_path = os.path.join(
+ self.cache_dir,
+ "msvd_raw_caption_retrieval_%s_%s_%d.pkl" % (self.tokenizer_name, self.stage, self.max_seq_len)
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg, self.anno_file)
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ total_datalist = datalist
+ return total_datalist
+
+
+ def load_raw_data(self, cfg, anno_file):
+ datalist = []
+ if self.stage == 'train':
+ if self.use_ceph:
+ annoinfo = json.load(BytesIO(self.tcs_loader.client.get(anno_file)))
+ else:
+ annoinfo = json.load(open(anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ for data in captions_train:
+ datalist.append(
+ {
+ 'video_id': data['image_id'],
+ 'caption': data['caption']
+ }
+ )
+
+ else:
+ if self.use_ceph:
+ annoinfo = json.load(BytesIO(self.tcs_loader.client.get(self.anno_file)))
+ else:
+ annoinfo = json.load(open(self.anno_file))
+ captions_train = sorted( annoinfo['annotations'], key=lambda x: x['id'])
+ video2caps = defaultdict(list)
+ for data in captions_train:
+ video2caps[data['image_id']].append(data['caption'])
+
+ for videoid, caps in video2caps.items():
+ datalist.append(
+ {
+ 'video_id': videoid,
+ 'caption': caps
+ }
+ )
+ return datalist
+
+ def _timesformer_transform(self):
+ transforms = [
+ Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])),
+ ]
+ if self.test_mode:
+ test_scale = self.cfg.MODEL.IMG_INPUT_SIZE
+ transforms.extend([
+ Lambda(lambda frames: random_short_side_scale_jitter(
+ frames, test_scale, test_scale)[0]),
+ CenterCrop(test_scale),
+ # Lambda(lambda images: torch.stack([uniform_crop(images, 224, i)[0] for i in range(3)], 0))
+ ])
+ else:
+ min_scale = int((256 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+ max_scale = int((320 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+
+ transforms.extend([
+ Lambda(lambda frames: random_short_side_scale_jitter(frames, min_scale, max_scale)[0].unsqueeze(0)),
+ RandomHorizontalFlip(),
+ RandomCrop(self.cfg.MODEL.IMG_INPUT_SIZE)
+ ])
+ transforms.append(
+ # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ # change to imagenet default value to keep consistency with pretrained parameters
+ # Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ )
+ return Compose(transforms)
+
+ def _sample_frame(self, atten_feats):
+ interval = atten_feats.shape[0] / self.max_feat_num
+ selected_indexes = [int(i * interval) for i in range(self.max_feat_num)]
+ selected_frames = atten_feats[selected_indexes, :]
+ return selected_frames
+
+ def random_word_wwm(self, tokens):
+ output_tokens = []
+ output_label = []
+
+ for i, token in enumerate(tokens):
+ if self.use_clip_tokenizer:
+ sub_tokens = self.tokenizer.encode_basic_tokenized_token(token)
+ else:
+ sub_tokens = self.tokenizer.wordpiece_tokenizer.tokenize(token)
+ prob = random.random()
+ # mask token with 15% probability
+ if prob < 0.15:
+ prob /= 0.15
+
+ # 80% randomly change token to mask token
+ if prob < 0.8:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(self.tokenizer.encoder["<|spe|>"])
+ else:
+ output_tokens.append("[MASK]")
+ # 10% randomly change token to random token
+ elif prob < 0.9:
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_tokens.append(random.choice(list(range(len(self.tokenizer.encoder)))))
+ else:
+ output_tokens.append(random.choice(list(self.tokenizer.vocab.keys())))
+ # -> rest 10% randomly keep current token
+ else:
+ for sub_token in sub_tokens:
+ output_tokens.append(sub_token)
+
+ # append current token to output (we will predict these later)
+ for sub_token in sub_tokens:
+ if self.use_clip_tokenizer:
+ output_label.append(sub_token)
+ else:
+ try:
+ output_label.append(self.tokenizer.vocab[sub_token])
+ except KeyError:
+ # For unknown words (should not occur with BPE vocab)
+ output_label.append(self.tokenizer.vocab["[UNK]"])
+ else:
+ for sub_token in sub_tokens:
+ # no masking token (will be ignored by loss function later)
+ output_tokens.append(sub_token)
+ output_label.append(-1)
+
+ # if no word masked, random choose a word to mask
+ # if all([l_ == -1 for l_ in output_label]):
+ # choosed = random.randrange(0, len(output_label))
+ # output_label[choosed] = self.tokenizer.vocab[tokens[choosed]]
+
+ return output_tokens, output_label
+
+
+ def __getitem__(self, idx):
+
+ for i_try in range(100):
+ try:
+ record = self.video_list[idx].as_py()
+ record = copy.deepcopy(record)
+ video_id = record['video_id']
+ # load video
+
+ video_path = os.path.join(self.feats_folder, self.idx2name[video_id] + '.avi')
+ if self.use_ceph:
+ container = av.open(io.BytesIO(self.tcs_loader.client.get(video_path)))
+ else:
+ container = av.open(video_path)
+
+
+ # container.streams.video[0].thread_type = "AUTO"
+ stream = container.streams.video[0]
+ total_frames = stream.frames
+ fps = float(container.streams.video[0].average_rate)
+
+ if total_frames == 0:
+ # it returns 0 if not know, but that doesn't mean the video is null
+ for frame in container.decode(stream):
+ total_frames += 1
+ container.close()
+ container = av.open(video_path)
+ stream = container.streams.video[0]
+ except Exception as e:
+ print(
+ "Failed to load video from {} with error {} ; trial {}".format(
+ video_path, e, i_try
+ )
+ )
+
+ # let's try another one
+ index = random.randint(0, len(self.data_list) - 1)
+ record = self.data_list[index]
+ continue
+
+ if self.stage=='train':
+ indices = [self._sample_indices(total_frames, fps)]
+ else:
+ indices = self._get_val_indices(total_frames, fps)
+
+ all_index = set()
+ for index in indices:
+ all_index = all_index.union(set(index))
+
+ start_index = min(all_index)
+ num_frames = len(all_index)
+
+ images = dict()
+
+ fetched = 0
+
+ for frame in container.decode(stream):
+ if frame.index not in all_index or frame.index in images:
+ continue
+ images[frame.index] = frame.to_rgb().to_image()
+ last = frame.index
+ fetched += 1
+ if fetched == num_frames:
+ break
+
+ container.close()
+
+ video_data = list()
+ for ind in indices:
+ seq = list()
+ for i in ind:
+ if i in images:
+ seq.append(images[i])
+ else:
+ seq.append(images[last])
+ video_data.append(self.transform(seq))
+ video_data = torch.cat(video_data, dim=0)
+
+ if video_data.dim() == 4:
+ video_data.unsqueeze_(0) # in case there is only one frame
+
+ ret = {
+ 'input_sample': [{
+ 'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input',
+ 'sample_info':{
+ 'id': video_id,
+ 'path': video_path,
+ 'num_views':num_frames,
+ 'cat_along_first_dim': True,
+ }
+ }]
+ }
+
+ if self.stage == 'train' and record['caption'] is not None:
+ caption = record['caption']
+ caption = caption + " <|endoftext|>"
+
+ if self.task_type == 'video_mlm':
+ u_mask_type = 1
+ elif self.task_type == 'video_caption':
+ u_mask_type = 0 # causal mask
+
+ if self.task_type=='video_caption' or self.task_type =='video_mlm':
+ if u_mask_type == 1: # mlm
+ caption_tokens = self.tokenizer.basic_tokenize(caption)
+ caption_tokens, mlm_labels = self.random_word_wwm(caption_tokens)
+ else:
+ # caption
+ caption_tokens = self.tokenizer.encode(caption)
+ mlm_labels = self.tokenizer.encode("<|spe|>")*len(caption_tokens)
+
+ else:
+ caption_tokens = self.tokenizer.encode(caption)
+
+
+ if len(caption_tokens) > self.max_seq_len:
+ # mlm task
+ text_len_keep = self.max_seq_len
+ caption_tokens = caption_tokens[:(text_len_keep - 1)] + [caption_tokens[-1]]
+ if self.task_type == 'video_caption' or self.task_type == 'video_mlm':
+ mlm_labels = mlm_labels[:(text_len_keep - 1)] + [mlm_labels[-1]]
+
+ if self.task_type == 'video_caption':
+
+ source = np.array(caption_tokens, dtype=np.int64)
+ source2 = np.array(mlm_labels, dtype=np.int64)
+ ret['input_sample'].append({
+ 'data': [source, source2],
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'text_spe_cat': True,
+ }
+ })
+ ret.update({
+ 'target_sample': [],
+ 'target_idx' : [np.array(caption_tokens, dtype=np.int64)],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+
+ elif self.task_type == 'video_mlm':
+
+ raise NotImplementedError('no needed for masked language modeling when given video now.')
+
+
+ elif self.task_type == 'video_retrieval':
+ ret.update({
+ 'target_sample': [{
+ 'data' : [np.array(caption_tokens, dtype=np.int64)],
+ 'modality' : 'text',
+ 'data_type' : 'target',
+ 'invalid_mask': None,
+ 'sample_info' : {}
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ else:
+ raise NotImplementedError
+
+ elif self.stage != 'train':
+ if self.task_type == 'video_caption':
+ ret.update({
+ 'target_set': copy.deepcopy(self.target_set),
+ 'target_sample': [],
+ 'target_idx': [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+ elif self.task_type=='video_retrieval':
+ captions = [caption + " <|endoftext|>" for caption in record['caption']]
+ caption_tokens_raw = [ self.tokenizer.encode(caption) for caption in captions]
+
+ caption_tokens = [ caption_token[:(self.max_seq_len - 1)] + [caption_token[-1]]
+ if len(caption_token) > self.max_seq_len else caption_token
+ for caption_token in caption_tokens_raw ]
+ ret.update(
+ {
+ 'input_sample': [{
+ 'data' : video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input',
+ 'sample_info' : {
+ 'id' : (video_id, [video_id] * len(caption_tokens)),
+ 'path' : video_path,
+ 'num_views':num_frames,
+ 'cat_along_first_dim': True,
+ }
+ }],
+ 'target_sample': [{
+ 'data' : [np.array(caption_token, dtype=np.int64)
+ for caption_token in caption_tokens],
+ 'modality' : 'text',
+ 'invalid_mask': None,
+ 'data_type' : 'target',
+ 'sample_info' : {
+ 'sample_alone': True,
+ }
+
+ }],
+ 'target_idx' : [],
+ 'target_set' : [],
+ 'task_info' : copy.deepcopy(self.task_info)
+ }
+ )
+ else:
+ raise NotImplementedError
+
+
+
+ dict_as_tensor(ret)
+ return ret
diff --git a/uniperceiver/datasets/task_dataset/video_raw.py b/uniperceiver/datasets/task_dataset/video_raw.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e0aa8345d21b4ebfb5fd8f0ec52e679ff4308fd
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/video_raw.py
@@ -0,0 +1,426 @@
+import os
+import random
+import numpy as np
+import torch
+import pickle
+from PIL import Image
+import torch.utils.data as data
+import torch.nn.functional as F
+from torchvision.transforms import Compose, RandomApply, ToTensor, Normalize, CenterCrop, Lambda, RandomHorizontalFlip, ColorJitter, Resize, RandomCrop
+import json
+import av
+from torchvision.transforms.transforms import RandomResizedCrop
+from uniperceiver.tokenization import ClipTokenizer
+from uniperceiver.config import configurable
+from ..build import DATASETS_REGISTRY
+from uniperceiver.functional import dict_as_tensor
+from .video_transform import random_short_side_scale_jitter, uniform_crop
+import pyarrow as pa
+from uniperceiver.utils import comm
+import copy
+
+import io
+
+__all__ = ["VideoDataSet", "random_clip"]
+
+
+def load_pkl_file(filepath):
+ return pickle.load(open(filepath, 'rb'), encoding='bytes') if len(filepath) > 0 else None
+
+
+def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False):
+ """
+ Args:
+ video_frames (int): total frame number of a video
+ sampling_rate (int): sampling rate for clip, pick one every k frames
+ frames_per_clip (int): number of frames of a clip
+ fixed_offset (bool): used with sample offset to decide the offset value deterministically.
+ Returns:
+ list[int]: frame indices (started from zero)
+ """
+ new_sampling_rate = sampling_rate
+ highest_idx = video_frames - int(new_sampling_rate * (frames_per_clip - 1) + 1)
+ if highest_idx <= 0:
+ random_offset = 0
+ else:
+ if fixed_offset:
+ random_offset = (video_frames - int(new_sampling_rate * frames_per_clip)) // 2
+ else:
+ random_offset = int(np.random.randint(0, highest_idx, 1))
+ frame_idx = [int(random_offset + int(i * sampling_rate)) % video_frames for i in range(frames_per_clip)]
+ frame_idx = [x for x in frame_idx if x < video_frames]
+ return frame_idx
+
+@DATASETS_REGISTRY.register()
+class VideoDataSet(data.Dataset):
+
+ @configurable
+ def __init__(self, cfg, stage, root_path, s3_path, list_file, category_file, use_ceph, tcs_conf_path,
+ tokenizer, tokenizer_name, data_percentage,
+ frames_per_clip=64, interval=4, num_clips=1,
+ is_train=True, test_mode=False, num_classes=None, target_fps=30, timesformer_aug=False, minibatches=1):
+ """
+ Args:
+ root_path (str): the file path to the root of video folder
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
+ frames_per_clip (int): number of frames per data sample
+ interval (int): interval between frames
+ is_train (bool): shuffle the video but keep the causality
+ test_mode (bool): testing mode, no label
+ """
+
+ self.cfg = cfg
+ self.stage = stage
+ self.root_path = root_path
+ self.s3_path = s3_path
+ self.list_file = list_file
+ self.category_file = category_file
+ self.frames_per_clip = frames_per_clip
+ self.interval = interval
+ self.num_clips = num_clips
+ self.is_train = is_train
+ self.test_mode = test_mode
+ self.num_classes = num_classes
+ self.target_fps = target_fps
+ self.minibatches = minibatches
+ self.data_percentage = data_percentage
+
+ # self.class_names = class_names if (class_names is not None) else None
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+
+ self.transform = self._timesformer_transform() if timesformer_aug else self._transform()
+
+ self.use_ceph = use_ceph
+ if self.use_ceph:
+ # get dataset
+ # dataset_name = self.root_path.split('/')[-2]
+ self.data_path = self.s3_path
+ print('debug info for {} {} '.format(self.cfg.DATASETS.DATASET_NAME, self.data_path))
+ from uniperceiver.datasets import TCSLoader
+
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+ else:
+ self.data_path = self.root_path
+
+ _temp_list =self.load_data(self.cfg)
+ self.video_list = pa.array(_temp_list)
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.cfg.DATASETS.DATASET_NAME} with task {self.cfg.DATASETS.TASK_TYPE}:")
+ print('!!! length of _temp_list: ', len(_temp_list))
+ print('!!! size of _temp_list: ', sys.getsizeof(_temp_list))
+ print('!!! size of pa database: ', sys.getsizeof(self.video_list))
+ del _temp_list
+
+ self.testing_multi_view = self.cfg.DATALOADER.get('MULTI_VEIW', 'v0')
+ self.temporal_num_view = self.cfg.DATALOADER.get('MULTI_VEIW_NUM', 1)
+
+ self.random_stride = self.cfg.DATALOADER.get('RANDON_STRIDE', False)
+
+ if self.test_mode:
+ self.frames_per_clip = int(self.frames_per_clip*self.temporal_num_view)
+ self.interval = int(self.interval/self.temporal_num_view)
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT,
+
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+
+ def _transform(self):
+ assert False, 'use timesformer augmentation'
+ transforms = [
+ Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])),
+ ]
+ if self.test_mode:
+ transforms.extend([
+ RandomResizedCrop(224, scale=(0.75, 0.75), ratio=(1.0, 1.0)),
+ # CenterCrop(224)
+ # RandomApply(torch.nn.ModuleList([ColorJitter(0.4, 0.4, 0.4)]), 0.8),
+ ])
+ else:
+ transforms.extend([
+ # scale jitter as in vivit: (0.9, 1.33)
+ RandomResizedCrop(224, scale=(0.56, 0.95), ratio=(1.0, 1.0)),
+ RandomHorizontalFlip(),
+ # only p=0.8 is specified in vivit paper, using deit default parameters
+ RandomApply(torch.nn.ModuleList([ColorJitter(0.4, 0.4, 0.4)]), 0.8),
+ ])
+ transforms.append(
+ # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ # change to imagenet default value to keep consistency with pretrained parameters
+ Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ )
+ return Compose(transforms)
+
+ def _timesformer_transform(self):
+ transforms = [
+ Lambda(lambda frames: torch.stack([ToTensor()(frame.convert("RGB")) for frame in frames])),
+ ]
+ if self.test_mode:
+ test_scale = self.cfg.MODEL.IMG_INPUT_SIZE
+ transforms.extend([
+ Lambda(lambda frames: random_short_side_scale_jitter(frames, test_scale, test_scale)[0]),
+ Lambda(lambda images: torch.stack([uniform_crop(images, test_scale, i)[0] for i in range(3)], 0))
+ ])
+ else:
+ min_scale = int((256 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+ max_scale = int((320 / 224)*self.cfg.MODEL.IMG_INPUT_SIZE)
+ transforms.extend([
+ # Lambda(lambda frames: random_short_side_scale_jitter(frames, 256, 320)[0].unsqueeze(0)),
+ Lambda(lambda frames: random_short_side_scale_jitter(frames, min_scale, max_scale)[0].unsqueeze(0)),
+ RandomHorizontalFlip(),
+ RandomCrop(self.cfg.MODEL.IMG_INPUT_SIZE)
+ ])
+ transforms.append(
+ # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ # change to imagenet default value to keep consistency with pretrained parameters
+ # Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
+ Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
+ )
+ return Compose(transforms)
+
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+ ret = {
+ "cfg": cfg,
+ "stage": stage,
+ "list_file": os.path.join(cfg.DATALOADER.ANNO_FOLDER, cfg.DATALOADER.ANNO_FILE),
+ "category_file": os.path.join(cfg.DATALOADER.ANNO_FOLDER, "category_mapping.txt"),
+ "root_path": os.path.join(cfg.DATALOADER.FEATS_FOLDER, "training" if stage == "train" else "validation"),
+ "s3_path": os.path.join(cfg.DATALOADER.S3_PATH, "training" if stage == "train" else "validation"),
+ "frames_per_clip": cfg.DATALOADER.FRAMES_PER_CLIP,
+ "interval": cfg.DATALOADER.STRIDE,
+ "num_clips": 1 if stage == 'train' else cfg.INFERENCE.NUM_VIEWS,
+ "is_train": stage == 'train',
+ "test_mode": stage != 'train',
+ "num_classes": cfg.MODEL.NUM_CLASSES,
+ "timesformer_aug": cfg.DATALOADER.TIMESFORMER_AUG,
+ "minibatches": cfg.DATALOADER.MINI_BATCHES,
+ "use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "tcs_conf_path": tcs_conf_path,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ }
+
+
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+
+
+ return ret
+
+ def load_data(self, cfg):
+ # usualy it is [video_id, num_frames, class_idx]
+ # or [video_id, start_frame, end_frame, list of class_idx]
+ self.cls2idx = dict()
+ self.idx2cls = dict()
+ self.class_names = list()
+ with open(self.category_file, 'r') as f:
+ for line in f.readlines():
+ class_name, idx = line.strip().split('\t')
+ # for annotations
+ class_name = class_name.replace(" ", "_") # replace(" ", "_") for kinetics dataset
+ self.cls2idx[class_name] = int(idx)
+ self.idx2cls[int(idx)] = class_name
+ # processed_name = class_name.replace("_", " ").lower()
+ # if cfg.NAME in ["K700", "K400"]:
+ # processed_name = processed_name.replace("american football", "football").replace("(", "").replace(")", "")
+ # self.class_names.append(processed_name)
+
+ # self.class_name_tokens = [np.array(self.tokenizer.encode(x + " <|endoftext|>"), dtype=np.int64) for x in self.class_names]
+
+ # self.class_name_type_tokens = [np.zeros(len(x), dtype=np.int64) for x in self.class_name_tokens]
+
+ # load the exclude list
+ # TODO: move this to the config file
+ exclude_list = list()
+ if os.path.exists(os.path.join(os.path.dirname(self.list_file), "exclude_list.txt")):
+ with open(os.path.join(os.path.dirname(self.list_file), "exclude_list.txt"), 'r') as f:
+ exclude_list = list(f)
+ exclude_list = [t.strip() for t in exclude_list]
+
+ video_list = []
+ count = 0
+ with open(self.list_file) as f:
+ data_file = json.load(f)
+ for name, info in data_file['database'].items():
+ # if count > 1000:
+ # break
+ # else:
+ # count =+ 1
+ video_path = os.path.join(self.data_path, info["annotations"]['label'], name+cfg.DATALOADER.FILE_EXTENSION)
+ # program will stop if there isn't an exclude list!
+ if os.path.basename(video_path) in exclude_list:
+ continue
+ if (self.is_train and info['subset'] == "training") or (not self.is_train and info['subset'] == "validation") :
+ inst = {
+ "video_path" : video_path,
+ "id": name
+ }
+ # if not self.test_mode:
+ label = info['annotations']['label']
+ inst["target_label"] = label
+ assert label in self.cls2idx
+ video_list.append(inst)
+
+ if self.is_train and self.data_percentage < 1.0:
+ video_dict = dict()
+ for video in video_list:
+ if video["target_label"] not in video_dict:
+ video_dict[video["target_label"]] = list()
+ video_dict[video["target_label"]].append(video)
+ new_list = list()
+ for k, v in video_dict.items():
+ new_list.extend(random.sample(v, k=int(len(v)*self.data_percentage)+1))
+ video_list = new_list
+
+ num = len(video_list)
+ print("The number of videos is {}".format(num), flush=True)
+ assert (num > 0)
+ return video_list
+
+ def _sample_indices(self, total_frames, fps):
+ """
+ Used for training.
+ Args:
+ - record (VideoRecord):
+ Returns:
+ list: frame index, index starts from 1.
+ """
+ if self.random_stride:
+ interval = random.sample([8, 16, 32], k=1)[0]
+ else:
+ interval = self.interval
+ frame_idx = np.asarray(random_clip(total_frames, interval * fps / self.target_fps , self.frames_per_clip))
+ return frame_idx
+
+ def _get_val_indices(self, total_frames, fps):
+ max_frame_idx = total_frames - 1
+ sample_pos = max(0, 1 + max_frame_idx - int(self.interval * fps / self.target_fps * self.frames_per_clip))
+ start_list = np.linspace(0, sample_pos - 1, num=self.num_clips, dtype=int)
+ frame_idx = []
+ for start_idx in start_list.tolist():
+ # ! changed by zhujinguo for torch.cat multi-views
+ ids = [int(idx * self.interval * fps / self.target_fps + start_idx)%total_frames for idx in range(self.frames_per_clip)]
+ ids = [x for x in ids if x < total_frames]
+ frame_idx.append(ids)
+ return frame_idx
+
+ def __getitem__(self, index):
+ for i_try in range(100):
+ try:
+ record = self.video_list[index].as_py()
+ if self.use_ceph:
+ container = av.open(io.BytesIO(self.tcs_loader.client.get(record["video_path"])))
+ else:
+ container = av.open(record["video_path"])
+ # container.streams.video[0].thread_type = "AUTO"
+ stream = container.streams.video[0]
+ total_frames = stream.frames
+ fps = float(container.streams.video[0].average_rate)
+
+ if total_frames == 0:
+ # it returns 0 if not know, but that doesn't mean the video is null
+ for frame in container.decode(stream):
+ total_frames += 1
+ container.close()
+ container = av.open(record["video_path"])
+ stream = container.streams.video[0]
+ except Exception as e:
+ print(
+ "Failed to load video from {} with error {} ; trial {}".format(
+ record["video_path"], e, i_try
+ )
+ )
+
+ # let's try another one
+ index = random.randint(0, len(self.video_list) - 1)
+ continue
+
+
+ if self.is_train:
+ indices = [self._sample_indices(total_frames, fps)]
+ else:
+ indices = self._get_val_indices(total_frames, fps)
+
+ all_index = set()
+ for index in indices:
+ all_index = all_index.union(set(index))
+
+ start_index = min(all_index)
+ num_frames = len(all_index)
+
+ images = dict()
+
+ fetched = 0
+
+ for frame in container.decode(stream):
+ if frame.index not in all_index or frame.index in images:
+ continue
+ images[frame.index] = frame.to_rgb().to_image()
+ last = frame.index
+ fetched += 1
+ if fetched == num_frames:
+ break
+
+ container.close()
+
+ video_data = list()
+ for ind in indices:
+ seq = list()
+ for i in ind:
+ if i in images:
+ seq.append(images[i])
+ else:
+ seq.append(images[last])
+ video_data.append(self.transform(seq))
+ video_data = torch.cat(video_data, dim=0)
+ # num_views, num_frames, 3, 224, 224
+ if not self.is_train:
+ if self.testing_multi_view == 'v1' and self.temporal_num_view > 1:
+ video_data = video_data.reshape(video_data.shape[0] * self.temporal_num_view, -1, *video_data.shape[-3:])
+ num_frames = num_frames // self.temporal_num_view
+ elif self.testing_multi_view == 'v2' and self.temporal_num_view > 1:
+ video_data = video_data.reshape(video_data.shape[0], -1, self.temporal_num_view,
+ *video_data.shape[-3:]).transpose(1, 2).reshape(video_data.shape[0] * self.temporal_num_view, -1,
+ *video_data.shape[-3:])
+ num_frames = num_frames // self.temporal_num_view
+
+
+ ret = {
+ 'input_sample':[
+ {
+ 'data': video_data, 'invalid_mask': None, 'modality': 'video', 'data_type': 'input',
+ 'sample_info':{
+ 'id': record['id'],
+ 'path': record['video_path'],
+ 'num_frames': num_frames,
+ 'num_views': video_data.shape[0],
+ 'cat_along_first_dim': True,
+ }
+ }
+ ],
+ 'target_sample': [],
+ 'target_idx': [self.cls2idx[record['target_label']]],
+ 'target_set': copy.deepcopy(self.target_set),
+ 'task_info': copy.deepcopy(self.task_info)
+
+ }
+
+ # dict_as_tensor(ret)
+ return ret
+
+ def __len__(self):
+ return len(self.video_list)
diff --git a/uniperceiver/datasets/task_dataset/video_transform.py b/uniperceiver/datasets/task_dataset/video_transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fac7a4328dae66a3652d044bc8f57e4844df102
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/video_transform.py
@@ -0,0 +1,671 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+
+import math
+import numpy as np
+import torch
+
+import numpy as np
+from PIL import Image
+# pytorch=1.7.1
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+# pip install opencv-python
+import cv2
+import random
+try:
+ import ffmpeg
+except:
+ pass
+import av
+import math
+
+
+
+def random_short_side_scale_jitter(
+ images, min_size, max_size, boxes=None, inverse_uniform_sampling=False
+):
+ """
+ Perform a spatial short scale jittering on the given images and
+ corresponding boxes.
+ Args:
+ images (tensor): images to perform scale jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ min_size (int): the minimal size to scale the frames.
+ max_size (int): the maximal size to scale the frames.
+ boxes (ndarray): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ inverse_uniform_sampling (bool): if True, sample uniformly in
+ [1 / max_scale, 1 / min_scale] and take a reciprocal to get the
+ scale. If False, take a uniform sample from [min_scale, max_scale].
+ Returns:
+ (tensor): the scaled images with dimension of
+ `num frames` x `channel` x `new height` x `new width`.
+ (ndarray or None): the scaled boxes with dimension of
+ `num boxes` x 4.
+ """
+ if inverse_uniform_sampling:
+ size = int(
+ round(1.0 / np.random.uniform(1.0 / max_size, 1.0 / min_size))
+ )
+ else:
+ size = int(round(np.random.uniform(min_size, max_size)))
+
+ height = images.shape[2]
+ width = images.shape[3]
+ if (width <= height and width == size) or (
+ height <= width and height == size
+ ):
+ return images, boxes
+ new_width = size
+ new_height = size
+ if width < height:
+ new_height = int(math.floor((float(height) / width) * size))
+ if boxes is not None:
+ boxes = boxes * float(new_height) / height
+ else:
+ new_width = int(math.floor((float(width) / height) * size))
+ if boxes is not None:
+ boxes = boxes * float(new_width) / width
+
+ return (
+ torch.nn.functional.interpolate(
+ images,
+ size=(new_height, new_width),
+ mode="bilinear",
+ align_corners=False,
+ ),
+ boxes,
+ )
+
+
+def crop_boxes(boxes, x_offset, y_offset):
+ """
+ Peform crop on the bounding boxes given the offsets.
+ Args:
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
+ is `num boxes` x 4.
+ x_offset (int): cropping offset in the x axis.
+ y_offset (int): cropping offset in the y axis.
+ Returns:
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ cropped_boxes = boxes.copy()
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
+
+ return cropped_boxes
+
+
+def random_crop(images, size, boxes=None):
+ """
+ Perform random spatial crop on the given images and corresponding boxes.
+ Args:
+ images (tensor): images to perform random crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): the size of height and width to crop on the image.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): cropped images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ if images.shape[2] == size and images.shape[3] == size:
+ return images, None
+ height = images.shape[2]
+ width = images.shape[3]
+ y_offset = 0
+ if height > size:
+ y_offset = int(np.random.randint(0, height - size))
+ x_offset = 0
+ if width > size:
+ x_offset = int(np.random.randint(0, width - size))
+ cropped = images[
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
+ ]
+
+ cropped_boxes = (
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ )
+
+ return cropped, cropped_boxes
+
+
+def horizontal_flip(prob, images, boxes=None):
+ """
+ Perform horizontal flip on the given images and corresponding boxes.
+ Args:
+ prob (float): probility to flip the images.
+ images (tensor): images to perform horizontal flip, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ images (tensor): images with dimension of
+ `num frames` x `channel` x `height` x `width`.
+ flipped_boxes (ndarray or None): the flipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ if boxes is None:
+ flipped_boxes = None
+ else:
+ flipped_boxes = boxes.copy()
+
+ if np.random.uniform() < prob:
+ images = images.flip((-1))
+
+ width = images.shape[3]
+ if boxes is not None:
+ flipped_boxes[:, [0, 2]] = width - boxes[:, [2, 0]] - 1
+
+ return images, flipped_boxes
+
+
+def uniform_crop(images, size, spatial_idx, boxes=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ height = images.shape[2]
+ width = images.shape[3]
+
+ y_offset = int(math.ceil((height - size) / 2))
+ x_offset = int(math.ceil((width - size) / 2))
+
+ if height > width:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 2:
+ y_offset = height - size
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 2:
+ x_offset = width - size
+ cropped = images[
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
+ ]
+
+ cropped_boxes = (
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ )
+
+ return cropped, cropped_boxes
+
+
+def uniform_crop_2crops(images, size, spatial_idx, boxes=None):
+ """
+ Perform uniform spatial sampling on the images and corresponding boxes.
+ Args:
+ images (tensor): images to perform uniform crop. The dimension is
+ `num frames` x `channel` x `height` x `width`.
+ size (int): size of height and weight to crop the images.
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
+ crop if height is larger than width.
+ boxes (ndarray or None): optional. Corresponding boxes to images.
+ Dimension is `num boxes` x 4.
+ Returns:
+ cropped (tensor): images with dimension of
+ `num frames` x `channel` x `size` x `size`.
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
+ `num boxes` x 4.
+ """
+ assert spatial_idx in [0, 1, 2]
+ height = images.shape[2]
+ width = images.shape[3]
+
+
+ if height > width:
+ x_offset = 0
+ if height > size * 2:
+ if spatial_idx == 0:
+ y_offset = int((height - size * 2) // 2)
+ elif spatial_idx == 1:
+ y_offset = int(height - size - ((height - size * 2) // 2))
+ else:
+ if spatial_idx == 0:
+ y_offset = 0
+ elif spatial_idx == 1:
+ y_offset = height - size
+ else:
+ y_offset = 0
+ if width > size * 2:
+ if spatial_idx == 0:
+ x_offset = int((width - size * 2) // 2)
+ elif spatial_idx == 1:
+ x_offset = int(width - size - ((width - size * 2) // 2))
+ else:
+ if spatial_idx == 0:
+ x_offset = 0
+ elif spatial_idx == 1:
+ x_offset = width - size
+
+ cropped = images[
+ :, :, y_offset : y_offset + size, x_offset : x_offset + size
+ ]
+
+ cropped_boxes = (
+ crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
+ )
+
+ return cropped, cropped_boxes
+
+def clip_boxes_to_image(boxes, height, width):
+ """
+ Clip an array of boxes to an image with the given height and width.
+ Args:
+ boxes (ndarray): bounding boxes to perform clipping.
+ Dimension is `num boxes` x 4.
+ height (int): given image height.
+ width (int): given image width.
+ Returns:
+ clipped_boxes (ndarray): the clipped boxes with dimension of
+ `num boxes` x 4.
+ """
+ clipped_boxes = boxes.copy()
+ clipped_boxes[:, [0, 2]] = np.minimum(
+ width - 1.0, np.maximum(0.0, boxes[:, [0, 2]])
+ )
+ clipped_boxes[:, [1, 3]] = np.minimum(
+ height - 1.0, np.maximum(0.0, boxes[:, [1, 3]])
+ )
+ return clipped_boxes
+
+
+def blend(images1, images2, alpha):
+ """
+ Blend two images with a given weight alpha.
+ Args:
+ images1 (tensor): the first images to be blended, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ images2 (tensor): the second images to be blended, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ alpha (float): the blending weight.
+ Returns:
+ (tensor): blended images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ return images1 * alpha + images2 * (1 - alpha)
+
+
+def grayscale(images):
+ """
+ Get the grayscale for the input images. The channels of images should be
+ in order BGR.
+ Args:
+ images (tensor): the input images for getting grayscale. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ img_gray (tensor): blended images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ # R -> 0.299, G -> 0.587, B -> 0.114.
+ img_gray = torch.tensor(images)
+ gray_channel = (
+ 0.299 * images[:, 2] + 0.587 * images[:, 1] + 0.114 * images[:, 0]
+ )
+ img_gray[:, 0] = gray_channel
+ img_gray[:, 1] = gray_channel
+ img_gray[:, 2] = gray_channel
+ return img_gray
+
+
+def color_jitter(images, img_brightness=0, img_contrast=0, img_saturation=0):
+ """
+ Perfrom a color jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ img_brightness (float): jitter ratio for brightness.
+ img_contrast (float): jitter ratio for contrast.
+ img_saturation (float): jitter ratio for saturation.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+
+ jitter = []
+ if img_brightness != 0:
+ jitter.append("brightness")
+ if img_contrast != 0:
+ jitter.append("contrast")
+ if img_saturation != 0:
+ jitter.append("saturation")
+
+ if len(jitter) > 0:
+ order = np.random.permutation(np.arange(len(jitter)))
+ for idx in range(0, len(jitter)):
+ if jitter[order[idx]] == "brightness":
+ images = brightness_jitter(img_brightness, images)
+ elif jitter[order[idx]] == "contrast":
+ images = contrast_jitter(img_contrast, images)
+ elif jitter[order[idx]] == "saturation":
+ images = saturation_jitter(img_saturation, images)
+ return images
+
+
+def brightness_jitter(var, images):
+ """
+ Perfrom brightness jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for brightness.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+
+ img_bright = torch.zeros(images.shape)
+ images = blend(images, img_bright, alpha)
+ return images
+
+
+def contrast_jitter(var, images):
+ """
+ Perfrom contrast jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for contrast.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+
+ img_gray = grayscale(images)
+ img_gray[:] = torch.mean(img_gray, dim=(1, 2, 3), keepdim=True)
+ images = blend(images, img_gray, alpha)
+ return images
+
+
+def saturation_jitter(var, images):
+ """
+ Perfrom saturation jittering on the input images. The channels of images
+ should be in order BGR.
+ Args:
+ var (float): jitter ratio for saturation.
+ images (tensor): images to perform color jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ Returns:
+ images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ alpha = 1.0 + np.random.uniform(-var, var)
+ img_gray = grayscale(images)
+ images = blend(images, img_gray, alpha)
+
+ return images
+
+
+def lighting_jitter(images, alphastd, eigval, eigvec):
+ """
+ Perform AlexNet-style PCA jitter on the given images.
+ Args:
+ images (tensor): images to perform lighting jitter. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ alphastd (float): jitter ratio for PCA jitter.
+ eigval (list): eigenvalues for PCA jitter.
+ eigvec (list[list]): eigenvectors for PCA jitter.
+ Returns:
+ out_images (tensor): the jittered images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ if alphastd == 0:
+ return images
+ # generate alpha1, alpha2, alpha3.
+ alpha = np.random.normal(0, alphastd, size=(1, 3))
+ eig_vec = np.array(eigvec)
+ eig_val = np.reshape(eigval, (1, 3))
+ rgb = np.sum(
+ eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0),
+ axis=1,
+ )
+ out_images = torch.zeros_like(images)
+ for idx in range(images.shape[1]):
+ out_images[:, idx] = images[:, idx] + rgb[2 - idx]
+
+ return out_images
+
+
+def color_normalization(images, mean, stddev):
+ """
+ Perform color nomration on the given images.
+ Args:
+ images (tensor): images to perform color normalization. Dimension is
+ `num frames` x `channel` x `height` x `width`.
+ mean (list): mean values for normalization.
+ stddev (list): standard deviations for normalization.
+ Returns:
+ out_images (tensor): the noramlized images, the dimension is
+ `num frames` x `channel` x `height` x `width`.
+ """
+ assert len(mean) == images.shape[1], "channel mean not computed properly"
+ assert (
+ len(stddev) == images.shape[1]
+ ), "channel stddev not computed properly"
+
+ out_images = torch.zeros_like(images)
+ for idx in range(len(mean)):
+ out_images[:, idx] = (images[:, idx] - mean[idx]) / stddev[idx]
+
+ return out_images
+
+
+
+
+
+class RawVideoExtractorCV2():
+ def __init__(self, centercrop=False, size=224, framerate=-1, ):
+ self.centercrop = centercrop
+ self.size = size
+ self.framerate = framerate
+ self.transform = self._transform(self.size)
+ # Normalize((0.48145466, 0.4578275, 0.40821073), ),
+ self.mean = np.array((0.48145466, 0.4578275, 0.40821073)).reshape(1, 1, 1, 3) * 255
+ self.std = np.array((0.26862954, 0.26130258, 0.27577711)).reshape(1, 1, 1, 3) * 255
+
+ def _transform(self, n_px):
+ return Compose([
+ Resize(n_px, interpolation=Image.BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+ def video_to_tensor(self, video_file, preprocess, sample_fp=0, num_frames=50, sample_offset=0, start_time=None, end_time=None, impl="pyav"):
+ if start_time is not None or end_time is not None:
+ assert isinstance(start_time, int) and isinstance(end_time, int) \
+ and start_time > -1 and end_time > start_time
+ # assert sample_fp > -1
+
+ if impl == "cv2":
+ # Samples a frame sample_fp X frames.
+ cap = cv2.VideoCapture(video_file)
+ frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
+
+ total_duration = (frameCount + fps - 1) // fps
+ start_sec, end_sec = 0, total_duration
+
+ if start_time is not None:
+ start_sec, end_sec = start_time, end_time if end_time <= total_duration else total_duration
+ cap.set(cv2.CAP_PROP_POS_FRAMES, int(start_time * fps))
+
+ ret = True
+ images, included = [], []
+
+
+ if sample_fp > -1:
+
+ if impl == "cv2":
+
+ # sample by fixed interval
+ interval = 1
+ if sample_fp > 0:
+ interval = fps // sample_fp
+ else:
+ sample_fp = fps
+ if interval == 0: interval = 1
+
+ inds = [ind for ind in np.arange(0, fps, interval)]
+ assert len(inds) >= sample_fp
+ inds = inds[:sample_fp]
+
+ offset = min(sample_offset, interval - 1) if sample_offset > 0 else random.randint(0, interval - 1)
+ for sec in np.arange(start_sec, end_sec + 1):
+ if not ret: break
+ # sec_base = int(sec * fps)
+ sec_base = int(sec * fps + offset)
+ for ind in inds:
+ cap.set(cv2.CAP_PROP_POS_FRAMES, sec_base + ind)
+ ret, frame = cap.read()
+ if not ret: break
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB")))
+
+ if len(images) > 0:
+ video_data = torch.tensor(np.stack(images))
+ else:
+ video_data = torch.zeros(1)
+ cap.release()
+
+ elif impl == "ffmpeg":
+
+ if sample_fp == 0:
+ sample_fp = 1000 # sample every frame
+
+ out, _ = (
+ ffmpeg
+ .input(video_file)
+ .filter('select', 'isnan(prev_selected_t)+gte(t-prev_selected_t,{})'.format(1 / sample_fp))
+ .filter('crop', 'min(in_w, in_h)', 'min(in_w, in_h)', '(in_w - min(in_w, in_h)) / 2', '(in_h - min(in_w, in_h)) / 2') # w, h, x, y, center crop
+ .filter('scale', self.size, self.size) # resize
+ .output('pipe:', format='rawvideo', pix_fmt='rgb24', vsync='vfr')
+ .global_args('-loglevel', 'quiet')
+ .run(capture_stdout=True)
+ )
+ video = (
+ np
+ .frombuffer(out, np.uint8)
+ .reshape([-1, self.size, self.size, 3])
+ )
+
+ video = (video - self.mean) / self.std
+ video_data = torch.as_tensor(video).permute(0, 3, 1, 2)
+
+ elif impl == 'pyav':
+ images = list()
+ container = av.open(video_file)
+ container.streams.video[0].thread_type = "AUTO"
+ stream = container.streams.video[0]
+ total_frames = stream.frames
+ assert total_frames != 0
+ duration = int(stream.duration * stream.time_base)
+
+ if sample_fp > 0:
+ interval = max(int(total_frames / duration / sample_fp), 1)
+ else:
+ interval = 1
+ for frame in container.decode(stream):
+ if frame.index % interval != 0:
+ continue
+ images.append(preprocess(frame.to_rgb().to_image()))
+
+ if len(images) > 0:
+ video_data = torch.stack(images) # th.tensor(np.stack(images))
+ else:
+ video_data = torch.zeros(1)
+ container.close()
+
+ else:
+ raise NotImplementedError
+
+ else:
+ if impl == "cv2":
+ # sample fixed number of frames
+ interval = max(frameCount // num_frames, 1) # this interval is int
+ start = min(sample_offset, interval - 1) if sample_offset > -1 else random.randint(0, interval - 1)
+ interval = frameCount / num_frames # the second interval is float
+ for i in range(num_frames):
+ cap.set(cv2.CAP_PROP_POS_FRAMES, start + int(i * interval))
+ ret, frame = cap.read()
+ if not ret: break
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ images.append(preprocess(Image.fromarray(frame_rgb).convert("RGB")))
+
+ if len(images) > 0:
+ video_data = torch.tensor(np.stack(images))
+ else:
+ video_data = torch.zeros(1)
+ cap.release()
+ elif impl == "pyav":
+ images = list()
+ container = av.open(video_file)
+ container.streams.video[0].thread_type = "AUTO"
+ stream = container.streams.video[0]
+ total_frames = stream.frames
+ assert total_frames != 0
+ interval = max(total_frames // num_frames, 1) # this interval is int
+ for frame in container.decode(stream):
+ if frame.index % interval != 0:
+ continue
+ images.append(preprocess(frame.to_rgb().to_image()))
+
+ if len(images) > 0:
+ video_data = torch.stack(images) # th.tensor(np.stack(images))
+ else:
+ video_data = torch.zeros(1)
+ container.close()
+ else:
+ raise NotImplementedError
+
+ return {'video': video_data}
+
+ def get_video_data(self, video_path, num_frames, sample_offset, start_time=None, end_time=None):
+ image_input = self.video_to_tensor(video_path, self.transform, sample_fp=self.framerate, num_frames=num_frames, sample_offset=sample_offset, start_time=start_time, end_time=end_time)
+ return image_input
+
+ def process_raw_data(self, raw_video_data):
+ tensor_size = raw_video_data.size()
+ tensor = raw_video_data.view(-1, 1, tensor_size[-3], tensor_size[-2], tensor_size[-1])
+ return tensor
+
+ def process_frame_order(self, raw_video_data, frame_order=0):
+ # 0: ordinary order; 1: reverse order; 2: random order.
+ if frame_order == 0:
+ pass
+ elif frame_order == 1:
+ reverse_order = np.arange(raw_video_data.size(0) - 1, -1, -1)
+ raw_video_data = raw_video_data[reverse_order, ...]
+ elif frame_order == 2:
+ random_order = np.arange(raw_video_data.size(0))
+ np.random.shuffle(random_order)
+ raw_video_data = raw_video_data[random_order, ...]
+
+ return raw_video_data
+
+# An ordinary video frame extractor based CV2
+RawVideoExtractor = RawVideoExtractorCV2
diff --git a/uniperceiver/datasets/task_dataset/vqa.py b/uniperceiver/datasets/task_dataset/vqa.py
new file mode 100644
index 0000000000000000000000000000000000000000..37cad8e010e445c82d66964596fa47b3ddb30604
--- /dev/null
+++ b/uniperceiver/datasets/task_dataset/vqa.py
@@ -0,0 +1,460 @@
+import os
+import copy
+import pickle
+import random
+import json
+import glob
+from uniperceiver.utils import comm
+from numpy.random import choice
+import pyarrow as pa
+from PIL import Image
+from torchvision import transforms
+import numpy as np
+from uniperceiver.config import configurable
+from uniperceiver.functional import read_np, dict_as_tensor, boxes_to_locfeats
+from uniperceiver.tokenization import ClipTokenizer
+from ..build import DATASETS_REGISTRY
+import torch
+from uniperceiver.datasets.custom_transforms import clip_transforms
+
+__all__ = ["VQADataset"]
+
+memorycache = False
+try:
+ if "SLURM_JOB_ID" in os.environ:
+ import mc
+ import io
+ memorycache = True
+# print("VQA using memory cache")
+ else:
+ # print("missing memory cache")
+ pass
+except:
+ # print("missing memory cache")
+ pass
+
+@DATASETS_REGISTRY.register()
+class VQADataset:
+ @configurable
+ def __init__(
+ self,
+ cfg,
+ dataset_name,
+ task_type,
+ stage: str,
+ anno_folder: str,
+ ans2label_path: str,
+ label2ans_path: str,
+ feats_folder: str,
+ max_feat_num: int,
+ max_seq_len: int,
+ use_global_v: bool,
+ tokenizer,
+ tokenizer_name,
+ use_ceph,
+ transform,
+ as_gen,
+ inf_input,
+ single_class,
+ small_val,
+ block_vq,
+ data_percentage,
+ two_eot,
+ ):
+ self.stage = stage
+ self.anno_folder = anno_folder
+ self.ans2label = pickle.load(open(ans2label_path, "rb"))
+ self.label2ans = pickle.load(open(label2ans_path, "rb"))
+ self.feats_folder = feats_folder
+ self.max_feat_num = max_feat_num
+ self.max_seq_len = max_seq_len
+ self.use_global_v = use_global_v
+ self.tokenizer = tokenizer
+ self.tokenizer_name = tokenizer_name
+ self.num_labels = len(self.ans2label)
+ self.cfg = cfg
+ self.dataset_name = dataset_name
+ self.task_type = task_type
+
+ self.id2path = self.load_img_info(self.anno_folder)
+
+ self.initialized = False
+ self.transform = transform
+ self.as_gen = as_gen
+ self.inf_input = inf_input
+ self.single_class = single_class
+ self.small_val = small_val
+ self.block_vq = block_vq
+ self.data_percentage = data_percentage
+ self.two_eot = two_eot
+ # if as_retrieval:
+ if self.tokenizer_name == "clip":
+ self.mask_tokens = [tokenizer.encoder["<|spe|>"]]
+ else:
+ raise NotImplementedError
+ # remove the first null answer, we are not using complementay dataset
+ self.answer_tokens = self.tokenize_answer()
+ self.answer_type_tokens = [np.zeros(len(x), dtype=np.int64) for x in self.answer_tokens]
+
+ self.use_ceph = use_ceph
+ if self.use_ceph:
+ self.feats_folder = "s3://coco"
+ print('debug info for vqa {}'.format( self.feats_folder))
+ from uniperceiver.datasets import TCSLoader
+ if 'SLURM_PROCID' in os.environ:
+ tcs_conf_path = cfg.DATALOADER.get("TCS_CONF_PATH", "slurm_tools/petreloss.config")
+ else:
+ # dev machine
+ tcs_conf_path = "slurm_tools/petreloss_local.config"
+ self.tcs_loader = TCSLoader(tcs_conf_path)
+
+ self.load_data(self.cfg)
+
+ self.task_info = {
+ 'task_type' : self.cfg.DATASETS.TASK_TYPE,
+ 'dataset_name' : self.cfg.DATASETS.DATASET_NAME,
+ 'batch_size' : self.cfg.DATALOADER.TRAIN_BATCH_SIZE if self.stage == 'train' else self.cfg.DATALOADER.TEST_BATCH_SIZE,
+ 'sampling_weight': self.cfg.DATALOADER.SAMPLING_WEIGHT,
+ 'single_class' : self.cfg.DATALOADER.SINGLE_CLASS
+ }
+
+ def _init_memcached(self):
+ if not self.initialized:
+ server_list_config_file = "/mnt/cache/share/memcached_client/server_list.conf"
+ client_config_file = "/mnt/cache/share/memcached_client/client.conf"
+ self.mclient = mc.MemcachedClient.GetInstance(server_list_config_file, client_config_file)
+ self.initialized = True
+
+
+ @classmethod
+ def from_config(cls, cfg, stage: str = "train"):
+ ans2label_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_ans2label.pkl")
+ label2ans_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_label2ans.pkl")
+
+ feats_folder = cfg.DATALOADER.FEATS_FOLDER
+ # if stage == "test":
+ # feats_folder = feats_folder + "/test2015"
+
+ if getattr(cfg.DATALOADER, 'TRANSFORM', None) == 'clip_transforms':
+ transform = clip_transforms(stage, flip_prob=0.0)
+ else:
+ transform = transforms.Compose([
+ transforms.Resize([224, 224]),
+ transforms.ToTensor(),
+ transforms.Normalize((0.485, 0.456, 0.406),
+ (0.229, 0.224, 0.225))]
+ )
+
+ ret = {
+ 'cfg': cfg,
+ 'dataset_name': cfg.DATASETS.DATASET_NAME,
+ 'task_type': cfg.DATASETS.TASK_TYPE,
+ "stage": stage,
+ "anno_folder": cfg.DATALOADER.ANNO_FOLDER,
+ "ans2label_path": ans2label_path,
+ "label2ans_path": label2ans_path,
+ "feats_folder": feats_folder,
+ "max_feat_num": cfg.DATALOADER.MAX_FEAT_NUM,
+ "max_seq_len": cfg.MODEL.MAX_SEQ_LEN,
+ "use_global_v": cfg.DATALOADER.USE_GLOBAL_V,
+ "use_ceph": getattr(cfg.DATALOADER, 'USE_CEPH', False),
+ "transform": transform,
+ "as_gen": cfg.DATALOADER.DO_AS_GEN,
+ "inf_input": cfg.DATALOADER.VQA_INPUT,
+ "single_class": cfg.DATALOADER.SINGLE_CLASS,
+ "small_val": cfg.DATALOADER.SMALL_VAL,
+ "block_vq": cfg.DATALOADER.BLOCK_VQ,
+ "data_percentage": cfg.DATALOADER.DATA_PERCENTAGE,
+ "two_eot": cfg.DATALOADER.TWO_EOT,
+ }
+
+ ret['tokenizer'] = ClipTokenizer()
+ ret['tokenizer_name'] = "clip"
+
+ return ret
+
+ def load_img_info(self, anno_folder):
+ id2path = {}
+
+ coco_map = json.load(open(os.path.join(anno_folder, "coco_map.json")))
+ for k, v in coco_map.items():
+ id2path[int(k)] = v
+
+ return id2path
+
+ def load_data(self, cfg):
+
+ cache_path = os.path.join(
+ self.anno_folder, "cache",
+ "VQA_sep_%s_%s_%d%s.pkl" % (self.tokenizer_name, self.stage, self.max_seq_len, "_full_val" if self.stage == "val" and not self.small_val else "")
+ )
+ if not os.path.exists(os.path.dirname(cache_path)):
+ os.makedirs(os.path.dirname(cache_path))
+ if not os.path.exists(cache_path):
+ datalist = self.load_raw_data(cfg)
+ self.tokenize(datalist)
+ pickle.dump(datalist, open(cache_path, "wb"))
+ datalist = pickle.load(open(cache_path, "rb"))
+ if self.data_percentage < 1.0 and self.stage == "train":
+ labels2l = dict()
+ for data in datalist:
+ if not data['answer']['labels']:
+ continue
+ ans = data['answer']['labels'][0]
+ if ans not in labels2l:
+ labels2l[ans] = list()
+ labels2l[ans].append(data)
+ datalist = []
+ for v in labels2l.values():
+ datalist.extend(random.sample(v, k=int(self.data_percentage * len(v)+1)))
+
+ self.database = pa.array(datalist)
+ self.datalist = datalist
+
+ if comm.is_main_process():
+ import sys
+ print(f"!!! Dataset {self.dataset_name} with task {self.task_type}:")
+ print('!!! length of _temp_list: ', len(datalist))
+ print('!!! size of _temp_list: ', sys.getsizeof(datalist))
+ print('!!! size of pa database: ', sys.getsizeof(self.database))
+ del datalist
+
+
+ def tokenize(self, datalist):
+ for entry in datalist:
+ tokens = self.tokenizer.encode(entry["question"])
+ tokens = tokens[: self.max_seq_len - 2]
+ # tokens = self.tokenizer.add_special_tokens_single_sentence(tokens)
+ entry["question"] = tokens
+
+ def tokenize_answer(self):
+ output = list()
+ for answer in self.label2ans:
+ answer_tokens = self.tokenizer.encode(answer + " <|endoftext|>")
+ # answer_tokens = self.tokenizer.add_special_tokens_single_sentence(answer_tokens)
+ output.append(answer_tokens)
+ return output
+
+ def load_raw_data(self, cfg):
+ if self.stage == 'train': # trainval mode
+ question_path_train = os.path.join(self.anno_folder, "v2_OpenEnded_mscoco_train2014_questions.json")
+ questions_train = sorted(
+ json.load(open(question_path_train))["questions"],
+ key=lambda x: x["question_id"],
+ )
+ answer_path_train = os.path.join(self.anno_folder, "train_target.pkl")
+ answers_train = pickle.load(open(answer_path_train, "rb"))
+ answers_train = sorted(answers_train, key=lambda x: x["question_id"])
+
+ question_path_val = os.path.join(self.anno_folder, "v2_OpenEnded_mscoco_val2014_questions.json")
+ questions_val = sorted(
+ json.load(open(question_path_val))["questions"],
+ key=lambda x: x["question_id"],
+ )
+ answer_path_val = os.path.join(self.anno_folder, "val_target.pkl")
+ answers_val = pickle.load(open(answer_path_val, "rb"))
+ answers_val = sorted(answers_val, key=lambda x: x["question_id"])
+
+ # VG
+ vg_question_path_train = os.path.join(self.anno_folder, "VG_questions2.json")
+ vg_questions_train = sorted(
+ json.load(open(vg_question_path_train))["questions"],
+ key=lambda x: x["question_id"],
+ )
+ vg_answer_path_train = os.path.join(self.anno_folder, "vg_target.pkl")
+ vg_answers_train = pickle.load(open(vg_answer_path_train, "rb"))
+ vg_answers_train = sorted(vg_answers_train, key=lambda x: x["question_id"])
+
+ questions = questions_train + questions_val[:-3000] + vg_questions_train
+ answers = answers_train + answers_val[:-3000] + vg_answers_train
+ elif self.stage == "val": # minval
+ question_path_val = os.path.join(self.anno_folder, "v2_OpenEnded_mscoco_val2014_questions.json")
+ questions_val = sorted(
+ json.load(open(question_path_val))["questions"],
+ key=lambda x: x["question_id"],
+ )
+ answer_path_val = os.path.join(self.anno_folder, "val_target.pkl")
+ answers_val = pickle.load(open(answer_path_val, "rb"))
+ answers_val = sorted(answers_val, key=lambda x: x["question_id"])
+ if self.small_val:
+ questions = questions_val[-3000:]
+ answers = answers_val[-3000:]
+ else:
+ questions = questions_val
+ answers = answers_val
+ else:
+ question_path_test = os.path.join(self.anno_folder, "v2_OpenEnded_mscoco_test2015_questions.json")
+ # question_path_test = os.path.join(self.anno_folder, "v2_OpenEnded_mscoco_test-dev2015_questions.json")
+ questions_test = sorted(
+ json.load(open(question_path_test))["questions"],
+ key=lambda x: x["question_id"],
+ )
+ questions = questions_test
+
+ datalist = []
+ if self.stage == "test":
+ for question in questions:
+ datalist.append({
+ "question_id": str(question["question_id"]),
+ "image_id": str(question["image_id"]),
+ "question": question["question"],
+ })
+ else:
+ assert len(questions) == len(answers)
+ for question, answer in zip(questions, answers):
+ assert question["question_id"] == answer["question_id"]
+ assert question["image_id"] == answer["image_id"]
+
+ answer.pop("image_id")
+ answer.pop("question_id")
+ datalist.append({
+ "question_id": str(question["question_id"]),
+ "image_id": str(question["image_id"]),
+ "question": question["question"],
+ "answer": answer,
+ })
+ return datalist
+
+ def __len__(self):
+ return len(self.database)
+
+ def __getitem__(self, index):
+
+ for i_try in range(100):
+ try:
+ dataset_dict = self.database[index].as_py()
+ image_id = dataset_dict['image_id']
+ question_id = dataset_dict["question_id"]
+
+ global memorycache
+
+ image_path = os.path.join(self.feats_folder, self.id2path[int(image_id)])
+ ### LOAD IMAGE ###
+
+ if self.use_ceph:
+ img = self.tcs_loader(image_path).convert('RGB')
+
+ elif not memorycache:
+ img = Image.open(image_path).convert("RGB")
+ else:
+ # memcached
+ self._init_memcached()
+ value = mc.pyvector()
+ self.mclient.Get(image_path, value)
+ value_str = mc.ConvertBuffer(value)
+ buff = io.BytesIO(value_str)
+ img = Image.open(buff).convert("RGB")
+ except Exception as e:
+ print(
+ "Failed to load video from {} with error {} ; trial {}".format(
+ image_path, e, i_try
+ )
+ )
+
+ # let's try another one
+ index = random.randint(0, len(self.datalist) - 1)
+ dataset_dict = self.datalist[index]
+ continue
+
+
+ img = self.transform(img)
+
+ prob = random.random()
+ if prob > 0.5 and self.stage == 'train':
+ # img = img[:, :, ::-1]
+ img = torch.flip(img, [2])
+
+ question = dataset_dict["question"]
+ if self.as_gen:
+ if self.two_eot:
+ question = question + self.tokenizer.encode("<|endoftext|>")
+ question = question + self.tokenizer.encode("<|spe|> <|endoftext|>")
+ index = len(question) - 2
+
+ question = np.array(question, dtype=np.int64)
+
+ #######################################################
+ if prob > 0.5 and self.stage == 'train':
+ for i in range(1, len(question)):
+ if self.tokenizer_name == "clip":
+ left = self.tokenizer.encoder["left"]
+ right = self.tokenizer.encoder["right"]
+ if question[i] == left:
+ question[i] = right
+ elif question[i] == right:
+ question[i] = left
+ else:
+ raise NotImplementedError
+
+ if 'image' in self.inf_input:
+ ret = {
+ 'input_sample': [{
+ 'data' : img,
+ 'invalid_mask': None,
+ 'modality' : 'image',
+ 'data_type' : 'input',
+ 'sample_info' : {
+ 'id': image_id,
+ 'path': image_path
+ }
+ }]
+ }
+
+ self.target_set = self.cfg.DATASETS.TARGET_SET
+
+ target = 0
+
+ if "answer" in dataset_dict:
+ answer = dataset_dict["answer"]
+ labels = answer["labels"]
+ scores = answer["scores"]
+
+ #######################################################
+ if prob > 0.5 and self.stage == 'train':
+ for i in range(len(labels)):
+ if labels[i] == self.ans2label['left']:
+ labels[i] = self.ans2label['right']
+ elif labels[i] == self.ans2label['right']:
+ labels[i] = self.ans2label['left']
+ #######################################################
+
+
+ if self.single_class:
+ if len(labels) < 1:
+ target = 0
+ else:
+ s = sum(scores)
+ # probabilty
+ p = [t / s for t in scores]
+ # sample
+ target = choice(labels, 1, p=p).item()
+ else:
+ target = np.zeros(self.num_labels)
+ if len(labels) > 0:
+ for label, score in zip(labels, scores):
+ target[label] = score
+ target = np.array(target, dtype=np.float32)
+
+
+ if self.as_gen:
+ # caption like
+ ret['input_sample'].append({
+ 'data': [question],
+ 'invalid_mask': None,
+ 'modality': 'text',
+ 'data_type': 'input',
+ 'sample_info': {
+ 'spe_index': index,
+ 'question_id': question_id
+ }
+ })
+ ret.update({
+ 'target_sample': [],
+ 'target_idx' : [target],
+ 'target_set' : copy.deepcopy(self.target_set),
+ 'task_info' : copy.deepcopy(self.task_info)
+ })
+
+
+ dict_as_tensor(ret)
+ return ret
diff --git a/uniperceiver/datasets/tcsreader.py b/uniperceiver/datasets/tcsreader.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a6aa7addf30f1e440c5b52490125987ad5a68e8
--- /dev/null
+++ b/uniperceiver/datasets/tcsreader.py
@@ -0,0 +1,42 @@
+import io
+from PIL import Image
+import cv2
+import numpy as np
+try:
+ from petrel_client.client import Client
+except ImportError as E:
+ "petrel_client.client cannot be imported"
+ pass
+
+
+def pil_loader(img_str):
+ buff = io.BytesIO(img_str)
+ return Image.open(buff)
+
+def cv2_loader(img_bytes):
+ # assert(img_bytes is not None)
+ img_mem_view = memoryview(img_bytes)
+ img_array = np.frombuffer(img_mem_view, np.uint8)
+ imgcv2=cv2.imdecode(img_array, cv2.IMREAD_COLOR)
+ imgcv2=cv2.cvtColor(imgcv2, cv2.COLOR_BGR2RGB)
+ return Image.fromarray(imgcv2)
+
+class TCSLoader(object):
+
+ def __init__(self, conf_path):
+ self.client = Client(conf_path)
+
+ def __call__(self, fn):
+ try:
+ img_value_str = self.client.get(fn)
+ img = pil_loader(img_value_str)
+ except:
+ try:
+ img = cv2_loader(img_value_str)
+ except:
+ print('Read image failed ({})'.format(fn))
+ return None
+ else:
+ return img
+ else:
+ return img
diff --git a/uniperceiver/datasets/unified_dataset.py b/uniperceiver/datasets/unified_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffe84db758a3b18adcf0040c3384d77cd669b83a
--- /dev/null
+++ b/uniperceiver/datasets/unified_dataset.py
@@ -0,0 +1,62 @@
+import os
+import copy
+import torch
+
+from .build import DATASETS_REGISTRY
+from uniperceiver.config import configurable
+
+
+
+import numpy as np
+"""
+
+ only standard dataset support now
+ class dataset:
+ def __len__(self):
+ pass
+
+ def __getitem__(self, index):
+ pass
+
+"""
+
+class UnifiedDataset:
+
+ def __init__(self, cfg, task_cfg, stage, **kwargs):
+ self.cfg = cfg
+ self.task_cfg =task_cfg
+ assert stage == 'train', f'now only training dataset is supported'
+
+ datasets = dict()
+ for name, new_cfg in self.task_cfg.items():
+ datasets[name] = self.build_unit_dataset(new_cfg,
+ new_cfg.DATASETS.TRAIN,
+ stage=stage)
+ self.datasets = datasets
+
+ self.dataset_name = list(self.datasets.keys())
+ self.dataset_list = list(self.datasets.values())
+
+
+ self.dataset_length = np.array([len(ds) for ds in self.datasets.values()])
+
+ self.dataset_scale = np.array([0] + np.cumsum(self.dataset_length).tolist()[:-1])
+ # [0, 876, 128000, ....]
+
+
+ pass
+
+ def build_unit_dataset(self, cfg, name, stage):
+ dataset_mapper = DATASETS_REGISTRY.get(name)(cfg, stage)
+ return dataset_mapper
+
+ def __len__(self):
+ return np.cumsum(self.dataset_length).tolist()[-1]
+
+
+ def __getitem__(self, index):
+ dataset_index = (index >= self.dataset_scale).sum() - 1 # the dataset index
+ offset = self.dataset_scale[dataset_index]
+ ret = self.dataset_list[dataset_index][index-offset]
+ ret.update({"task_name": self.dataset_name[dataset_index]})
+ return ret
diff --git a/uniperceiver/datasets/zipreader.py b/uniperceiver/datasets/zipreader.py
new file mode 100644
index 0000000000000000000000000000000000000000..c120a801890e9e541108bca589c3e55807a1c1bc
--- /dev/null
+++ b/uniperceiver/datasets/zipreader.py
@@ -0,0 +1,85 @@
+import zipfile
+import os
+import io
+import time
+from PIL import Image
+
+
+class ZipReader(object):
+ zip_bank = dict()
+
+ def __init__(self):
+ super(ZipReader, self).__init__()
+
+ @staticmethod
+ def get_zipfile(path):
+ zip_bank = ZipReader.zip_bank
+ if path in zip_bank:
+ return zip_bank[path]
+ else:
+ print("creating new zip_bank")
+ zfile = zipfile.ZipFile(path, 'r')
+ zip_bank[path] = zfile
+ return zip_bank[path]
+
+ @staticmethod
+ def split_zip_style_path(path):
+ pos_zip_at = path.index('.zip@')
+ if pos_zip_at == len(path):
+ print("character '@' is not found from the given path '%s'" % (path))
+ assert 0
+ pos_at = pos_zip_at + len('.zip@') - 1
+
+ zip_path = path[0: pos_at]
+ folder_path = path[pos_at + 1:]
+ folder_path = str.strip(folder_path, '/')
+ return zip_path, folder_path
+
+ @staticmethod
+ def list_folder(path):
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ folder_list = []
+ for file_foler_name in zfile.namelist():
+ file_foler_name = str.strip(file_foler_name, '/')
+ if file_foler_name.startswith(folder_path) and \
+ len(os.path.splitext(file_foler_name)[-1]) == 0 and \
+ file_foler_name != folder_path:
+ if len(folder_path) == 0:
+ folder_list.append(file_foler_name)
+ else:
+ folder_list.append(file_foler_name[len(folder_path)+1:])
+
+ return folder_list
+
+ @staticmethod
+ def list_files(path, extension=['.*']):
+ zip_path, folder_path = ZipReader.split_zip_style_path(path)
+
+ zfile = ZipReader.get_zipfile(zip_path)
+ file_lists = []
+ for file_foler_name in zfile.namelist():
+ file_foler_name = str.strip(file_foler_name, '/')
+ if file_foler_name.startswith(folder_path) and str.lower(os.path.splitext(file_foler_name)[-1]) in extension:
+ if len(folder_path) == 0:
+ file_lists.append(file_foler_name)
+ else:
+ file_lists.append(file_foler_name[len(folder_path)+1:])
+
+ return file_lists
+
+ @staticmethod
+ def imread(path):
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
+ zfile = ZipReader.get_zipfile(zip_path)
+ data = zfile.read(path_img)
+ im = Image.open(io.BytesIO(data))
+ return im
+
+ @staticmethod
+ def read(path):
+ zip_path, path_img = ZipReader.split_zip_style_path(path)
+ zfile = ZipReader.get_zipfile(zip_path)
+ data = zfile.read(path_img)
+ return data
diff --git a/uniperceiver/engine/__init__.py b/uniperceiver/engine/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ba37a4b3cdb2c83ce39251293be027ab928f1d2
--- /dev/null
+++ b/uniperceiver/engine/__init__.py
@@ -0,0 +1,11 @@
+from .launch import *
+from .train_loop import *
+
+__all__ = [k for k in globals().keys() if not k.startswith("_")]
+
+
+from .hooks import *
+from .defaults import *
+from .unified_trainer import UnifiedTrainer
+
+from .build import build_engine
diff --git a/uniperceiver/engine/build.py b/uniperceiver/engine/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cc7507dd60403a22b407cd278600226c9a6e610
--- /dev/null
+++ b/uniperceiver/engine/build.py
@@ -0,0 +1,11 @@
+
+from uniperceiver.utils.registry import Registry
+
+ENGINE_REGISTRY = Registry("ENGINE")
+ENGINE_REGISTRY.__doc__ = """
+Registry for engine
+"""
+
+def build_engine(cfg):
+ engine = ENGINE_REGISTRY.get(cfg.ENGINE.NAME)(cfg)
+ return engine
\ No newline at end of file
diff --git a/uniperceiver/engine/defaults.py b/uniperceiver/engine/defaults.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5cfa69528beee38c532fab4fc90143fb6b54ab4
--- /dev/null
+++ b/uniperceiver/engine/defaults.py
@@ -0,0 +1,699 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+"""
+This file contains components with some default boilerplate logic user may need
+in training / testing. They will not work for everyone, but many users may find them useful.
+
+The behavior of functions/classes in this file is subject to change,
+since they are meant to represent the "common default behavior" people need in their projects.
+"""
+
+import argparse
+import logging
+import os
+import sys
+import weakref
+from collections import OrderedDict
+from typing import Optional
+import torch
+from fvcore.nn.precise_bn import get_bn_modules
+from omegaconf import OmegaConf
+from torch.nn.parallel import DistributedDataParallel
+from uniperceiver.config import CfgNode
+
+
+
+from uniperceiver.datasets import (
+ build_dataset_mapper,
+ build_standard_train_loader,
+ build_standard_valtest_loader,
+ build_unified_train_loader,
+)
+
+
+from uniperceiver.evaluation import (
+ DatasetEvaluator,
+ inference_on_dataset,
+ print_csv_format,
+ verify_results,
+)
+from uniperceiver.modeling import build_model
+from uniperceiver.lr_scheduler import build_lr_scheduler
+from uniperceiver.optim import build_optimizer
+from uniperceiver.utils import comm
+from uniperceiver.utils.collect_env import collect_env_info
+from uniperceiver.utils.env import seed_all_rng
+from uniperceiver.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
+from uniperceiver.utils.file_io import PathManager
+from uniperceiver.utils.logger import setup_logger
+
+from . import hooks
+from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase
+
+__all__ = [
+ "create_ddp_model",
+ "default_argument_parser",
+ "default_setup",
+ "default_writers",
+ "DefaultTrainer",
+ "add_moe_arguments",
+]
+
+
+def create_ddp_model(model, *, fp16_compression=False, **kwargs):
+ """
+ Create a DistributedDataParallel model if there are >1 processes.
+
+ Args:
+ model: a torch.nn.Module
+ fp16_compression: add fp16 compression hooks to the ddp object.
+ See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook
+ kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`.
+ """ # noqa
+ if comm.get_world_size() == 1:
+ return model
+ if "device_ids" not in kwargs:
+ kwargs["device_ids"] = [comm.get_local_rank()]
+ ddp = DistributedDataParallel(model, **kwargs)
+ if fp16_compression:
+ from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks
+
+ ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook)
+ return ddp
+
+
+def default_argument_parser(epilog=None):
+ """
+ Create a parser with some common arguments used by detectron2 users.
+
+ Args:
+ epilog (str): epilog passed to ArgumentParser describing the usage.
+
+ Returns:
+ argparse.ArgumentParser:
+ """
+ parser = argparse.ArgumentParser(
+ epilog=epilog
+ or f"""
+Examples:
+
+Run on single machine:
+ $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml
+
+Change some config options:
+ $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001
+
+Run on multiple machines:
+ (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags]
+ (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags]
+""",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file")
+ parser.add_argument(
+ "--resume",
+ action="store_true",
+ help="Whether to attempt to resume from the checkpoint directory. "
+ "See documentation of `DefaultTrainer.resume_or_load()` for what it means.",
+ )
+ parser.add_argument("--eval-only", action="store_true", help="perform evaluation only")
+ parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*")
+ parser.add_argument("--num-machines", type=int, default=1, help="total number of machines")
+ parser.add_argument(
+ "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)"
+ )
+
+ # PyTorch still may leave orphan processes in multi-gpu training.
+ # Therefore we use a deterministic way to obtain port,
+ # so that users are aware of orphan processes by seeing the port occupied.
+ port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14
+ parser.add_argument(
+ "--dist-url",
+ default="tcp://127.0.0.1:{}".format(port),
+ help="initialization URL for pytorch distributed backend. See "
+ "https://pytorch.org/docs/stable/distributed.html for details.",
+ )
+ parser.add_argument(
+ "opts",
+ help="""
+Modify config options at the end of the command. For Yacs configs, use
+space-separated "PATH.KEY VALUE" pairs.
+For python-based LazyConfig, use "path.key=value".
+ """.strip(),
+ default=None,
+ nargs=argparse.REMAINDER,
+ )
+ return parser
+
+
+def add_moe_arguments(parser):
+ """
+ Arguments:
+ parser: argument parser
+ Return:
+ parser: Updated Parser
+ """
+ group = parser.add_argument_group('MOE', 'DeepSpeed MOE configurations')
+ group.add_argument('--moe',
+ default=False,
+ type=bool,
+ help='use deepspeed mixture of experts (moe)')
+
+ group.add_argument('--ep-world-size',
+ default=1,
+ type=int,
+ help='(moe) expert parallel world size')
+ group.add_argument('--num-experts',
+ default=1,
+ type=int,
+ help='(moe) number of total experts')
+ group.add_argument('--top-k',
+ default=1,
+ type=int,
+ help='(moe) gating top 1 and 2 supported')
+ group.add_argument(
+ '--min-capacity',
+ default=0,
+ type=int,
+ help=
+ '(moe) minimum capacity of an expert regardless of the capacity_factor'
+ )
+ group.add_argument(
+ '--noisy-gate-policy',
+ default=None,
+ type=str,
+ help=
+ '(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter'
+ )
+ group.add_argument(
+ '--moe-param-group',
+ default=False,
+ action='store_true',
+ help=
+ '(moe) create separate moe param groups, required when using ZeRO w. MoE'
+ )
+ return parser
+
+
+def _try_get_key(cfg, *keys, default=None):
+ """
+ Try select keys from cfg until the first key that exists. Otherwise return default.
+ """
+ if isinstance(cfg, CfgNode):
+ cfg = OmegaConf.create(cfg.dump())
+ for k in keys:
+ none = object()
+ p = OmegaConf.select(cfg, k, default=none)
+ if p is not none:
+ return p
+ return default
+
+
+def _highlight(code, filename):
+ try:
+ import pygments
+ except ImportError:
+ return code
+
+ from pygments.lexers import Python3Lexer, YamlLexer
+ from pygments.formatters import Terminal256Formatter
+
+ lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer()
+ code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai"))
+ return code
+
+
+def default_setup(cfg, args):
+ """
+ Perform some basic common setups at the beginning of a job, including:
+
+ 1. Set up the detectron2 logger
+ 2. Log basic information about environment, cmdline arguments, and config
+ 3. Backup the config to the output directory
+
+ Args:
+ cfg (CfgNode or omegaconf.DictConfig): the full config to be used
+ args (argparse.NameSpace): the command line arguments to be logged
+ """
+ # error with following code
+ # output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir")
+ output_dir = cfg.OUTPUT_DIR
+ if comm.is_main_process() and output_dir:
+ PathManager.mkdirs(output_dir)
+
+ rank = comm.get_rank()
+ setup_logger(output_dir, distributed_rank=rank, name="fvcore")
+ logger = setup_logger(output_dir, distributed_rank=rank)
+
+ logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size()))
+ logger.info("Environment info:\n" + collect_env_info())
+
+ logger.info("Command line arguments: " + str(args))
+ if hasattr(args, "config_file") and args.config_file != "":
+ logger.info(
+ "Contents of args.config_file={}:\n{}".format(
+ args.config_file,
+ _highlight(PathManager.open(args.config_file, "r").read(), args.config_file),
+ )
+ )
+
+ if comm.is_main_process() and output_dir:
+ # Note: some of our scripts may expect the existence of
+ # config.yaml in output directory
+ path = os.path.join(output_dir, "config.yaml")
+
+ logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml")))
+ with PathManager.open(path, "w") as f:
+ f.write(cfg.dump())
+
+ logger.info("Full config saved to {}".format(path))
+
+ # make sure each worker has a different, yet deterministic seed if specified
+ # error with following code
+ # seed = _try_get_key(cfg, "SEED", "train.seed", default=-1)
+ seed = cfg.SEED
+ seed_all_rng(None if seed < 0 else seed + rank)
+
+ # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
+ # typical validation set.
+ if not (hasattr(args, "eval_only") and args.eval_only):
+ # error with following code
+ # torch.backends.cudnn.benchmark = _try_get_key(
+ # cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False
+ # )
+ torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
+
+
+def default_writers(output_dir: str, max_iter: Optional[int] = None):
+ """
+ Build a list of :class:`EventWriter` to be used.
+ It now consists of a :class:`CommonMetricPrinter`,
+ :class:`TensorboardXWriter` and :class:`JSONWriter`.
+
+ Args:
+ output_dir: directory to store JSON metrics and tensorboard events
+ max_iter: the total number of iterations
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ PathManager.mkdirs(output_dir)
+ return [
+ # It may not always print what you want to see, since it prints "common" metrics only.
+ CommonMetricPrinter(max_iter),
+ JSONWriter(os.path.join(output_dir, "metrics.json")),
+ TensorboardXWriter(output_dir),
+ ]
+
+
+
+class DefaultTrainer(TrainerBase):
+ """
+ A trainer with default training logic. It does the following:
+
+ 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader
+ defined by the given config. Create a LR scheduler defined by the config.
+ 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when
+ `resume_or_load` is called.
+ 3. Register a few common hooks defined by the config.
+
+ It is created to simplify the **standard model training workflow** and reduce code boilerplate
+ for users who only need the standard training workflow, with standard features.
+ It means this class makes *many assumptions* about your training logic that
+ may easily become invalid in a new research. In fact, any assumptions beyond those made in the
+ :class:`SimpleTrainer` are too much for research.
+
+ The code of this class has been annotated about restrictive assumptions it makes.
+ When they do not work for you, you're encouraged to:
+
+ 1. Overwrite methods of this class, OR:
+ 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and
+ nothing else. You can then add your own hooks if needed. OR:
+ 3. Write your own training loop similar to `tools/plain_train_net.py`.
+
+ See the :doc:`/tutorials/training` tutorials for more details.
+
+ Note that the behavior of this class, like other functions/classes in
+ this file, is not stable, since it is meant to represent the "common default behavior".
+ It is only guaranteed to work well with the standard models and training workflow in detectron2.
+ To obtain more stable behavior, write your own training logic with other public APIs.
+
+ Examples:
+ ::
+ trainer = DefaultTrainer(cfg)
+ trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS
+ trainer.train()
+
+ Attributes:
+ scheduler:
+ checkpointer (DetectionCheckpointer):
+ cfg (CfgNode):
+ """
+
+ def __init__(self, cfg):
+ """
+ Args:
+ cfg (CfgNode):
+ """
+ super().__init__()
+ logger = logging.getLogger("uniperceiver")
+ if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2
+ setup_logger()
+ cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
+
+ # Assume these objects must be constructed in this order.
+ model = self.build_model(cfg)
+ optimizer = self.build_optimizer(cfg, model)
+ data_loader = self.build_train_loader(cfg)
+
+ model = create_ddp_model(model, broadcast_buffers=False)
+ self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
+ model, data_loader, optimizer
+ )
+
+ self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+ self.start_iter = 0
+ self.max_iter = cfg.SOLVER.MAX_ITER
+ self.cfg = cfg
+
+ self.register_hooks(self.build_hooks())
+
+ def resume_or_load(self, resume=True):
+ """
+ If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by
+ a `last_checkpoint` file), resume from the file. Resuming means loading all
+ available states (eg. optimizer and scheduler) and update iteration counter
+ from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used.
+
+ Otherwise, this is considered as an independent training. The method will load model
+ weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start
+ from iteration 0.
+
+ Args:
+ resume (bool): whether to do resume or not
+ """
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume)
+ if resume and self.checkpointer.has_checkpoint():
+ # The checkpoint stores the training iteration that just finished, thus we start
+ # at the next iteration
+ self.start_iter = self.iter + 1
+
+ def build_hooks(self):
+ """
+ Build a list of default hooks, including timing, evaluation,
+ checkpointing, lr scheduling, precise BN, writing events.
+
+ Returns:
+ list[HookBase]:
+ """
+ cfg = self.cfg.clone()
+ cfg.defrost()
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
+
+ ret = [
+ hooks.IterationTimer(),
+ hooks.LRScheduler(),
+ hooks.PreciseBN(
+ # Run at the same freq as (but before) evaluation.
+ cfg.TEST.EVAL_PERIOD,
+ self.model,
+ # Build a new data loader to not affect training
+ self.build_train_loader(cfg),
+ cfg.TEST.PRECISE_BN.NUM_ITER,
+ )
+ if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
+ else None,
+ ]
+
+ # Do PreciseBN before checkpointer, because it updates the model and need to
+ # be saved by checkpointer.
+ # This is not always the best: if checkpointing has a different frequency,
+ # some checkpoints may have more precise statistics than others.
+ if comm.is_main_process():
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD))
+
+ def test_and_save_results():
+ self._last_eval_results = self.test(self.cfg, self.model)
+ return self._last_eval_results
+
+ # Do evaluation after checkpointer, because then if it fails,
+ # we can use the saved checkpoint to debug.
+ ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results))
+
+ if comm.is_main_process():
+ # Here the default print/log frequency of each writer is used.
+ # run writers in the end, so that evaluation metrics are written
+ ret.append(hooks.PeriodicWriter(self.build_writers(), period=20))
+ return ret
+
+ def build_writers(self):
+ """
+ Build a list of writers to be used using :func:`default_writers()`.
+ If you'd like a different list of writers, you can overwrite it in
+ your trainer.
+
+ Returns:
+ list[EventWriter]: a list of :class:`EventWriter` objects.
+ """
+ return default_writers(self.cfg.OUTPUT_DIR, self.max_iter)
+
+ def train(self):
+ """
+ Run training.
+
+ Returns:
+ OrderedDict of results, if evaluation is enabled. Otherwise None.
+ """
+ super().train(self.start_iter, self.max_iter)
+ if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
+ assert hasattr(
+ self, "_last_eval_results"
+ ), "No evaluation results obtained during training!"
+ verify_results(self.cfg, self._last_eval_results)
+ return self._last_eval_results
+
+ def run_step(self):
+ self._trainer.iter = self.iter
+ self._trainer.run_step()
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["_trainer"] = self._trainer.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self._trainer.load_state_dict(state_dict["_trainer"])
+
+ @classmethod
+ def build_model(cls, cfg):
+ """
+ Returns:
+ torch.nn.Module:
+
+ It now calls :func:`detectron2.modeling.build_model`.
+ Overwrite it if you'd like a different model.
+ """
+ model = build_model(cfg)
+ logger = logging.getLogger(__name__)
+ logger.info("Model:\n{}".format(model))
+ return model
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ """
+ Returns:
+ torch.optim.Optimizer:
+
+ It now calls :func:`detectron2.solver.build_optimizer`.
+ Overwrite it if you'd like a different optimizer.
+ """
+ return build_optimizer(cfg, model)
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer):
+ """
+ It now calls :func:`detectron2.solver.build_lr_scheduler`.
+ Overwrite it if you'd like a different scheduler.
+ """
+ return build_lr_scheduler(cfg, optimizer)
+
+ @classmethod
+ def build_train_loader(cls, cfg):
+ """
+ Returns:
+ iterable
+
+ It now calls :func:`detectron2.data.build_detection_train_loader`.
+ Overwrite it if you'd like a different data loader.
+ """
+ return build_standard_train_loader(cfg)
+
+ @classmethod
+ def build_test_loader(cls, cfg, dataset_name):
+ """
+ Returns:
+ iterable
+
+ It now calls :func:`detectron2.data.build_detection_test_loader`.
+ Overwrite it if you'd like a different data loader.
+ """
+ return build_standard_valtest_loader(cfg, dataset_name)
+
+ @classmethod
+ def build_evaluator(cls, cfg, dataset_name):
+ """
+ Returns:
+ DatasetEvaluator or None
+
+ It is not implemented by default.
+ """
+ raise NotImplementedError(
+ """
+If you want DefaultTrainer to automatically run evaluation,
+please implement `build_evaluator()` in subclasses (see train_net.py for example).
+Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example).
+"""
+ )
+
+ @classmethod
+ def test(cls, cfg, model, evaluators=None):
+ """
+ Evaluate the given model. The given model is expected to already contain
+ weights to evaluate.
+
+ Args:
+ cfg (CfgNode):
+ model (nn.Module):
+ evaluators (list[DatasetEvaluator] or None): if None, will call
+ :meth:`build_evaluator`. Otherwise, must have the same length as
+ ``cfg.DATASETS.TEST``.
+
+ Returns:
+ dict: a dict of result metrics
+ """
+ logger = logging.getLogger(__name__)
+ if isinstance(evaluators, DatasetEvaluator):
+ evaluators = [evaluators]
+ if evaluators is not None:
+ assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format(
+ len(cfg.DATASETS.TEST), len(evaluators)
+ )
+
+ results = OrderedDict()
+ for idx, dataset_name in enumerate(cfg.DATASETS.TEST):
+ data_loader = cls.build_test_loader(cfg, dataset_name)
+ # When evaluators are passed in as arguments,
+ # implicitly assume that evaluators can be created before data_loader.
+ if evaluators is not None:
+ evaluator = evaluators[idx]
+ else:
+ try:
+ evaluator = cls.build_evaluator(cfg, dataset_name)
+ except NotImplementedError:
+ logger.warn(
+ "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, "
+ "or implement its `build_evaluator` method."
+ )
+ results[dataset_name] = {}
+ continue
+ results_i = inference_on_dataset(model, data_loader, evaluator)
+ results[dataset_name] = results_i
+ if comm.is_main_process():
+ assert isinstance(
+ results_i, dict
+ ), "Evaluator must return a dict on the main process. Got {} instead.".format(
+ results_i
+ )
+ logger.info("Evaluation results for {} in csv format:".format(dataset_name))
+ print_csv_format(results_i)
+
+ if len(results) == 1:
+ results = list(results.values())[0]
+ return results
+
+ @staticmethod
+ def auto_scale_workers(cfg, num_workers: int):
+ """
+ When the config is defined for certain number of workers (according to
+ ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of
+ workers currently in use, returns a new cfg where the total batch size
+ is scaled so that the per-GPU batch size stays the same as the
+ original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``.
+
+ Other config options are also scaled accordingly:
+ * training steps and warmup steps are scaled inverse proportionally.
+ * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`.
+
+ For example, with the original config like the following:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 16
+ BASE_LR: 0.1
+ REFERENCE_WORLD_SIZE: 8
+ MAX_ITER: 5000
+ STEPS: (4000,)
+ CHECKPOINT_PERIOD: 1000
+
+ When this config is used on 16 GPUs instead of the reference number 8,
+ calling this method will return a new config with:
+
+ .. code-block:: yaml
+
+ IMS_PER_BATCH: 32
+ BASE_LR: 0.2
+ REFERENCE_WORLD_SIZE: 16
+ MAX_ITER: 2500
+ STEPS: (2000,)
+ CHECKPOINT_PERIOD: 500
+
+ Note that both the original config and this new config can be trained on 16 GPUs.
+ It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``).
+
+ Returns:
+ CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``.
+ """
+ old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE
+ if old_world_size == 0 or old_world_size == num_workers:
+ return cfg
+ cfg = cfg.clone()
+ frozen = cfg.is_frozen()
+ cfg.defrost()
+
+ assert (
+ cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0
+ ), "Invalid REFERENCE_WORLD_SIZE in config!"
+ scale = num_workers / old_world_size
+ bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale))
+ lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale
+ max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale))
+ warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale))
+ cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS)
+ cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale))
+ cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale))
+ cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant
+ logger = logging.getLogger(__name__)
+ logger.info(
+ f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, "
+ f"max_iter={max_iter}, warmup={warmup_iter}."
+ )
+
+ if frozen:
+ cfg.freeze()
+ return cfg
+
+
+# Access basic attributes from the underlying trainer
+for _attr in ["model", "data_loader", "optimizer"]:
+ setattr(
+ DefaultTrainer,
+ _attr,
+ property(
+ # getter
+ lambda self, x=_attr: getattr(self._trainer, x),
+ # setter
+ lambda self, value, x=_attr: setattr(self._trainer, x, value),
+ ),
+ )
diff --git a/uniperceiver/engine/hooks.py b/uniperceiver/engine/hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..62bed27f4149740dca5adef91d6a9d8a38503acf
--- /dev/null
+++ b/uniperceiver/engine/hooks.py
@@ -0,0 +1,646 @@
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import datetime
+import itertools
+import logging
+import os
+import tempfile
+import time
+from collections import Counter
+import torch
+from fvcore.common.param_scheduler import ParamScheduler
+from fvcore.common.timer import Timer
+from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats
+
+import uniperceiver.utils.comm as comm
+from uniperceiver.checkpoint import PeriodicEpochCheckpointer as _PeriodicCheckpointer
+from uniperceiver.evaluation.testing import flatten_results_dict
+from uniperceiver.utils.events import EventStorage, EventWriter
+from uniperceiver.utils.file_io import PathManager
+
+from .train_loop import HookBase
+
+__all__ = [
+ "CallbackHook",
+ "IterationTimer",
+ "PeriodicWriter",
+ "PeriodicCheckpointer",
+ "LRScheduler",
+ "AutogradProfiler",
+ "EvalHook",
+ "PreciseBN",
+ "ModelWeightsManipulating",
+ "MultiGPUEvalHook"
+]
+
+
+"""
+Implement some common hooks.
+"""
+
+
+class CallbackHook(HookBase):
+ """
+ Create a hook using callback functions provided by the user.
+ """
+
+ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None):
+ """
+ Each argument is a function that takes one argument: the trainer.
+ """
+ self._before_train = before_train
+ self._before_step = before_step
+ self._after_step = after_step
+ self._after_train = after_train
+
+ def before_train(self):
+ if self._before_train:
+ self._before_train(self.trainer)
+
+ def after_train(self):
+ if self._after_train:
+ self._after_train(self.trainer)
+ # The functions may be closures that hold reference to the trainer
+ # Therefore, delete them to avoid circular reference.
+ del self._before_train, self._after_train
+ del self._before_step, self._after_step
+
+ def before_step(self):
+ if self._before_step:
+ self._before_step(self.trainer)
+
+ def after_step(self):
+ if self._after_step:
+ self._after_step(self.trainer)
+
+class ModelWeightsManipulating(HookBase):
+ """
+ Init or bind weights after loading a model
+ """
+ def __init__(self):
+ super().__init__()
+
+ def before_train(self):
+ if hasattr(self.trainer.model, 'bind_or_init_weights'):
+ self.trainer.model.bind_or_init_weights()
+
+ if hasattr(self.trainer.model, 'compute_attributemoe_decision'):
+ self.trainer.model.compute_attributemoe_decision()
+
+ pass
+
+class ScheduledSampling(HookBase):
+ def __init__(self, start_iter, inc_every_iter, inc_prob, max_prob):
+ self._start_iter = start_iter
+ self._inc_every_iter = inc_every_iter
+ self._inc_prob = inc_prob
+ self._max_prob = max_prob
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ if next_iter > self._start_iter:
+ frac = (next_iter - self._start_iter) // self._inc_every_iter
+ ss_prob = min(self._inc_prob * frac, self._max_prob)
+ self.trainer.ss_prob = ss_prob
+
+
+class IterationTimer(HookBase):
+ """
+ Track the time spent for each iteration (each run_step call in the trainer).
+ Print a summary in the end of training.
+
+ This hook uses the time between the call to its :meth:`before_step`
+ and :meth:`after_step` methods.
+ Under the convention that :meth:`before_step` of all hooks should only
+ take negligible amount of time, the :class:`IterationTimer` hook should be
+ placed at the beginning of the list of hooks to obtain accurate timing.
+ """
+
+ def __init__(self, warmup_iter=3):
+ """
+ Args:
+ warmup_iter (int): the number of iterations at the beginning to exclude
+ from timing.
+ """
+ self._warmup_iter = warmup_iter
+ self._step_timer = Timer()
+ self._start_time = time.perf_counter()
+ self._total_timer = Timer()
+
+ def before_train(self):
+ self._start_time = time.perf_counter()
+ self._total_timer.reset()
+ self._total_timer.pause()
+
+ def after_train(self):
+ logger = logging.getLogger(__name__)
+ total_time = time.perf_counter() - self._start_time
+ total_time_minus_hooks = self._total_timer.seconds()
+ hook_time = total_time - total_time_minus_hooks
+
+ num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter
+
+ if num_iter > 0 and total_time_minus_hooks > 0:
+ # Speed is meaningful only after warmup
+ # NOTE this format is parsed by grep in some scripts
+ logger.info(
+ "Overall training speed: {} iterations in {} ({:.4f} s / it)".format(
+ num_iter,
+ str(datetime.timedelta(seconds=int(total_time_minus_hooks))),
+ total_time_minus_hooks / num_iter,
+ )
+ )
+
+ logger.info(
+ "Total training time: {} ({} on hooks)".format(
+ str(datetime.timedelta(seconds=int(total_time))),
+ str(datetime.timedelta(seconds=int(hook_time))),
+ )
+ )
+
+ def before_step(self):
+ self._step_timer.reset()
+ self._total_timer.resume()
+
+ def after_step(self):
+ # +1 because we're in after_step, the current step is done
+ # but not yet counted
+ iter_done = self.trainer.iter - self.trainer.start_iter + 1
+ if iter_done >= self._warmup_iter:
+ sec = self._step_timer.seconds()
+ self.trainer.storage.put_scalars(time=sec)
+ else:
+ self._start_time = time.perf_counter()
+ self._total_timer.reset()
+
+ self._total_timer.pause()
+
+
+class PeriodicWriter(HookBase):
+ """
+ Write events to EventStorage (by calling ``writer.write()``) periodically.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ Note that ``period`` does not affect how data is smoothed by each writer.
+ """
+
+ def __init__(self, writers, period=20):
+ """
+ Args:
+ writers (list[EventWriter]): a list of EventWriter objects
+ period (int):
+ """
+ self._writers = writers
+ for w in writers:
+ assert isinstance(w, EventWriter), w
+ self._period = period
+
+ def after_step(self):
+ if (self.trainer.iter + 1) % self._period == 0 or (
+ self.trainer.iter == self.trainer.max_iter - 1
+ ):
+ for writer in self._writers:
+ writer.write()
+
+ def after_train(self):
+ for writer in self._writers:
+ # If any new data is found (e.g. produced by other after_train),
+ # write them before closing
+ writer.write()
+ writer.close()
+
+
+class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase):
+ """
+ Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook.
+
+ Note that when used as a hook,
+ it is unable to save additional data other than what's defined
+ by the given `checkpointer`.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ """
+
+ def before_train(self):
+ self.max_iter = self.trainer.max_iter
+
+ def after_step(self):
+ # No way to use **kwargs
+ self.step(iteration = self.trainer.iter, epoch=(self.trainer.iter+1) // self.trainer.iters_per_epoch)
+
+class LRScheduler(HookBase):
+ """
+ A hook which executes a torch builtin LR scheduler and summarizes the LR.
+ It is executed after every iteration.
+ """
+
+ def __init__(self, optimizer=None, scheduler=None):
+ """
+ Args:
+ optimizer (torch.optim.Optimizer):
+ scheduler (torch.optim.LRScheduler):
+
+ If any argument is not given, will try to obtain it from the trainer.
+ """
+ self._optimizer = optimizer
+ self._scheduler = scheduler
+
+
+ def before_train(self):
+ self._optimizer = self._optimizer or self.trainer.optimizer
+ self._scheduler = self._scheduler or self.trainer.scheduler
+
+ # self.ds_engine = self.trainer.ds_engine
+ # if self.ds_engine:
+ # print("scheduler.step() will be done by the deepspeed engine.")
+
+ # NOTE: some heuristics on what LR to summarize
+ # summarize the param group with most parameters
+ largest_group = max(len(g["params"]) for g in self._optimizer.param_groups)
+
+ if largest_group == 1:
+ # If all groups have one parameter,
+ # then find the most common initial LR, and use it for summary
+ lr_count = Counter([g["lr"] for g in self._optimizer.param_groups])
+ lr = lr_count.most_common()[0][0]
+ for i, g in enumerate(self._optimizer.param_groups):
+ if g["lr"] == lr:
+ self._best_param_group_id = i
+ break
+ else:
+ for i, g in enumerate(self._optimizer.param_groups):
+ if len(g["params"]) == largest_group:
+ self._best_param_group_id = i
+ break
+
+ def after_step(self):
+ lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
+ self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
+ self._scheduler.step()
+ # otherwise we will step by the deepspeed
+
+
+class AutogradProfiler(HookBase):
+ """
+ A hook which runs `torch.autograd.profiler.profile`.
+
+ Examples:
+ ::
+ hooks.AutogradProfiler(
+ lambda trainer: trainer.iter > 10 and trainer.iter < 20, self.cfg.OUTPUT_DIR
+ )
+
+ The above example will run the profiler for iteration 10~20 and dump
+ results to ``OUTPUT_DIR``. We did not profile the first few iterations
+ because they are typically slower than the rest.
+ The result files can be loaded in the ``chrome://tracing`` page in chrome browser.
+
+ Note:
+ When used together with NCCL on older version of GPUs,
+ autograd profiler may cause deadlock because it unnecessarily allocates
+ memory on every device it sees. The memory management calls, if
+ interleaved with NCCL calls, lead to deadlock on GPUs that do not
+ support ``cudaLaunchCooperativeKernelMultiDevice``.
+ """
+
+ def __init__(self, enable_predicate, output_dir, *, use_cuda=True):
+ """
+ Args:
+ enable_predicate (callable[trainer -> bool]): a function which takes a trainer,
+ and returns whether to enable the profiler.
+ It will be called once every step, and can be used to select which steps to profile.
+ output_dir (str): the output directory to dump tracing files.
+ use_cuda (bool): same as in `torch.autograd.profiler.profile`.
+ """
+ self._enable_predicate = enable_predicate
+ self._use_cuda = use_cuda
+ self._output_dir = output_dir
+
+ def before_step(self):
+ if self._enable_predicate(self.trainer):
+ self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda)
+ self._profiler.__enter__()
+ else:
+ self._profiler = None
+
+ def after_step(self):
+ if self._profiler is None:
+ return
+ self._profiler.__exit__(None, None, None)
+ PathManager.mkdirs(self._output_dir)
+ out_file = os.path.join(
+ self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter)
+ )
+ if "://" not in out_file:
+ self._profiler.export_chrome_trace(out_file)
+ else:
+ # Support non-posix filesystems
+ with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d:
+ tmp_file = os.path.join(d, "tmp.json")
+ self._profiler.export_chrome_trace(tmp_file)
+ with open(tmp_file) as f:
+ content = f.read()
+ with PathManager.open(out_file, "w") as f:
+ f.write(content)
+
+
+class EvalHook(HookBase):
+ """
+ Run an evaluation function periodically, and at the end of training.
+
+ It is executed every ``eval_period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, eval_period, eval_start, eval_function, iters_per_epoch, stage, multi_gpu_eval):
+ """
+ Args:
+ eval_period (int): the period to run `eval_function`. Set to 0 to
+ not evaluate periodically (but still after the last iteration).
+ eval_function (callable): a function which takes no arguments, and
+ returns a nested dict of evaluation metrics.
+
+ Note:
+ This hook must be enabled in all or none workers.
+ If you would like only certain workers to perform evaluation,
+ give other workers a no-op function (`eval_function=lambda: None`).
+ """
+ self._period = eval_period * iters_per_epoch
+ self._func = eval_function
+ self._stage = stage
+ self._eval_start = eval_start
+ self._multi_gpu_eval = multi_gpu_eval
+
+ def _do_eval(self, epoch):
+ if self._multi_gpu_eval:
+ results = self._func(epoch)
+ else:
+ if comm.is_main_process():
+ results = self._func(epoch)
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ if comm.is_main_process():
+ if results:
+ assert isinstance(results, dict), "Eval function must return a dict. Got {} instead.".format(results)
+ flattened_results = flatten_results_dict(results)
+ for k, v in flattened_results.items():
+ try:
+ v = float(v)
+ except Exception as e:
+ raise ValueError(
+ "[EvalHook] eval_function should return a nested dict of float. "
+ "Got '{}: {}' instead.".format(k, v)
+ ) from e
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
+
+ logger = logging.getLogger(__name__)
+ epoch = (self.trainer.iter + 1) // self.trainer.iters_per_epoch
+ logger.info('############################## {}: {} ##############################'.format(self._stage, epoch))
+ for k, v in flattened_results.items():
+ logger.info("{}: {}".format(k, v))
+
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ epoch = int(next_iter // self.trainer.iters_per_epoch)
+ if (self._period > 0) and (next_iter % self._period == 0) and (epoch > self._eval_start):
+ self._do_eval(epoch)
+
+ def after_train(self):
+ next_iter = self.trainer.iter + 1
+ epoch = int(next_iter // self.trainer.iters_per_epoch)
+ # This condition is to prevent the eval from running after a failed training
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
+ self._do_eval(epoch)
+ # func is likely a closure that holds reference to the trainer
+ # therefore we clean it to avoid circular reference in the end
+ del self._func
+
+class IterEvalHook(HookBase):
+ """
+ Run an evaluation function periodically, and at the end of training.
+
+ It is executed every ``eval_period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, eval_period, eval_start, eval_function, stage, multi_gpu_eval):
+ """
+ Args:
+ eval_period (int): the period to run `eval_function`. Set to 0 to
+ not evaluate periodically (but still after the last iteration).
+ eval_function (callable): a function which takes no arguments, and
+ returns a nested dict of evaluation metrics.
+
+ Note:
+ This hook must be enabled in all or none workers.
+ If you would like only certain workers to perform evaluation,
+ give other workers a no-op function (`eval_function=lambda: None`).
+ """
+ self._period = eval_period
+ self._func = eval_function
+ self._stage = stage
+ self._eval_start = eval_start
+ self._multi_gpu_eval = multi_gpu_eval
+
+ def _do_eval(self, iter_):
+ if self._multi_gpu_eval:
+ results = self._func(iter_)
+ else:
+ if comm.is_main_process():
+ results = self._func(iter_)
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ if comm.is_main_process():
+ if results:
+ assert isinstance(results, dict), "Eval function must return a dict. Got {} instead.".format(results)
+ flattened_results = flatten_results_dict(results)
+ for k, v in flattened_results.items():
+ try:
+ v = float(v)
+ except Exception as e:
+ raise ValueError(
+ "[EvalHook] eval_function should return a nested dict of float. "
+ "Got '{}: {}' instead.".format(k, v)
+ ) from e
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
+
+ logger = logging.getLogger(__name__)
+ iter_ = self.trainer.iter + 1
+ logger.info('############################## {}: {} ##############################'.format(self._stage, iter_))
+ for k, v in flattened_results.items():
+ logger.info("{}: {}".format(k, v))
+
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ if (self._period > 0) and (next_iter % self._period == 0) and (next_iter > self._eval_start):
+ self._do_eval(next_iter)
+
+ def after_train(self):
+ next_iter = self.trainer.iter + 1
+ # This condition is to prevent the eval from running after a failed training
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
+ self._do_eval(next_iter)
+ # func is likely a closure that holds reference to the trainer
+ # therefore we clean it to avoid circular reference in the end
+ del self._func
+
+
+class MultiGPUEvalHook(HookBase):
+ """
+ Run an evaluation function periodically, and at the end of training.
+
+ It is executed every ``eval_period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, eval_period, eval_start, eval_function, iters_per_epoch, stage, multi_gpu_eval):
+ """
+ Args:
+ eval_period (int): the period to run `eval_function`. Set to 0 to
+ not evaluate periodically (but still after the last iteration).
+ eval_function (callable): a function which takes no arguments, and
+ returns a nested dict of evaluation metrics.
+
+ Note:
+ This hook must be enabled in all or none workers.
+ If you would like only certain workers to perform evaluation,
+ give other workers a no-op function (`eval_function=lambda: None`).
+ """
+ self._period = eval_period * iters_per_epoch
+ self._func = eval_function
+ self._stage = stage
+ self._eval_start = eval_start
+ self._multi_gpu_eval = multi_gpu_eval
+
+ def _do_eval(self, epoch):
+ results = self._func(epoch)
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ if comm.is_main_process():
+ if results:
+ assert isinstance(results, dict), "Eval function must return a dict. Got {} instead.".format(results)
+ flattened_results = flatten_results_dict(results)
+ for k, v in flattened_results.items():
+ try:
+ v = float(v)
+ except Exception as e:
+ raise ValueError(
+ "[EvalHook] eval_function should return a nested dict of float. "
+ "Got '{}: {}' instead.".format(k, v)
+ ) from e
+ self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False)
+
+ logger = logging.getLogger(__name__)
+ epoch = (self.trainer.iter + 1) // self.trainer.iters_per_epoch
+ logger.info('############################## {}: {} ##############################'.format(self._stage, epoch))
+ for k, v in flattened_results.items():
+ logger.info("{}: {}".format(k, v))
+
+
+ # Evaluation may take different time among workers.
+ # A barrier make them start the next iteration together.
+ comm.synchronize()
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ epoch = int(next_iter // self.trainer.iters_per_epoch)
+ if (self._period > 0) and (next_iter % self._period == 0) and (epoch > self._eval_start):
+ self._do_eval(epoch)
+
+ def after_train(self):
+ next_iter = self.trainer.iter + 1
+ epoch = int(next_iter // self.trainer.iters_per_epoch)
+ # This condition is to prevent the eval from running after a failed training
+ if self.trainer.iter + 1 >= self.trainer.max_iter:
+ self._do_eval(epoch)
+ # func is likely a closure that holds reference to the trainer
+ # therefore we clean it to avoid circular reference in the end
+ del self._func
+
+class PreciseBN(HookBase):
+ """
+ The standard implementation of BatchNorm uses EMA in inference, which is
+ sometimes suboptimal.
+ This class computes the true average of statistics rather than the moving average,
+ and put true averages to every BN layer in the given model.
+
+ It is executed every ``period`` iterations and after the last iteration.
+ """
+
+ def __init__(self, period, model, data_loader, num_iter):
+ """
+ Args:
+ period (int): the period this hook is run, or 0 to not run during training.
+ The hook will always run in the end of training.
+ model (nn.Module): a module whose all BN layers in training mode will be
+ updated by precise BN.
+ Note that user is responsible for ensuring the BN layers to be
+ updated are in training mode when this hook is triggered.
+ data_loader (iterable): it will produce data to be run by `model(data)`.
+ num_iter (int): number of iterations used to compute the precise
+ statistics.
+ """
+ self._logger = logging.getLogger(__name__)
+ if len(get_bn_modules(model)) == 0:
+ self._logger.info(
+ "PreciseBN is disabled because model does not contain BN layers in training mode."
+ )
+ self._disabled = True
+ return
+
+ self._model = model
+ self._data_loader = data_loader
+ self._num_iter = num_iter
+ self._period = period
+ self._disabled = False
+
+ self._data_iter = None
+
+ def after_step(self):
+ next_iter = self.trainer.iter + 1
+ is_final = next_iter == self.trainer.max_iter
+ if is_final or (self._period > 0 and next_iter % self._period == 0):
+ self.update_stats()
+
+ def update_stats(self):
+ """
+ Update the model with precise statistics. Users can manually call this method.
+ """
+ if self._disabled:
+ return
+
+ if self._data_iter is None:
+ self._data_iter = iter(self._data_loader)
+
+ def data_loader():
+ for num_iter in itertools.count(1):
+ if num_iter % 100 == 0:
+ self._logger.info(
+ "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
+ )
+ # This way we can reuse the same iterator
+ yield next(self._data_iter)
+
+ with EventStorage(): # capture events in a new storage to discard them
+ self._logger.info(
+ "Running precise-BN for {} iterations... ".format(self._num_iter)
+ + "Note that this could produce different statistics every time."
+ )
+ update_bn_stats(self._model, data_loader(), self._num_iter)
diff --git a/uniperceiver/engine/launch.py b/uniperceiver/engine/launch.py
new file mode 100644
index 0000000000000000000000000000000000000000..f772142fd18b57285072ded37cc14b28a9fd68ba
--- /dev/null
+++ b/uniperceiver/engine/launch.py
@@ -0,0 +1,126 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+from datetime import timedelta
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+
+from uniperceiver.utils import comm
+
+__all__ = ["DEFAULT_TIMEOUT", "launch"]
+
+DEFAULT_TIMEOUT = timedelta(minutes=30)
+
+
+def _find_free_port():
+ import socket
+
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ # Binding to port 0 will cause the OS to find an available port for us
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
+ # NOTE: there is still a chance the port could be taken by other processes.
+ return port
+
+
+def launch(
+ main_func,
+ num_gpus_per_machine,
+ num_machines=1,
+ machine_rank=0,
+ dist_url=None,
+ args=(),
+ timeout=DEFAULT_TIMEOUT,
+):
+ """
+ Launch multi-gpu or distributed training.
+ This function must be called on all machines involved in the training.
+ It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
+
+ Args:
+ main_func: a function that will be called by `main_func(*args)`
+ num_gpus_per_machine (int): number of GPUs per machine
+ num_machines (int): the total number of machines
+ machine_rank (int): the rank of this machine
+ dist_url (str): url to connect to for distributed jobs, including protocol
+ e.g. "tcp://127.0.0.1:8686".
+ Can be set to "auto" to automatically select a free port on localhost
+ timeout (timedelta): timeout of the distributed workers
+ args (tuple): arguments passed to main_func
+ """
+ world_size = num_machines * num_gpus_per_machine
+ if world_size > 1:
+ # https://github.com/pytorch/pytorch/pull/14391
+ # TODO prctl in spawned processes
+
+ if dist_url == "auto":
+ assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
+ port = _find_free_port()
+ dist_url = f"tcp://127.0.0.1:{port}"
+ if num_machines > 1 and dist_url.startswith("file://"):
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
+ )
+
+ mp.spawn(
+ _distributed_worker,
+ nprocs=num_gpus_per_machine,
+ args=(
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ args,
+ timeout,
+ ),
+ daemon=False,
+ )
+ else:
+ main_func(*args)
+
+
+def _distributed_worker(
+ local_rank,
+ main_func,
+ world_size,
+ num_gpus_per_machine,
+ machine_rank,
+ dist_url,
+ args,
+ timeout=DEFAULT_TIMEOUT,
+):
+ assert torch.cuda.is_available(), "cuda is not available. Please check your installation."
+ global_rank = machine_rank * num_gpus_per_machine + local_rank
+ try:
+ dist.init_process_group(
+ backend="NCCL",
+ init_method=dist_url,
+ world_size=world_size,
+ rank=global_rank,
+ timeout=timeout,
+ )
+ except Exception as e:
+ logger = logging.getLogger(__name__)
+ logger.error("Process group URL: {}".format(dist_url))
+ raise e
+
+ # Setup the local process group (which contains ranks within the same machine)
+ assert comm._LOCAL_PROCESS_GROUP is None
+ num_machines = world_size // num_gpus_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
+ pg = dist.new_group(ranks_on_i)
+ if i == machine_rank:
+ comm._LOCAL_PROCESS_GROUP = pg
+
+ assert num_gpus_per_machine <= torch.cuda.device_count()
+ torch.cuda.set_device(local_rank)
+
+ # synchronize is needed here to prevent a possible timeout after calling init_process_group
+ # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
+ comm.synchronize()
+
+ main_func(*args)
diff --git a/uniperceiver/engine/train_loop.py b/uniperceiver/engine/train_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..220997f7e571b28271beee450c53c255abb3cd62
--- /dev/null
+++ b/uniperceiver/engine/train_loop.py
@@ -0,0 +1,417 @@
+# -*- coding: utf-8 -*-
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+import logging
+import numpy as np
+import time
+import weakref
+from typing import List, Mapping, Optional
+import torch
+from torch.nn.parallel import DataParallel, DistributedDataParallel
+
+import uniperceiver.utils.comm as comm
+from uniperceiver.utils.events import EventStorage, get_event_storage
+from uniperceiver.utils.logger import _log_api_usage
+
+__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"]
+
+
+class HookBase:
+ """
+ Base class for hooks that can be registered with :class:`TrainerBase`.
+
+ Each hook can implement 4 methods. The way they are called is demonstrated
+ in the following snippet:
+ ::
+ hook.before_train()
+ for iter in range(start_iter, max_iter):
+ hook.before_step()
+ trainer.run_step()
+ hook.after_step()
+ iter += 1
+ hook.after_train()
+
+ Notes:
+ 1. In the hook method, users can access ``self.trainer`` to access more
+ properties about the context (e.g., model, current iteration, or config
+ if using :class:`DefaultTrainer`).
+
+ 2. A hook that does something in :meth:`before_step` can often be
+ implemented equivalently in :meth:`after_step`.
+ If the hook takes non-trivial time, it is strongly recommended to
+ implement the hook in :meth:`after_step` instead of :meth:`before_step`.
+ The convention is that :meth:`before_step` should only take negligible time.
+
+ Following this convention will allow hooks that do care about the difference
+ between :meth:`before_step` and :meth:`after_step` (e.g., timer) to
+ function properly.
+
+ """
+
+ trainer: "TrainerBase" = None
+ """
+ A weak reference to the trainer object. Set by the trainer when the hook is registered.
+ """
+
+ def before_train(self):
+ """
+ Called before the first iteration.
+ """
+ pass
+
+ def after_train(self):
+ """
+ Called after the last iteration.
+ """
+ pass
+
+ def before_step(self):
+ """
+ Called before each iteration.
+ """
+ pass
+
+ def after_step(self):
+ """
+ Called after each iteration.
+ """
+ pass
+
+ def state_dict(self):
+ """
+ Hooks are stateless by default, but can be made checkpointable by
+ implementing `state_dict` and `load_state_dict`.
+ """
+ return {}
+
+
+class TrainerBase:
+ """
+ Base class for iterative trainer with hooks.
+
+ The only assumption we made here is: the training runs in a loop.
+ A subclass can implement what the loop is.
+ We made no assumptions about the existence of dataloader, optimizer, model, etc.
+
+ Attributes:
+ iter(int): the current iteration.
+
+ start_iter(int): The iteration to start with.
+ By convention the minimum possible value is 0.
+
+ max_iter(int): The iteration to end training.
+
+ storage(EventStorage): An EventStorage that's opened during the course of training.
+ """
+
+ def __init__(self) -> None:
+ self._hooks: List[HookBase] = []
+ self.iter: int = 0
+ self.start_iter: int = 0
+ self.max_iter: int
+ self.storage: EventStorage
+ _log_api_usage("trainer." + self.__class__.__name__)
+
+ def register_hooks(self, hooks: List[Optional[HookBase]]) -> None:
+ """
+ Register hooks to the trainer. The hooks are executed in the order
+ they are registered.
+
+ Args:
+ hooks (list[Optional[HookBase]]): list of hooks
+ """
+ hooks = [h for h in hooks if h is not None]
+ for h in hooks:
+ assert isinstance(h, HookBase)
+ # To avoid circular reference, hooks and trainer cannot own each other.
+ # This normally does not matter, but will cause memory leak if the
+ # involved objects contain __del__:
+ # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/
+ h.trainer = weakref.proxy(self)
+ self._hooks.extend(hooks)
+
+ def train(self, start_iter: int, max_iter: int):
+ """
+ Args:
+ start_iter, max_iter (int): See docs above
+ """
+ logger = logging.getLogger(__name__)
+ logger.info("Starting training from iteration {}".format(start_iter))
+
+ self.iter = self.start_iter = start_iter
+ self.max_iter = max_iter
+
+ with EventStorage(start_iter) as self.storage:
+ try:
+ self.before_train()
+ for self.iter in range(start_iter, max_iter):
+ self.before_step()
+ self.run_step()
+ self.after_step()
+ # self.iter == max_iter can be used by `after_train` to
+ # tell whether the training successfully finished or failed
+ # due to exceptions.
+ self.iter += 1
+ except Exception:
+ logger.exception("Exception during training:")
+ raise
+ finally:
+ self.after_train()
+
+ def before_train(self):
+ for h in self._hooks:
+ h.before_train()
+
+ def after_train(self):
+ self.storage.iter = self.iter
+ for h in self._hooks:
+ h.after_train()
+
+ def before_step(self):
+ # Maintain the invariant that storage.iter == trainer.iter
+ # for the entire execution of each step
+ self.storage.iter = self.iter
+
+ for h in self._hooks:
+ h.before_step()
+
+ def after_step(self):
+ for h in self._hooks:
+ h.after_step()
+
+ def run_step(self):
+ raise NotImplementedError
+
+ def state_dict(self):
+ ret = {"iteration": self.iter}
+ hooks_state = {}
+ for h in self._hooks:
+ sd = h.state_dict()
+ if sd:
+ name = type(h).__qualname__
+ if name in hooks_state:
+ # TODO handle repetitive stateful hooks
+ continue
+ hooks_state[name] = sd
+ if hooks_state:
+ ret["hooks"] = hooks_state
+ return ret
+
+ def load_state_dict(self, state_dict):
+ logger = logging.getLogger(__name__)
+ self.iter = state_dict["iteration"]
+ for key, value in state_dict.get("hooks", {}).items():
+ for h in self._hooks:
+ try:
+ name = type(h).__qualname__
+ except AttributeError:
+ continue
+ if name == key:
+ h.load_state_dict(value)
+ break
+ else:
+ logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.")
+
+
+class SimpleTrainer(TrainerBase):
+ """
+ A simple trainer for the most common type of task:
+ single-cost single-optimizer single-data-source iterative optimization,
+ optionally using data-parallelism.
+ It assumes that every step, you:
+
+ 1. Compute the loss with a data from the data_loader.
+ 2. Compute the gradients with the above loss.
+ 3. Update the model with the optimizer.
+
+ All other tasks during training (checkpointing, logging, evaluation, LR schedule)
+ are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`.
+
+ If you want to do anything fancier than this,
+ either subclass TrainerBase and implement your own `run_step`,
+ or write your own training loop.
+ """
+
+ def __init__(self, model, data_loader, optimizer):
+ """
+ Args:
+ model: a torch Module. Takes a data from data_loader and returns a
+ dict of losses.
+ data_loader: an iterable. Contains data to be used to call model.
+ optimizer: a torch optimizer.
+ """
+ super().__init__()
+
+ """
+ We set the model to training mode in the trainer.
+ However it's valid to train a model that's in eval mode.
+ If you want your model (or a submodule of it) to behave
+ like evaluation during training, you can overwrite its train() method.
+ """
+ model.train()
+
+ self.model = model
+ self.data_loader = data_loader
+ self._data_loader_iter = iter(data_loader)
+ self.optimizer = optimizer
+
+ def run_step(self):
+ """
+ Implement the standard training logic described above.
+ """
+ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
+ start = time.perf_counter()
+ """
+ If you want to do something with the data, you can wrap the dataloader.
+ """
+ data = next(self._data_loader_iter)
+ data_time = time.perf_counter() - start
+
+ """
+ If you want to do something with the losses, you can wrap the model.
+ """
+ loss_dict = self.model(data)
+ if isinstance(loss_dict, torch.Tensor):
+ losses = loss_dict
+ loss_dict = {"total_loss": loss_dict}
+ else:
+ losses = sum(loss_dict.values())
+
+ """
+ If you need to accumulate gradients or do something similar, you can
+ wrap the optimizer with your custom `zero_grad()` method.
+ """
+ self.optimizer.zero_grad()
+ losses.backward()
+
+ self._write_metrics(loss_dict, data_time)
+
+ """
+ If you need gradient clipping/scaling or other processing, you can
+ wrap the optimizer with your custom `step()` method. But it is
+ suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4
+ """
+ self.optimizer.step()
+
+ def _write_metrics(
+ self,
+ loss_dict: Mapping[str, torch.Tensor],
+ data_time: float,
+ prefix: str = "",
+ ) -> None:
+ SimpleTrainer.write_metrics(loss_dict, data_time, prefix)
+
+ @staticmethod
+ def write_metrics(
+ loss_dict: Mapping[str, torch.Tensor],
+ data_time: float,
+ prefix: str = "",
+ ) -> None:
+ """
+ Args:
+ loss_dict (dict): dict of scalar losses
+ data_time (float): time taken by the dataloader iteration
+ prefix (str): prefix for logging keys
+ """
+ metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()}
+ metrics_dict["data_time"] = data_time
+
+ # Gather metrics among all workers for logging
+ # This assumes we do DDP-style training, which is currently the only
+ # supported method in detectron2.
+ all_metrics_dict = comm.gather(metrics_dict)
+
+ if comm.is_main_process():
+ storage = get_event_storage()
+
+ # data_time among workers can have high variance. The actual latency
+ # caused by data_time is the maximum among workers.
+ data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
+ storage.put_scalar("data_time", data_time)
+
+ # average the rest metrics
+ metrics_dict = {
+ k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys()
+ }
+ total_losses_reduced = sum(metrics_dict.values())
+ if not np.isfinite(total_losses_reduced):
+ raise FloatingPointError(
+ f"Loss became infinite or NaN at iteration={storage.iter}!\n"
+ f"loss_dict = {metrics_dict}"
+ )
+
+ storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced)
+ if len(metrics_dict) > 1:
+ storage.put_scalars(**metrics_dict)
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["optimizer"] = self.optimizer.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self.optimizer.load_state_dict(state_dict["optimizer"])
+
+
+class AMPTrainer(SimpleTrainer):
+ """
+ Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision
+ in the training loop.
+ """
+
+ def __init__(self, model, data_loader, optimizer, grad_scaler=None):
+ """
+ Args:
+ model, data_loader, optimizer: same as in :class:`SimpleTrainer`.
+ grad_scaler: torch GradScaler to automatically scale gradients.
+ """
+ unsupported = "AMPTrainer does not support single-process multi-device training!"
+ if isinstance(model, DistributedDataParallel):
+ assert not (model.device_ids and len(model.device_ids) > 1), unsupported
+ assert not isinstance(model, DataParallel), unsupported
+
+ super().__init__(model, data_loader, optimizer)
+
+ if grad_scaler is None:
+ from torch.cuda.amp import GradScaler
+
+ grad_scaler = GradScaler()
+ self.grad_scaler = grad_scaler
+
+ def run_step(self):
+ """
+ Implement the AMP training logic.
+ """
+ assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
+ assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
+ from torch.cuda.amp import autocast
+
+ start = time.perf_counter()
+ data = next(self._data_loader_iter)
+ data_time = time.perf_counter() - start
+
+ with autocast():
+ loss_dict = self.model(data)
+ if isinstance(loss_dict, torch.Tensor):
+ losses = loss_dict
+ loss_dict = {"total_loss": loss_dict}
+ else:
+ losses = sum(loss_dict.values())
+
+ self.optimizer.zero_grad()
+ self.grad_scaler.scale(losses).backward()
+
+ self._write_metrics(loss_dict, data_time)
+
+ self.grad_scaler.step(self.optimizer)
+ self.grad_scaler.update()
+
+ def state_dict(self):
+ ret = super().state_dict()
+ ret["grad_scaler"] = self.grad_scaler.state_dict()
+ return ret
+
+ def load_state_dict(self, state_dict):
+ super().load_state_dict(state_dict)
+ self.grad_scaler.load_state_dict(state_dict["grad_scaler"])
diff --git a/uniperceiver/engine/unified_tester.py b/uniperceiver/engine/unified_tester.py
new file mode 100644
index 0000000000000000000000000000000000000000..3df00b3379653dc392935684692c6894d3fb4adb
--- /dev/null
+++ b/uniperceiver/engine/unified_tester.py
@@ -0,0 +1,340 @@
+from torch.functional import Tensor
+import tqdm
+import os
+import pickle
+import sys
+import numpy as np
+import itertools
+import random
+import torch
+from torch.cuda.amp import autocast
+import shutil
+import uniperceiver.utils.comm as comm
+from timm.utils import accuracy
+from collections import defaultdict, deque
+import torch.distributed as dist
+
+
+class SmoothedValue(object):
+ """Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+
+ # borrowed from diet and mae
+ """
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not comm.is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value)
+
+
+def tester(task_cfg, model, test_data_loader, evaluator, epoch, amp_fp16, apex_fp16):
+ results = dict()
+ for task in test_data_loader.keys():
+ comm._LOCAL_CURRENT_TASK = task # used for other script
+ if test_data_loader[task] is None:
+ continue
+ if comm.is_main_process():
+ print('val/test task {}'.format(task))
+ if 'to_task' in dir(model):
+ model.to_task(task)
+ else:
+ model.module.to_task(task)
+ task_type = task_cfg[task]['DATASETS']['TASK_TYPE']
+ if task_type in ["image_retrieval", 'video_retrieval']:
+ results[task] = test_retrieval(task_cfg[task], model, test_data_loader[task], evaluator[task], epoch, amp_fp16, task)
+ else:
+ results[task] = test_cls(task_cfg[task], model, test_data_loader[task], evaluator[task], epoch, amp_fp16, task)
+
+ if 'reset_attr' in dir(model):
+ model.reset_attr()
+ else:
+ model.module.reset_attr()
+ return results
+
+
+# TODO write eval func for each task_type
+def test_cls(cfg, model, test_data_loader, evaluator, epoch, amp_fp16, task=None):
+ # only one works
+ # if not comm.is_main_process():
+ # return None
+ model.eval()
+ results = []
+
+ if not os.path.exists(comm.temp_dir):
+ os.mkdir(comm.temp_dir)
+
+ # shared_seed = comm.shared_random_seed() this simply does not work!
+ shared_seed = random.randint(0, sys.maxsize)
+ shared_seed = torch.tensor(shared_seed, device=next(model.parameters()).device)
+ torch.distributed.broadcast(shared_seed, src=0)
+ shared_seed = shared_seed.item()
+ if comm.is_main_process():
+ os.makedirs(os.path.join(comm.temp_dir, str(shared_seed)))
+ comm.synchronize()
+
+ # remove the cached embedding for word vocab
+ if isinstance(getattr(comm.unwrap_model(model), 'beam_searcher', None), torch.nn.Module):
+ if hasattr(getattr(comm.unwrap_model(model), 'beam_searcher', None), 'pre_computed_word_embeds'):
+ del comm.unwrap_model(model).beam_searcher.pre_computed_word_embeds
+ comm.unwrap_model(model).beam_searcher.pre_computed_word_embeds = None
+
+ meters = defaultdict(SmoothedValue)
+ with torch.no_grad():
+
+ for i, data in tqdm.tqdm(enumerate(test_data_loader)) if comm.is_main_process() else enumerate(test_data_loader):
+ # data = comm.unwrap_model(model).preprocess_batch(data)
+ # if i > 10:
+ # break
+ # model.train()
+ # return {}
+ if task is not None:
+ data["task_info"]['task_name'] = task
+ data = move_to_cuda(data)
+ task_type = data['task_info']['task_type']
+
+ sample_infos = data['input_sample_list'][0].get('sample_info', None)
+ with autocast(amp_fp16):
+ if cfg.INFERENCE.GENERATION_MODE:
+ res = model(data, use_beam_search=True, output_sents=True)
+ else:
+ res = model(data)
+
+ if isinstance(res["output"], torch.Tensor) and res["output"].dtype != torch.float32:
+ res["output"] = res["output"].float()
+
+ outputs = res["output"]
+
+ if task_type == 'vqa':
+ u_logits = res["output"]
+ outputs = torch.softmax(u_logits, dim=-1)
+ outputs = torch.max(outputs, 1)[1].data
+
+ if isinstance(data['input_sample_list'][0]['sample_info'], dict):
+ # single gpu; changes for data['input_sample_list'][0]['sample_info']
+ sample_infos = data['input_sample_list'][0]['sample_info']['sample_info_per_sample'][1]
+ elif isinstance(data['input_sample_list'][0]['sample_info'], list):
+ # multi gpu; original data
+ sample_infos = data['input_sample_list'][1]['sample_info']
+
+ for sample_info_pers_ample, output in zip(sample_infos, outputs):
+ if isinstance(output, torch.Tensor):
+ output = output.cpu()
+ # results.append({ "task_name": task, "answer": output, "question_id": int(sample_info_pers_ample['question_id'])})
+ results.append({ "answer": output, "question_id": int(sample_info_pers_ample['question_id'])})
+
+ elif task_type in ['image_classification']:
+ # targets in the input data
+ targets = data['target_idx_list'][0]
+ acc1, acc5 = accuracy(outputs, targets, topk=(1, 5))
+ bs = targets.shape[0]
+ meters['acc1'].update(acc1.item(), n=bs)
+ meters['acc5'].update(acc5.item(), n=bs)
+
+ pass
+ "an early version for evaluating Imagenet-1K "
+ """
+ # rely on ids to retrive label
+ outputs = outputs.cpu()
+ if isinstance(data['input_sample_list'][0]['sample_info'], dict):
+ # single gpu; changes for data['input_sample_list'][0]['sample_info']
+ sample_infos = data['input_sample_list'][0]['sample_info']['sample_info_per_sample'][0]
+ elif isinstance(data['input_sample_list'][0]['sample_info'], list):
+ # multi gpu; original data
+ sample_infos = data['input_sample_list'][0]['sample_info']
+ else:
+ raise NotImplementedError('please check')
+
+ for idx, si in enumerate(sample_infos):
+ results.append({cfg.INFERENCE.ID_KEY: si['id'], cfg.INFERENCE.VALUE: outputs[idx]})
+ """
+ elif task_type in ['image_caption', 'video_caption']:
+ ids = res["IDS"]
+ for id, output in zip(ids, outputs):
+ results.append({"image_id": int(id.item()), "caption": output})
+ elif task_type in ['text_classification']:
+ for label, output in zip(data['target_idx_list'][0], outputs):
+ results.append({"label": int(label), "pred": output})
+
+ elif task_type in ['video_classification']:
+ # targets in the input data
+ targets = data['target_idx_list'][0]
+ outputs = torch.softmax(outputs, -1).view(-1, sample_infos[0]['num_views'], outputs.size(-1)).mean(1)
+ acc1 = accuracy(outputs, targets, topk=(1,))[0]
+ bs = targets.shape[0]
+ meters['acc1'].update(acc1.item(), n=bs)
+
+ else:
+ raise NotImplementedError
+
+
+ if task_type in ['image_classification']:
+ for meter in meters.values():
+ meter.synchronize_between_processes()
+ eval_res = {'Acc@1': meters['acc1'].global_avg, 'Acc@5': meters['acc5'].global_avg}
+ elif task_type in ['video_classification']:
+ for meter in meters.values():
+ meter.synchronize_between_processes()
+ eval_res = {'Acc@1': meters['acc1'].global_avg}
+ else:
+ with open(os.path.join(comm.temp_dir, str(shared_seed), "rank_{}.pkl".format(comm.get_rank())), 'wb') as f:
+ # json.dump(results, f)
+ pickle.dump(results, f)
+ comm.synchronize()
+ if comm.is_main_process():
+ results_all = list()
+ for i in range(comm.get_world_size()):
+ with open(os.path.join(comm.temp_dir, str(shared_seed), "rank_{}.pkl".format(i)), 'rb') as f:
+ # results_all += json.load(f)
+ results_all += pickle.load(f)
+
+ results = results_all
+
+ if evaluator is not None:
+ eval_res = evaluator.eval(results, epoch)
+ else:
+ eval_res = ''
+
+ # remove cached files
+ shutil.rmtree(os.path.join(comm.temp_dir, str(shared_seed)))
+
+ model.train()
+ comm.synchronize()
+ if comm.is_main_process():
+ return eval_res
+ else:
+ return None
+
+
+def test_retrieval(cfg, model, test_data_loader, evaluator, epoch, amp_fp16, task=None):
+
+ if evaluator is not None:
+ if not comm.is_main_process():
+ comm.synchronize()
+ return None
+ ret = {}
+ model.eval()
+ ids = []
+ vfeats = []
+ tfeats = []
+ with torch.no_grad():
+ for data in tqdm.tqdm(test_data_loader):
+ if task is not None:
+ data["task_info"]['task_name'] = task
+ data = move_to_cuda(data)
+ # task_type = data['task_info']['task_type']
+
+ ids_local = [si['id'] for si in data['input_sample_list'][0]['sample_info']]
+ with autocast(amp_fp16):
+ outputs = model(data)
+ ids += ids_local
+ vfeats.append(outputs["input_feats"])
+ tfeats.append(outputs["tgt_feats"])
+
+ iids = [i[0] for i in ids]
+ cids = [i[1] for i in ids]
+ cids = list(itertools.chain.from_iterable(cids))
+ labels = np.expand_dims(cids, axis=1) == np.expand_dims(iids, axis=0)
+ labels = labels.astype(int)
+ vfeats = torch.cat(vfeats, dim=0)
+ tfeats = torch.cat(tfeats, dim=0)
+
+ ret.update(evaluator.eval(vfeats, tfeats, labels, 't2i'))
+ ret.update(evaluator.eval(tfeats, vfeats, labels.T, 'i2t'))
+ model.train()
+ comm.synchronize()
+ return ret
+
+ else:
+ raise NotImplementedError('please use \'RetrievalEvaler\'.')
+
+
+def move_to_cuda(data):
+ if isinstance(data, dict):
+ for key in data:
+ data[key] = move_to_cuda(data[key])
+ return data
+ elif isinstance(data, list):
+ return [move_to_cuda(item) for item in data]
+ elif isinstance(data, torch.Tensor):
+ return data.cuda(non_blocking=True)
+ else:
+ # let alone variable with other type
+ return data
+
+
+def dict_to_cuda(input_dict):
+ for key in input_dict:
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ elif isinstance(input_dict[key], dict):
+ input_dict[key] = dict_to_cuda(input_dict[key])
+ return input_dict
+
+
+def list_to_cuda(input_list):
+ # e.g., shared_targets
+ return [dict_to_cuda(item) if isinstance(item, dict) else item for item in input_list]
+
+
+def data_to_cuda(data):
+ data = dict_to_cuda(data)
+ data['net_input']['shared_targets'] = list_to_cuda(data['net_input']['shared_targets'])
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+ """
+ Performs all_gather operation on the provided tensors.
+ *** Warning ***: torch.distributed.all_gather has no gradient.
+ """
+ tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())]
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+ output = torch.cat(tensors_gather, dim=0)
+ return output
diff --git a/uniperceiver/engine/unified_trainer.py b/uniperceiver/engine/unified_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..62301cdf585d14a57b2314cd36102b5735e6170b
--- /dev/null
+++ b/uniperceiver/engine/unified_trainer.py
@@ -0,0 +1,520 @@
+import time
+import tqdm
+import os
+import json
+import pickle
+import sys
+import copy
+import numpy as np
+import itertools
+import random
+import torch
+import io
+from torch.nn.parallel import DistributedDataParallel
+from torch.cuda.amp import autocast
+from .unified_tester import tester, dict_to_cuda, list_to_cuda, move_to_cuda
+from collections import OrderedDict
+from uniperceiver.evaluation import build_evaluation
+import uniperceiver.utils.comm as comm
+from uniperceiver.utils.engine_util import *
+from .build import ENGINE_REGISTRY
+from uniperceiver.datasets import (
+ build_standard_valtest_loader,
+ build_unified_train_loader,
+)
+
+from uniperceiver.utils.events import get_event_storage
+from uniperceiver.utils.events import EventStorage
+from omegaconf import DictConfig
+from uniperceiver.losses import build_losses
+from uniperceiver.optim import build_optimizer
+from uniperceiver.modeling import build_model
+from uniperceiver.lr_scheduler import build_lr_scheduler
+from torch.cuda.amp import autocast
+from uniperceiver.checkpoint import TorchCheckpointer
+
+import logging
+import math
+import weakref
+
+from uniperceiver.config import CfgNode
+
+
+from . import hooks
+
+
+from timm.data import Mixup
+from timm.utils import ModelEma
+from uniperceiver.utils.misc import NativeScalerWithGradNormCount as NativeScaler
+from uniperceiver.utils.misc import ApexScalerWithGradNormCount as ApexScaler
+
+from collections import defaultdict
+from .train_loop import TrainerBase
+from uniperceiver.utils.logger import setup_logger
+
+try:
+ from apex import amp
+ APEX_INSTALLED = True
+except:
+ print('apex has not been installed.')
+ APEX_INSTALLED = False
+
+__all__ = ['UnifiedTrainer']
+
+
+@ENGINE_REGISTRY.register()
+class UnifiedTrainer(TrainerBase):
+ def __init__(self, cfg):
+ super().__init__()
+ self.logger = logging.getLogger(__name__)
+ if not self.logger.isEnabledFor(
+ logging.INFO): # setup_logger is not called for d2
+ setup_logger()
+
+ self.task_cfg = dict()
+ self.task_names = []
+ for task in cfg.TASKS:
+ name = task['NAME']
+ self.task_names.append(name)
+
+ # self.task_cfg[name] = new_cfg
+ self.task_cfg[name] = CfgNode(task)
+
+ self.cfg = cfg
+
+ # Assume these objects must be constructed in this order.
+ model = self.build_model(cfg)
+ self.logger.info("Model Creation Done")
+
+ self.apex_need_reload = False
+
+ self.optimizer = self.build_optimizer(cfg, model)
+
+ if cfg.SOLVER.APEX_FP16 and APEX_INSTALLED:
+ self.apex_fp16 = True
+
+ model, self.optimizer = amp.initialize(model,
+ self.optimizer,
+ opt_level=self.cfg.SOLVER.APEX_OPT_LEVEL,
+ master_weights=self.cfg.SOLVER.APEX_MASTER_WEIGHTS,
+ min_loss_scale=self.cfg.SOLVER.MIN_LOSS_SCLE,
+ loss_scale="dynamic")
+
+ # For training, wrap with DDP. But don't need this for inference.
+ if comm.get_world_size() > 1:
+ model = DistributedDataParallel(
+ model,
+ find_unused_parameters=cfg.find_unused_parameters,
+ device_ids=[comm.get_local_rank()],
+ broadcast_buffers=False)
+ self.model = model
+
+
+ self.model.train()
+
+ self.train_data_loader = build_train_loader(cfg, self.task_cfg, self.model)
+ self.val_data_loader = build_val_loader(cfg, self.task_cfg)
+ self.test_data_loader = build_test_loader(cfg, self.task_cfg)
+
+ if isinstance(self.train_data_loader, list):
+ self.iters_per_epoch_list = [
+ len(loader) for loader in self.train_data_loader
+ ]
+ self._train_data_loader_iter_list = [
+ iter(loader) for loader in self.train_data_loader
+ ]
+
+ self.iters_per_epoch = len(self.train_data_loader[0])
+ self._train_data_loader_iter = iter(self.train_data_loader[0])
+ else:
+ self.iters_per_epoch = len(self.train_data_loader)
+ self._train_data_loader_iter = iter(self.train_data_loader)
+
+ if self.val_data_loader is not None:
+ self.val_evaluator = build_evaluation(cfg,
+ cfg.INFERENCE.VAL_ANNFILE,
+ None)
+ else:
+ self.val_evaluator = None
+
+ if self.test_data_loader is not None:
+ self.test_evaluator = build_evaluation(cfg,
+ cfg.INFERENCE.TEST_ANNFILE,
+ cfg.OUTPUT_DIR)
+ else:
+ self.test_evaluator = None
+
+ self.ss_prob = 0.0
+
+
+ self.model_ema = None
+ if cfg.MODEL.MODEL_EMA:
+ self.model_ema = ModelEma(
+ self.model,
+ decay=cfg.MODEL.MODEL_EMA_DECAY,
+ device='cpu' if cfg.MODEL.MODEL_EMA_FORCE_CPU else '',
+ resume='')
+
+ self.checkpointer = TorchCheckpointer(
+ # Assume you want to save checkpoints together with logs/statistics
+ self.model,
+ self.model_ema,
+ cfg.OUTPUT_DIR,
+ trainer=weakref.proxy(self),
+ checkpoint_mapping=cfg.SOLVER.CHECKPOINT_MAPPING,
+ mapping=cfg.SOLVER.CHECKPOINT_MAP,
+ resume_tau=cfg.SOLVER.RESUME_TAU,
+ ceph_save=cfg.SOLVER.CHECKPOINT_CEPH_SAVE,
+ ceph_config=cfg.DATALOADER.get("TCS_CONF_PATH",
+ "petreloss.config"),
+ )
+ self.checkpointer.add_checkpointable('optimizer', self.optimizer)
+
+ if cfg.MODEL.MODEL_EMA:
+ self.checkpointer.add_checkpointable('ema_model',self.model_ema.ema)
+
+ self.start_iter = 0
+ self.max_iter = cfg.SOLVER.EPOCH * self.iters_per_epoch
+ self.register_hooks(self.build_hooks())
+
+ if cfg.SOLVER.AMP_FP16:
+ # Creates a GradScaler once at the beginning of training.
+ self.amp_scaler = NativeScaler(enabled=True, growth_interval=cfg.SOLVER.LOSS_SCALE_WINDOW)
+ self.amp_fp16=True
+ else:
+ self.amp_scaler = NativeScaler(enabled=False)
+ self.amp_fp16=False
+
+ if cfg.SOLVER.APEX_FP16 and APEX_INSTALLED:
+
+ self.amp_scaler = ApexScaler(enabled=True)
+
+ else:
+ self.apex_fp16 = False
+
+ self.fp16 = cfg.SOLVER.AMP_FP16 or cfg.SOLVER.APEX_FP16
+ self.bf16 = cfg.SOLVER.BF16
+ if self.fp16:
+ assert not self.bf16
+
+ if self.amp_scaler is not None:
+ self.checkpointer.add_checkpointable('amp_scaler', self.amp_scaler)
+
+
+ self.val_evaluator = dict()
+ self.test_evaluator = dict()
+ self.mixup_fn = dict()
+ for name, new_cfg in self.task_cfg.items():
+ if self.val_data_loader[name]:
+ self.val_evaluator[name] = build_evaluation(
+ new_cfg, new_cfg.INFERENCE.VAL_ANNFILE, cfg.OUTPUT_DIR)
+ else:
+ self.val_evaluator[name] = None
+ if self.test_data_loader[name]:
+ self.test_evaluator[name] = build_evaluation(new_cfg, new_cfg.INFERENCE.TEST_ANNFILE, cfg.OUTPUT_DIR)
+ else:
+ self.test_evaluator[name] = None
+
+ if new_cfg.DATALOADER.MIXUP > 0 or new_cfg.DATALOADER.CUTMIX > 0:
+ self.mixup_fn[name] = Mixup(
+ mixup_alpha=new_cfg.DATALOADER.MIXUP, cutmix_alpha=new_cfg.DATALOADER.CUTMIX, cutmix_minmax=None,
+ prob=new_cfg.DATALOADER.MIXUP_PROB, switch_prob=new_cfg.DATALOADER.MIXUP_SWITCH_PROB, mode=new_cfg.DATALOADER.MIXUP_MODE,
+ label_smoothing=new_cfg.DATALOADER.MIXUP_LABEL_SMOOTHING, num_classes=new_cfg.MODEL.LABELS_NUM)
+ else:
+ self.mixup_fn[name] = None
+
+ if cfg.DATALOADER.USE_WEIGHTED_SAMPLER:
+ # this is to avoid strange behaviors.
+ self.iters_per_epoch = 1
+ # override the previous scheduler
+
+ self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, self.iters_per_epoch)
+ self.checkpointer.add_checkpointable('scheduler', self.scheduler)
+
+ self.accum_iter = max(1, cfg.SOLVER.ACCUM_ITER)
+ self.step_index = 0
+
+ self.grad_print = getattr(cfg.SOLVER, "GRAD_PRINT", False)
+
+ if self.cfg.SOLVER.GradHistogram:
+ assert self.cfg.SOLVER.TORCH_OPTIMIZER and self.cfg.SOLVER.PARAMS_SEPERATE
+
+ def resume_or_load(self, resume=True):
+
+ self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS,
+ resume=resume,
+ resume_optmizer=self.cfg.SOLVER.RESUME_OPTIMIZER)
+ if resume and self.checkpointer.has_checkpoint():
+ self.start_iter = self.iter + 1
+ # make apex resume work
+ if self.apex_fp16:
+ self.apex_need_reload = True
+
+ @classmethod
+ def build_losses(cls, cfg):
+ losses = {}
+ for task_config in cfg.TASKS:
+ task_config = DictConfig(task_config)
+ losses[task_config.NAME] = build_losses(task_config)
+
+ return losses
+
+ def build_hooks(self):
+
+ self.max_iter = self.cfg.SOLVER.MAX_ITER
+ cfg = self.cfg.clone()
+ cfg.defrost()
+ cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
+
+ ret = [
+ hooks.IterationTimer(),
+ hooks.LRScheduler(),
+ hooks.ModelWeightsManipulating()
+ ]
+
+ # Do PreciseBN before checkpointer, because it updates the model and need to
+ # be saved by checkpointer.
+ # This is not always the best: if checkpointing has a different frequency,
+ # some checkpoints may have more precise statistics than others.
+ if comm.is_main_process():
+ ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD,
+ max_to_keep= cfg.SOLVER.CHECKPOINT_MAX_SAVE ))
+
+ def test_and_save_results(epoch):
+ eval_results = self.test(self.cfg, self.model, self.test_data_loader, self.test_evaluator, epoch)
+ return eval_results
+
+ def val_and_save_results(epoch):
+ eval_results = self.test(self.cfg, self.model, self.val_data_loader, self.val_evaluator, epoch)
+ return eval_results
+
+ if self.model_ema is not None:
+
+ def test_and_save_results_ema(epoch):
+ eval_results = self.test(self.cfg, self.model_ema.ema,
+ self.test_data_loader,
+ self.test_evaluator, epoch)
+ ema_results = {}
+ for taskname, taskresults in eval_results.items():
+ if isinstance(taskresults, dict):
+ taskresults = {
+ f'{k}_ema': v
+ for k, v in taskresults.items()
+ }
+ ema_results[taskname] = taskresults
+
+ return ema_results
+
+ def val_and_save_results_ema(epoch):
+ eval_results = self.test(self.cfg, self.model_ema.ema,
+ self.val_data_loader,
+ self.val_evaluator, epoch)
+ ema_results = {}
+ for taskname, taskresults in eval_results.items():
+ if isinstance(taskresults, dict):
+ taskresults = {f'{k}_ema': v for k, v in taskresults.items()}
+ ema_results[taskname] = taskresults
+
+ return ema_results
+
+ # Do evaluation after checkpointer, because then if it fails,
+ # we can use the saved checkpoint to debug.
+ if self.val_data_loader is not None:
+ ret.append(
+ hooks.IterEvalHook(
+ eval_period = cfg.SOLVER.EVAL_PERIOD,
+ eval_start = cfg.INFERENCE.VAL_EVAL_START,
+ eval_function = val_and_save_results,
+ stage = 'val',
+ multi_gpu_eval=True
+ ))
+ if self.model_ema is not None:
+ ret.append(
+ hooks.IterEvalHook(
+ eval_period = cfg.SOLVER.EVAL_PERIOD,
+ eval_start = cfg.INFERENCE.VAL_EVAL_START,
+ eval_function = val_and_save_results_ema,
+ stage = 'val',
+ multi_gpu_eval=True
+ ))
+
+ if self.test_data_loader is not None:
+ ret.append(
+ hooks.IterEvalHook(
+ eval_period = cfg.SOLVER.EVAL_PERIOD,
+ eval_start = cfg.INFERENCE.TEST_EVAL_START,
+ eval_function = test_and_save_results,
+ stage = 'test',
+ multi_gpu_eval=True
+ ))
+ if self.model_ema is not None:
+ ret.append(
+ hooks.IterEvalHook(
+ eval_period=cfg.SOLVER.EVAL_PERIOD,
+ eval_start=cfg.INFERENCE.TEST_EVAL_START,
+ eval_function=test_and_save_results_ema,
+ stage='test',
+ multi_gpu_eval=True))
+
+ if comm.is_main_process():
+ # Here the default print/log frequency of each writer is used.
+ # run writers in the end, so that evaluation metrics are written
+ ret.append(hooks.PeriodicWriter(build_writers(cfg, self.max_iter), period=cfg.SOLVER.WRITE_PERIOD))
+
+ return ret
+
+ def train(self):
+ """
+ Args:
+ start_iter, max_iter (int): See docs above
+ """
+ start_iter = self.start_iter
+ max_iter = self.max_iter
+ logger = logging.getLogger(__name__)
+ logger.info("Starting training from iteration {}".format(start_iter))
+
+ self.iter = self.start_iter = start_iter
+ self.max_iter = max_iter
+
+ with EventStorage(start_iter) as self.storage:
+ try:
+
+ self.before_train()
+ for self.iter in range(start_iter, max_iter):
+ self.before_step()
+
+ self.run_step_torch()
+
+ self.after_step()
+
+ if self.apex_need_reload:
+ optimizer_state_dict = torch.load(self.checkpointer.get_checkpoint_file())['optimizer']
+ self.optimizer.load_state_dict(optimizer_state_dict)
+ self.apex_need_reload = False
+
+ self.iter += 1
+ except Exception:
+ logger.exception("Exception during training:")
+ raise
+ finally:
+ self.after_train()
+
+ @classmethod
+ def build_model(cls, cfg):
+ model = build_model(cfg)
+ logger = logging.getLogger(__name__)
+ logger.info("Model:\n{}".format(model))
+ return model
+
+ @classmethod
+ def build_optimizer(cls, cfg, model):
+ logger = logging.getLogger(__name__)
+ logger.info("building optimizer...")
+ return build_optimizer(cfg, model)
+
+ @classmethod
+ def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
+ logger = logging.getLogger(__name__)
+ logger.info("building lr_scheduler...")
+ return build_lr_scheduler(cfg, optimizer, iters_per_epoch)
+
+ def run_step_torch(self):
+ if self.accum_iter > 1:
+ for micro_step in range(self.accum_iter):
+ self.micro_step = micro_step
+ self.run_min_batch()
+ else:
+ self.micro_step = 0
+ self.run_min_batch()
+
+ def run_min_batch(self):
+ timer_fn = time.perf_counter
+ assert self.model.training, "[SimpleTrainer] model was changed to eval mode!"
+ torch.cuda.synchronize()
+
+ start = timer_fn()
+ data = get_batch_data(self.cfg, self._train_data_loader_iter, self.train_data_loader)
+ data_time = time.perf_counter() - start
+
+ task = data['task_info']['task_name']
+ data = move_to_cuda(data)
+
+ #TODO: task specifix code, move into model
+ if self.mixup_fn[task] is not None:
+ # imagenet
+ data['input_sample_list'][0]["data"], data[
+ 'target_idx_list'][0] = self.mixup_fn[task](
+ data['input_sample_list'][0]["data"], data["target_idx_list"][0])
+
+ if not self.amp_fp16:
+ losses_dict = self.model(data)
+
+ else:
+ with autocast(self.amp_fp16):
+ losses_dict = self.model(data)
+
+ losses = sum(losses_dict.values())
+
+ # for accum iter
+ losses /= self.accum_iter
+
+ total_grad = self.amp_scaler(losses, self.optimizer, clip_grad=self.cfg.SOLVER.GRAD_CLIP,
+ parameters=self.model.parameters(), create_graph=False,
+ update_grad=(self.micro_step + 1 == self.accum_iter), fp16=self.fp16, iter=self.iter,
+ min_loss_scale=self.cfg.SOLVER.MIN_LOSS_SCLE,
+ loss_scale_window=self.cfg.SOLVER.LOSS_SCALE_WINDOW)
+
+ if self.micro_step + 1 != self.accum_iter:
+ return
+
+ if self.micro_step + 1 == self.accum_iter:
+ write_metrics(losses_dict, data_time, task + '/')
+
+ if comm.is_main_process():
+ storage = get_event_storage()
+ if torch.logical_or(total_grad.isnan(), total_grad.isinf()):
+ logger = logging.getLogger(__name__)
+ logger.info('grad to nan or inf in task {} {}'.format(task, total_grad))
+ storage.put_scalar("total_grad", total_grad, smoothing_hint=False)
+
+ if self.apex_need_reload:
+ pass
+ else:
+ self.amp_scaler.step(self.optimizer)
+
+ if comm.is_main_process():
+ storage.put_scalar("amp_scale", self.amp_scaler.get_scale(), smoothing_hint=False)
+ if hasattr(comm.unwrap_model(self.model).loss_prepare, 'temperature_dict'):
+ if isinstance(comm.unwrap_model(self.model).loss_prepare, torch.nn.ModuleList):
+ temperature_dict = comm.unwrap_model(self.model).loss_prepare[-1].temperature_dict
+ else:
+ temperature_dict = comm.unwrap_model(self.model).loss_prepare.temperature_dict
+ storage.put_scalars(**temperature_dict, smoothing_hint=False)
+
+ if self.amp_fp16:
+ self.amp_scaler.update()
+
+
+ self.optimizer.zero_grad()
+ if self.model_ema is not None:
+ self.model_ema.update(self.model)
+ torch.cuda.synchronize()
+
+ def cast_layers(self):
+ logger = self.logger
+ if self.cfg.MODEL.LN_FP32:
+ logger.info("cast LN to fp32")
+
+ def cast_ln_fp32(module):
+ if isinstance(module, CustomLayernorm):
+ module.float()
+
+ self.model_engine.module.apply(cast_ln_fp32)
+
+ if self.iter == 0:
+ comm.unwrap_model(self.model).operatedweight()
+
+
+
+ def test(self, cfg, model, test_data_loader, evaluator, epoch):
+ return tester(self.task_cfg, model, test_data_loader, evaluator, epoch, self.amp_fp16, self.apex_fp16)
diff --git a/uniperceiver/evaluation/__init__.py b/uniperceiver/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73143fc6d21605d0cb18cdf36c70ca924d80a09e
--- /dev/null
+++ b/uniperceiver/evaluation/__init__.py
@@ -0,0 +1,12 @@
+from .build import build_evaluation
+
+from .imagenet_evaler import ImageNetEvaler
+from .coco_evaler import COCOEvaler
+from .retrieval_evaler import RetrievalEvaler
+from .mit_evaler import MiTEvaler
+from .vqa_eval import VQAEvaler
+from .glue_evaler import GLUEEvaler
+
+from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset
+
+from .testing import print_csv_format, verify_results
\ No newline at end of file
diff --git a/uniperceiver/evaluation/build.py b/uniperceiver/evaluation/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe3a6498503239537425e169acda3e8139c6fcb
--- /dev/null
+++ b/uniperceiver/evaluation/build.py
@@ -0,0 +1,13 @@
+
+import torch
+
+from uniperceiver.utils.registry import Registry
+
+EVALUATION_REGISTRY = Registry("EVALUATION")
+EVALUATION_REGISTRY.__doc__ = """
+Registry for evaluation
+"""
+
+def build_evaluation(cfg, annfile, output_dir):
+ evaluation = EVALUATION_REGISTRY.get(cfg.INFERENCE.NAME)(cfg, annfile, output_dir) if len(cfg.INFERENCE.NAME) > 0 else None
+ return evaluation
\ No newline at end of file
diff --git a/uniperceiver/evaluation/coco_evaler.py b/uniperceiver/evaluation/coco_evaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4d606e3a97e0c7f23ca8e14c535c2035f73a1aa
--- /dev/null
+++ b/uniperceiver/evaluation/coco_evaler.py
@@ -0,0 +1,52 @@
+import os
+import sys
+import tempfile
+import json
+from json import encoder
+from uniperceiver.config import configurable
+from .build import EVALUATION_REGISTRY
+from pycocotools.coco import COCO
+from pycocoevalcap.eval import COCOEvalCap
+from uniperceiver.utils import comm
+
+@EVALUATION_REGISTRY.register()
+class COCOEvaler(object):
+ def __init__(self, cfg, annfile, output_dir):
+ super(COCOEvaler, self).__init__()
+
+ self.coco = COCO(annfile) if annfile != '' else None
+ if not os.path.exists("./data/temp") and comm.is_main_process():
+ os.makedirs("./data/temp")
+
+ if output_dir is not None:
+ self.output_dir = os.path.join(output_dir, 'results')
+ if not os.path.exists(self.output_dir) and comm.is_main_process():
+ os.makedirs(self.output_dir)
+ else:
+ self.output_dir = None
+
+ def eval(self, results_input, epoch):
+ image_ids = []
+ results = []
+ for result in results_input:
+ if result['image_id'] not in image_ids:
+ results.append(result)
+ image_ids.append(result['image_id'])
+
+ if self.output_dir is not None:
+ json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w"))
+
+ in_file = tempfile.NamedTemporaryFile(mode='w',
+ delete=False,
+ dir="./data/temp")
+ json.dump(results, in_file)
+ in_file.close()
+
+ if self.coco is None:
+ return {}
+
+ cocoRes = self.coco.loadRes(in_file.name)
+ cocoEval = COCOEvalCap(self.coco, cocoRes)
+ cocoEval.evaluate()
+ os.remove(in_file.name)
+ return cocoEval.eval
\ No newline at end of file
diff --git a/uniperceiver/evaluation/evaluator.py b/uniperceiver/evaluation/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..10af1db2b53507fed8bdcf7e8787a528265959ba
--- /dev/null
+++ b/uniperceiver/evaluation/evaluator.py
@@ -0,0 +1,224 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import datetime
+import logging
+import time
+from collections import OrderedDict, abc
+from contextlib import ExitStack, contextmanager
+from typing import List, Union
+import torch
+from torch import nn
+
+from uniperceiver.utils.comm import get_world_size, is_main_process
+from uniperceiver.utils.logger import log_every_n_seconds
+
+
+class DatasetEvaluator:
+ """
+ Base class for a dataset evaluator.
+
+ The function :func:`inference_on_dataset` runs the model over
+ all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs.
+
+ This class will accumulate information of the inputs/outputs (by :meth:`process`),
+ and produce evaluation results in the end (by :meth:`evaluate`).
+ """
+
+ def reset(self):
+ """
+ Preparation for a new round of evaluation.
+ Should be called before starting a round of evaluation.
+ """
+ pass
+
+ def process(self, inputs, outputs):
+ """
+ Process the pair of inputs and outputs.
+ If they contain batches, the pairs can be consumed one-by-one using `zip`:
+
+ .. code-block:: python
+
+ for input_, output in zip(inputs, outputs):
+ # do evaluation on single input/output pair
+ ...
+
+ Args:
+ inputs (list): the inputs that's used to call the model.
+ outputs (list): the return value of `model(inputs)`
+ """
+ pass
+
+ def evaluate(self):
+ """
+ Evaluate/summarize the performance, after processing all input/output pairs.
+
+ Returns:
+ dict:
+ A new evaluator class can return a dict of arbitrary format
+ as long as the user can process the results.
+ In our train_net.py, we expect the following format:
+
+ * key: the name of the task (e.g., bbox)
+ * value: a dict of {metric name: score}, e.g.: {"AP50": 80}
+ """
+ pass
+
+
+class DatasetEvaluators(DatasetEvaluator):
+ """
+ Wrapper class to combine multiple :class:`DatasetEvaluator` instances.
+
+ This class dispatches every evaluation call to
+ all of its :class:`DatasetEvaluator`.
+ """
+
+ def __init__(self, evaluators):
+ """
+ Args:
+ evaluators (list): the evaluators to combine.
+ """
+ super().__init__()
+ self._evaluators = evaluators
+
+ def reset(self):
+ for evaluator in self._evaluators:
+ evaluator.reset()
+
+ def process(self, inputs, outputs):
+ for evaluator in self._evaluators:
+ evaluator.process(inputs, outputs)
+
+ def evaluate(self):
+ results = OrderedDict()
+ for evaluator in self._evaluators:
+ result = evaluator.evaluate()
+ if is_main_process() and result is not None:
+ for k, v in result.items():
+ assert (
+ k not in results
+ ), "Different evaluators produce results with the same key {}".format(k)
+ results[k] = v
+ return results
+
+
+def inference_on_dataset(
+ model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None]
+):
+ """
+ Run model on the data_loader and evaluate the metrics with evaluator.
+ Also benchmark the inference speed of `model.__call__` accurately.
+ The model will be used in eval mode.
+
+ Args:
+ model (callable): a callable which takes an object from
+ `data_loader` and returns some outputs.
+
+ If it's an nn.Module, it will be temporarily set to `eval` mode.
+ If you wish to evaluate a model in `training` mode instead, you can
+ wrap the given model and override its behavior of `.eval()` and `.train()`.
+ data_loader: an iterable object with a length.
+ The elements it generates will be the inputs to the model.
+ evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark,
+ but don't want to do any evaluation.
+
+ Returns:
+ The return value of `evaluator.evaluate()`
+ """
+ num_devices = get_world_size()
+ logger = logging.getLogger(__name__)
+ logger.info("Start inference on {} batches".format(len(data_loader)))
+
+ total = len(data_loader) # inference data loader must have a fixed length
+ if evaluator is None:
+ # create a no-op evaluator
+ evaluator = DatasetEvaluators([])
+ if isinstance(evaluator, abc.MutableSequence):
+ evaluator = DatasetEvaluators(evaluator)
+ evaluator.reset()
+
+ num_warmup = min(5, total - 1)
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+ with ExitStack() as stack:
+ if isinstance(model, nn.Module):
+ stack.enter_context(inference_context(model))
+ stack.enter_context(torch.no_grad())
+
+ start_data_time = time.perf_counter()
+ for idx, inputs in enumerate(data_loader):
+ total_data_time += time.perf_counter() - start_data_time
+ if idx == num_warmup:
+ start_time = time.perf_counter()
+ total_data_time = 0
+ total_compute_time = 0
+ total_eval_time = 0
+
+ start_compute_time = time.perf_counter()
+ outputs = model(inputs)
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ total_compute_time += time.perf_counter() - start_compute_time
+
+ start_eval_time = time.perf_counter()
+ evaluator.process(inputs, outputs)
+ total_eval_time += time.perf_counter() - start_eval_time
+
+ iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
+ data_seconds_per_iter = total_data_time / iters_after_start
+ compute_seconds_per_iter = total_compute_time / iters_after_start
+ eval_seconds_per_iter = total_eval_time / iters_after_start
+ total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start
+ if idx >= num_warmup * 2 or compute_seconds_per_iter > 5:
+ eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1)))
+ log_every_n_seconds(
+ logging.INFO,
+ (
+ f"Inference done {idx + 1}/{total}. "
+ f"Dataloading: {data_seconds_per_iter:.4f} s/iter. "
+ f"Inference: {compute_seconds_per_iter:.4f} s/iter. "
+ f"Eval: {eval_seconds_per_iter:.4f} s/iter. "
+ f"Total: {total_seconds_per_iter:.4f} s/iter. "
+ f"ETA={eta}"
+ ),
+ n=5,
+ )
+ start_data_time = time.perf_counter()
+
+ # Measure the time only for this worker (before the synchronization barrier)
+ total_time = time.perf_counter() - start_time
+ total_time_str = str(datetime.timedelta(seconds=total_time))
+ # NOTE this format is parsed by grep
+ logger.info(
+ "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format(
+ total_time_str, total_time / (total - num_warmup), num_devices
+ )
+ )
+ total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time)))
+ logger.info(
+ "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format(
+ total_compute_time_str, total_compute_time / (total - num_warmup), num_devices
+ )
+ )
+
+ results = evaluator.evaluate()
+ # An evaluator may return None when not in main process.
+ # Replace it by an empty dict instead to make it easier for downstream code to handle
+ if results is None:
+ results = {}
+ return results
+
+
+@contextmanager
+def inference_context(model):
+ """
+ A context where the model is temporarily changed to eval mode,
+ and restored to previous mode afterwards.
+
+ Args:
+ model: a torch Module
+ """
+ training_mode = model.training
+ model.eval()
+ yield
+ model.train(training_mode)
diff --git a/uniperceiver/evaluation/glue_evaler.py b/uniperceiver/evaluation/glue_evaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4763ba04bb4a9874129cc6a399e5609958964dca
--- /dev/null
+++ b/uniperceiver/evaluation/glue_evaler.py
@@ -0,0 +1,71 @@
+import os
+import sys
+import pickle
+import json
+from json import encoder
+from uniperceiver.config import configurable
+from .build import EVALUATION_REGISTRY
+
+import numpy as np
+from sklearn.metrics import f1_score, matthews_corrcoef
+from scipy.stats import pearsonr, spearmanr
+
+def simple_accuracy(preds, labels):
+ return (preds == labels).mean()
+
+@EVALUATION_REGISTRY.register()
+class GLUEEvaler(object):
+ def __init__(self, cfg, *args, **kwargs):
+ super(GLUEEvaler, self).__init__()
+ self.task_name = cfg.DATASETS.DATASET_NAME
+ self.tasks = [""]
+
+
+
+ def eval(self, results, epoch):
+ preds = []
+ labels = []
+ for result in results:
+ # cls task
+ if self.task_name != 'STS-B':
+ preds.append(result["pred"].argmax().item())
+ labels.append(int(result["label"]))
+
+ else:
+ # regression task
+ preds.append(float(result["pred"].sigmoid().item()))
+ labels.append(float(result["label"]))
+
+ preds = np.array(preds)
+ labels = np.array(labels)
+
+ if self.task_name == 'CoLA':
+ acc = simple_accuracy(preds, labels)
+ matthewscorr = matthews_corrcoef(labels, preds)
+ result = {
+ "accuracy": acc,
+ "matthews_corrcoef": matthewscorr,
+ }
+ elif self.task_name in [ 'QNLI', 'RTE', 'SST-2'] or self.task_name.startswith("MNLI"):
+ acc = simple_accuracy(preds, labels)
+ result = {
+ "accuracy": acc,
+ }
+ elif self.task_name in ['MRPC', 'QQP']:
+ acc = simple_accuracy(preds, labels)
+ f1 = f1_score(y_true=labels, y_pred=preds)
+ result = {
+ "accuracy": acc,
+ "f1_score": f1,
+ }
+ elif self.task_name in ['STS-B']:
+ pearson_corr = pearsonr(preds, labels)[0]
+ spearman_corr = spearmanr(preds, labels)[0]
+ result ={
+ "pearson_corr": pearson_corr,
+ "spearman_corr": spearman_corr,
+ }
+ else:
+ raise NotImplementedError
+
+ return result
diff --git a/uniperceiver/evaluation/imagenet_evaler.py b/uniperceiver/evaluation/imagenet_evaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd0a742f35ba480cd18a92b0924f9acd2f1db66
--- /dev/null
+++ b/uniperceiver/evaluation/imagenet_evaler.py
@@ -0,0 +1,47 @@
+
+import os
+import sys
+import tempfile
+import json
+from json import encoder
+
+import torch
+from uniperceiver.config import configurable
+from .build import EVALUATION_REGISTRY
+
+# from timm.utils import accuracy
+
+from uniperceiver.utils import comm
+
+def accuracy(output, target, topk=(1,)):
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ maxk = max(topk)
+ batch_size = target.size(0)
+ _, pred = output.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
+
+@EVALUATION_REGISTRY.register()
+class ImageNetEvaler(object):
+ def __init__(self, cfg, annfile, output_dir):
+ super(ImageNetEvaler, self).__init__()
+ self.ann_file = annfile
+ with open(self.ann_file, 'r') as f:
+ img_infos = f.readlines()
+
+ target = [int(info.replace('\n', '').split(' ')[1]) for info in img_infos]
+ self.target = torch.tensor(target)
+
+
+ def eval(self, results, epoch):
+
+ # sort the result for multi-gpu evaluation
+ results = {res['image_id']: res['cls_logits'] for res in results}
+ results = [results[i] for i in sorted(results.keys())]
+
+ results = torch.stack(results)
+
+ acc1, acc5 = accuracy(results, self.target.to(device=results.device), topk=(1, 5))
+ # acc1, acc5 = accuracy(results, self.target[:results.size(0)].to(device=results.device), topk=(1, 5))
+ return {'Acc@1': acc1, 'Acc@5': acc5}
\ No newline at end of file
diff --git a/uniperceiver/evaluation/mit_evaler.py b/uniperceiver/evaluation/mit_evaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa640538ae85803aa5ff57ae45463dd9a26ed1e1
--- /dev/null
+++ b/uniperceiver/evaluation/mit_evaler.py
@@ -0,0 +1,55 @@
+import os
+import tempfile
+import json
+from uniperceiver.config import configurable
+from .build import EVALUATION_REGISTRY
+from uniperceiver.utils import comm
+import numpy as np
+
+def simple_accuracy(preds, labels):
+ return (preds == labels).mean()
+
+@EVALUATION_REGISTRY.register()
+class MiTEvaler(object):
+ def __init__(self, cfg, annfile, output_dir):
+ super(MiTEvaler, self).__init__()
+ self.video_dict = dict()
+
+ self.cls2idx = dict()
+ with open(os.path.join(os.path.dirname(annfile), "category_mapping.txt"), 'r') as f:
+ for line in f.readlines():
+ class_name, idx = line.strip().split('\t')
+ class_name = class_name.replace(" ", "_")
+ self.cls2idx[class_name] = int(idx)
+
+ with open(annfile) as f:
+ data_file = json.load(f)
+ for name, info in data_file['database'].items():
+ # if info['subset'] == "validation" or True: # debug
+ if info['subset'] == "validation":
+ self.video_dict[name] = self.cls2idx[info['annotations']['label']]
+
+ if not os.path.exists(comm.TEMP_DIR):
+ os.mkdir(comm.TEMP_DIR)
+
+ if output_dir is not None:
+ self.output_dir = os.path.join(output_dir, 'results')
+ if not os.path.exists(self.output_dir) and comm.is_main_process():
+ os.mkdir(self.output_dir)
+ else:
+ self.output_dir = None
+
+ def eval(self, results, epoch):
+ preds = []
+ labels = []
+ for result in results:
+ labels.append(self.video_dict[result["video_name"]])
+ preds.append(result["label"].item())
+ preds = np.array(preds)
+ labels = np.array(labels)
+ acc = simple_accuracy(preds, labels)
+ # if self.output_dir is not None:
+ # json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w"))
+
+
+ return {"accuracy": acc}
diff --git a/uniperceiver/evaluation/retrieval_evaler.py b/uniperceiver/evaluation/retrieval_evaler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a92fd2a1e0e276a1f4d8ae32f3a3e4d852dd668
--- /dev/null
+++ b/uniperceiver/evaluation/retrieval_evaler.py
@@ -0,0 +1,58 @@
+
+import os
+import sys
+import numpy as np
+import torch
+from uniperceiver.config import configurable
+from .build import EVALUATION_REGISTRY
+
+@EVALUATION_REGISTRY.register()
+class RetrievalEvaler(object):
+ def __init__(self, cfg, annfile, output_dir,):
+ super(RetrievalEvaler, self).__init__()
+ self.eval_bs = cfg.INFERENCE.EVAL_BS
+ pass
+
+ def eval(self, vfeats, tfeats, labels, prefix=None):
+ count = 0
+ batch_size = self.eval_bs
+ batch_num = tfeats.size(0) // batch_size
+ rank_matrix = np.ones((tfeats.size(0))) * vfeats.size(0)
+ for i in range(batch_num):
+ if i == batch_num - 1:
+ b_tfeats = tfeats[i*batch_size:]
+ else:
+ b_tfeats = tfeats[i*batch_size:(i+1)*batch_size]
+
+ with torch.no_grad():
+ scores = (b_tfeats.unsqueeze(1) * vfeats.unsqueeze(0)).sum(dim=-1).cpu().numpy()
+ for score in scores:
+ # rank = np.where((np.argsort(-score) == np.where(labels[count]==1)[0][0]) == 1)[0][0]
+ rank = min([10] + [np.where((np.argsort(-score) == index) == 1)[0][0] for index in np.where(labels[count]==1)[0]])
+ rank_matrix[count] = rank
+ count += 1
+
+ r1 = 100.0 * np.sum(rank_matrix < 1) / len(rank_matrix)
+ r5 = 100.0 * np.sum(rank_matrix < 5) / len(rank_matrix)
+ r10 = 100.0 * np.sum(rank_matrix < 10) / len(rank_matrix)
+
+ rmean = (r1+r5+r10)/3
+
+ # medr = np.floor(np.median(rank_matrix) + 1)
+ # meanr = np.mean(rank_matrix) + 1
+ if prefix is None:
+ return {
+ "r1": r1,
+ "r5": r5,
+ "r10": r10,
+ "rmean": rmean,
+ # "meanr": meanr
+ }
+ else:
+ return {
+ prefix+ "-r1": r1,
+ prefix+ "-r5": r5,
+ prefix+ "-r10": r10,
+ prefix+ "-rmean": rmean,
+ # prefix+ "-meanr": meanr
+ }
diff --git a/uniperceiver/evaluation/testing.py b/uniperceiver/evaluation/testing.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e5ae625bb0593fc20739dd3ea549157e4df4f3d
--- /dev/null
+++ b/uniperceiver/evaluation/testing.py
@@ -0,0 +1,85 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import logging
+import numpy as np
+import pprint
+import sys
+from collections.abc import Mapping
+
+
+def print_csv_format(results):
+ """
+ Print main metrics in a format similar to Detectron,
+ so that they are easy to copypaste into a spreadsheet.
+
+ Args:
+ results (OrderedDict[dict]): task_name -> {metric -> score}
+ unordered dict can also be printed, but in arbitrary order
+ """
+ assert isinstance(results, Mapping) or not len(results), results
+ logger = logging.getLogger(__name__)
+ for task, res in results.items():
+ if isinstance(res, Mapping):
+ # Don't print "AP-category" metrics since they are usually not tracked.
+ important_res = [(k, v) for k, v in res.items() if "-" not in k]
+ logger.info("copypaste: Task: {}".format(task))
+ logger.info("copypaste: " + ",".join([k[0] for k in important_res]))
+ logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res]))
+ else:
+ logger.info(f"copypaste: {task}={res}")
+
+
+def verify_results(cfg, results):
+ """
+ Args:
+ results (OrderedDict[dict]): task_name -> {metric -> score}
+
+ Returns:
+ bool: whether the verification succeeds or not
+ """
+ expected_results = cfg.TEST.EXPECTED_RESULTS
+ if not len(expected_results):
+ return True
+
+ ok = True
+ for task, metric, expected, tolerance in expected_results:
+ actual = results[task].get(metric, None)
+ if actual is None:
+ ok = False
+ continue
+ if not np.isfinite(actual):
+ ok = False
+ continue
+ diff = abs(actual - expected)
+ if diff > tolerance:
+ ok = False
+
+ logger = logging.getLogger(__name__)
+ if not ok:
+ logger.error("Result verification failed!")
+ logger.error("Expected Results: " + str(expected_results))
+ logger.error("Actual Results: " + pprint.pformat(results))
+
+ sys.exit(1)
+ else:
+ logger.info("Results verification passed.")
+ return ok
+
+
+def flatten_results_dict(results):
+ """
+ Expand a hierarchical dict of scalars into a flat dict of scalars.
+ If results[k1][k2][k3] = v, the returned dict will have the entry
+ {"k1/k2/k3": v}.
+
+ Args:
+ results (dict):
+ """
+ r = {}
+ for k, v in results.items():
+ if isinstance(v, Mapping):
+ v = flatten_results_dict(v)
+ for kk, vv in v.items():
+ r[k + "/" + kk] = vv
+ else:
+ r[k] = v
+ return r
diff --git a/uniperceiver/evaluation/vqa_eval.py b/uniperceiver/evaluation/vqa_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..4043cd174d2b98b3af6e0dade1e4f5370c27fb9b
--- /dev/null
+++ b/uniperceiver/evaluation/vqa_eval.py
@@ -0,0 +1,80 @@
+# Copyright 2021 JD.com, Inc., JD AI
+"""
+@author: Yehao Li
+@contact: yehaoli.sysu@gmail.com
+"""
+import os
+import sys
+import pickle
+import json
+from json import encoder
+from .build import EVALUATION_REGISTRY
+
+@EVALUATION_REGISTRY.register()
+class VQAEvaler(object):
+ def __init__(self, cfg, annfile, output_dir):
+ super(VQAEvaler, self).__init__()
+ label2ans_path = os.path.join(cfg.DATALOADER.ANNO_FOLDER, "trainval_label2ans.pkl")
+ ori_annotation = json.load(open(os.path.join(cfg.DATALOADER.ANNO_FOLDER, "v2_mscoco_val2014_annotations.json")))
+ self.id2type = {t['question_id']: t['question_type'] for t in ori_annotation['annotations']}
+ self.id2type_answer = {t['question_id']: t['answer_type'] for t in ori_annotation['annotations']}
+ self.label2ans = pickle.load(open(label2ans_path, "rb"))
+
+ self.id2label = {}
+ if len(annfile) > 0:
+ answers_val = pickle.load(open(annfile, "rb"))
+ for datum in answers_val:
+ quesid = datum['question_id']
+ self.id2label[quesid] = {}
+ for i, label in enumerate(datum['labels']):
+ label_str = self.label2ans[label]
+ self.id2label[quesid][label_str] = datum['scores'][i]
+
+ if output_dir is not None:
+ self.output_dir = os.path.join(output_dir, 'results')
+ if not os.path.exists(self.output_dir):
+ os.mkdir(self.output_dir)
+ else:
+ self.output_dir = None
+
+ def eval(self, results, epoch):
+ for res in results:
+ res['answer'] = self.label2ans[res['answer']]
+ if self.output_dir is not None:
+ json.dump(results, open(os.path.join(self.output_dir, str(epoch) + '.json'), "w", encoding="utf-8"))
+
+ accuracy = 0.
+ acc_by_type = dict()
+ acc_by_type_answer = dict()
+ for result in results:
+ quesid = result['question_id']
+ ans = result['answer']
+ if quesid not in self.id2type:
+ print("Test Stage has no target")
+ return { "accuracy": 0.0 }
+ type = self.id2type[quesid]
+ ans_type = self.id2type_answer[quesid]
+ if type not in acc_by_type:
+ acc_by_type[type] = [0,0]
+ if ans_type not in acc_by_type_answer:
+ acc_by_type_answer[ans_type] = [0,0]
+ if quesid not in self.id2label:
+ return { "accuracy": 0.0 }
+
+ datum = self.id2label[quesid]
+ acc_by_type[type][1] += 1
+ acc_by_type_answer[ans_type][1] += 1
+ if ans in datum:
+ accuracy += datum[ans]
+ acc_by_type[type][0] += datum[ans]
+ acc_by_type_answer[ans_type][0] += datum[ans]
+
+ accuracy = accuracy / len(results)
+ print('vqa acc: {}'.format(accuracy*100))
+ for k in acc_by_type.keys():
+ acc_by_type[k] = acc_by_type[k][0] / acc_by_type[k][1] * 100
+ for k in acc_by_type_answer.keys():
+ acc_by_type_answer[k] = acc_by_type_answer[k][0] / acc_by_type_answer[k][1] * 100
+ print('vqa acc by question type: ', acc_by_type)
+ print("vqa acc by answer type: ", acc_by_type_answer)
+ return { "accuracy": accuracy }
diff --git a/uniperceiver/functional/__init__.py b/uniperceiver/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f9192c0feadb099b1964123d51d0dbe3352ea74
--- /dev/null
+++ b/uniperceiver/functional/__init__.py
@@ -0,0 +1,10 @@
+from .func_caption import (
+ decode_sequence,
+ decode_sequence_bert
+)
+from .func_feats import (iou, boxes_to_locfeats, dict_as_tensor, dict_to_cuda,
+ pad_tensor, expand_tensor, clip_v_inputs,
+ clip_t_inputs)
+from .func_others import (flat_list_of_lists)
+
+from .func_io import (load_vocab, read_np)
diff --git a/uniperceiver/functional/func_caption.py b/uniperceiver/functional/func_caption.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b5f68c1fcc847fe2069b71343e2e7b8fe43305
--- /dev/null
+++ b/uniperceiver/functional/func_caption.py
@@ -0,0 +1,31 @@
+
+import os
+
+def decode_sequence(vocab, seq):
+ N, T = seq.size()
+ sents = []
+ for n in range(N):
+ words = []
+ for t in range(T):
+ ix = seq[n, t]
+ if ix == 0:
+ break
+ words.append(vocab[ix])
+ sent = ' '.join(words)
+ sents.append(sent)
+ return sents
+
+def decode_sequence_bert(tokenizer, seq, sep_token_id):
+ N, T = seq.size()
+ seq = seq.data.cpu().numpy()
+ sents = []
+ for n in range(N):
+ words = []
+ for t in range(T):
+ ix = seq[n, t]
+ if ix == sep_token_id:
+ break
+ words.append(tokenizer.ids_to_tokens[ix])
+ sent = tokenizer.convert_tokens_to_string(words)
+ sents.append(sent)
+ return sents
\ No newline at end of file
diff --git a/uniperceiver/functional/func_feats.py b/uniperceiver/functional/func_feats.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd287e018b161a351a3d2faf73ace3c3dee2be4e
--- /dev/null
+++ b/uniperceiver/functional/func_feats.py
@@ -0,0 +1,155 @@
+import itertools
+import numpy as np
+import torch
+# from torch.nn.utils.rnn import pad_sequence
+
+def pad_sequence(sequences, batch_first=False, padding_value=0.0, padding_length=None):
+ """
+ modified from torch.nn.utils.rnn.pad_sequence
+
+ """
+
+ # assuming trailing dimensions and type of all the Tensors
+ # in sequences are same and fetching those from sequences[0]
+ max_size = sequences[0].size()
+ trailing_dims = max_size[1:]
+ max_len = max([s.size(0) for s in sequences]) if padding_length is None else padding_length
+ if batch_first:
+ out_dims = (len(sequences), max_len) + trailing_dims
+ else:
+ out_dims = (max_len, len(sequences)) + trailing_dims
+
+ out_tensor = sequences[0].new_full(out_dims, padding_value)
+ for i, tensor in enumerate(sequences):
+ length = tensor.size(0)
+ # use index notation to prevent duplicate references to the tensor
+ if batch_first:
+ out_tensor[i, :length, ...] = tensor
+ else:
+ out_tensor[:length, i, ...] = tensor
+
+ return out_tensor
+
+
+def pad_tensor(tensor, padding_value, use_mask, padding_length=None):
+ if isinstance(tensor[0], list):
+ tensor = list(itertools.chain.from_iterable(tensor))
+
+ out = pad_sequence(tensor, batch_first=True, padding_value=padding_value, padding_length=padding_length)
+ if use_mask:
+ lengths = [t.size(0) for t in tensor]
+ max_lengths = max(lengths) if padding_length is None else padding_length
+ mask = torch.zeros((out.size(0), max_lengths), dtype=torch.uint8)
+ for i, length in enumerate(lengths):
+ mask[i, 0:length] = 1
+ return out, mask
+ else:
+ return out
+
+def dict_to_cuda(input_dict):
+ for key in input_dict:
+ if isinstance(input_dict[key], torch.Tensor):
+ input_dict[key] = input_dict[key].cuda(non_blocking=True)
+ elif isinstance(input_dict[key], dict):
+ input_dict[key] = dict_to_cuda(input_dict[key])
+
+
+
+def dict_as_tensor(input_data):
+ if isinstance(input_data, str) or isinstance(input_data, tuple) or isinstance(input_data, int) or input_data is None:
+ pass
+ elif isinstance(input_data, dict):
+ for key in input_data:
+ input_data[key] = dict_as_tensor(input_data[key])
+ elif isinstance(input_data, list):
+ input_data = [dict_as_tensor(item) for item in input_data]
+ else:
+ input_data = torch.as_tensor(input_data)
+ return input_data
+
+
+def boxes_to_locfeats(boxes, image_w, image_h):
+ image_location = np.zeros((boxes.shape[0], 5), dtype=np.float32)
+ image_location[:, :4] = boxes
+ image_location[:, 4] = (
+ (image_location[:, 3] - image_location[:, 1])
+ * (image_location[:, 2] - image_location[:, 0])
+ / (float(image_w) * float(image_h))
+ )
+
+ image_location[:, 0] = image_location[:, 0] / float(image_w)
+ image_location[:, 1] = image_location[:, 1] / float(image_h)
+ image_location[:, 2] = image_location[:, 2] / float(image_w)
+ image_location[:, 3] = image_location[:, 3] / float(image_h)
+ return image_location
+
+def expand_tensor(tensor, size, dim=1):
+ if size == 1 or tensor is None:
+ return tensor
+ tensor = tensor.unsqueeze(dim)
+ if dim == 0:
+ tensor = tensor.expand([size] + [-1] + list(tensor.shape[2:]))
+ tensor = tensor.reshape([-1] + list(tensor.shape[2:]))
+ else:
+ tensor = tensor.expand(list(tensor.shape[:dim]) + [size] + list(tensor.shape[dim+1:]))
+ tensor = tensor.reshape(list(tensor.shape[:dim-1]) + [-1] + list(tensor.shape[dim+1:]))
+ return tensor
+
+def iou(anchors, gt_boxes):
+ """
+ anchors: (N, 4) ndarray of float
+ gt_boxes: (K, 4) ndarray of float
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
+ """
+ N = anchors.shape[0]
+ K = gt_boxes.shape[0]
+
+ gt_boxes_area = (
+ (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) * (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
+ ).reshape(1, K)
+
+ anchors_area = (
+ (anchors[:, 2] - anchors[:, 0] + 1) * (anchors[:, 3] - anchors[:, 1] + 1)
+ ).reshape(N, 1)
+
+ boxes = np.repeat(anchors.reshape(N, 1, 4), K, axis=1)
+ query_boxes = np.repeat(gt_boxes.reshape(1, K, 4), N, axis=0)
+
+ iw = (
+ np.minimum(boxes[:, :, 2], query_boxes[:, :, 2])
+ - np.maximum(boxes[:, :, 0], query_boxes[:, :, 0])
+ + 1
+ )
+ iw[iw < 0] = 0
+
+ ih = (
+ np.minimum(boxes[:, :, 3], query_boxes[:, :, 3])
+ - np.maximum(boxes[:, :, 1], query_boxes[:, :, 1])
+ + 1
+ )
+ ih[ih < 0] = 0
+
+ ua = anchors_area + gt_boxes_area - (iw * ih)
+ overlaps = iw * ih / ua
+
+ return overlaps
+
+
+def get_max_len_from_mask(mask):
+ return int(mask.sum(1).max().item())
+
+
+def clip_v_inputs(v_feats, spatials, image_mask):
+ max_len = get_max_len_from_mask(image_mask)
+ v_feats = v_feats[:, :max_len]
+ spatials = spatials[:, :max_len]
+ image_mask = image_mask[:, :max_len]
+ return v_feats, spatials, image_mask
+
+
+def clip_t_inputs(input_txt, segment_ids, input_mask):
+ max_len = get_max_len_from_mask(input_mask)
+ input_txt = input_txt[:, :max_len]
+ segment_ids = segment_ids[:, :max_len]
+ input_mask = input_mask[:, :max_len]
+ return input_txt, segment_ids, input_mask
\ No newline at end of file
diff --git a/uniperceiver/functional/func_io.py b/uniperceiver/functional/func_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..5aa9deec6670f6be26d36c5e84b78f14977b5719
--- /dev/null
+++ b/uniperceiver/functional/func_io.py
@@ -0,0 +1,59 @@
+
+import os
+import numpy as np
+from .func_feats import boxes_to_locfeats
+
+import pdb
+
+def read_lines(path):
+ with open(path, 'r') as fid:
+ lines = [line.strip() for line in fid]
+ return lines
+
+def read_lines_set(path):
+ lines = read_lines(path)
+ lines = set(lines)
+ return lines
+
+# "features", "cls_prob", "boxes", "image_h", "image_w"
+def read_np(path, preload=None):
+ if preload:
+ content = preload[path]
+ else:
+ content = np.load(path, allow_pickle=True)
+ if isinstance(content, np.ndarray):
+ return { "features": content }
+
+ keys = content.keys()
+ if len(keys) == 1:
+ return { "features": content[list(keys)[0]] }
+ return content
+
+def read_np_bbox(path, max_feat_num, use_global_v=True, preload=None):
+ content = read_np(path, preload)
+ features = content['features'][0:max_feat_num - 1]
+ boxes = content['boxes'][0:max_feat_num - 1]
+ image_h = content['image_h'][0]
+ image_w = content['image_w'][0]
+ num_boxes = len(boxes)
+
+ if use_global_v:
+ g_feat = np.sum(features, axis=0) / num_boxes
+ features = np.concatenate([np.expand_dims(g_feat, axis=0), features], axis=0)
+
+ image_locations = boxes_to_locfeats(boxes, image_w, image_h)
+ if use_global_v:
+ g_location = np.array([0, 0, 1, 1, 1])
+ image_locations = np.concatenate([np.expand_dims(g_location, axis=0), image_locations], axis=0)
+ return features, image_locations
+
+
+
+def load_vocab(path):
+ if len(path) == 0:
+ return None
+ vocab = ['.']
+ with open(path, 'r') as fid:
+ for line in fid:
+ vocab.append(line.strip())
+ return vocab
\ No newline at end of file
diff --git a/uniperceiver/functional/func_others.py b/uniperceiver/functional/func_others.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e460aa6a5f4f7e70ff5a5e4da319a9b7ca58e2
--- /dev/null
+++ b/uniperceiver/functional/func_others.py
@@ -0,0 +1,9 @@
+# Copyright 2021 JD.com, Inc., JD AI
+"""
+@author: Jianjie Luo
+@contact: jianjieluo.sysu@gmail.com
+"""
+
+def flat_list_of_lists(l):
+ """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
+ return [item for sublist in l for item in sublist]
\ No newline at end of file
diff --git a/uniperceiver/losses/__init__.py b/uniperceiver/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1082f1c0a93bc492c17300d00bd367844765b33
--- /dev/null
+++ b/uniperceiver/losses/__init__.py
@@ -0,0 +1,7 @@
+from .build import build_losses
+
+from .label_smoothing import LabelSmoothingCrossEntropy
+from .soft_target_cross_entropy import SoftTargetCrossEntropy
+from .cross_entropy import CrossEntropy
+from .accuracy import Accuracy
+from .bce_logitis import BCEWithLogits
diff --git a/uniperceiver/losses/accuracy.py b/uniperceiver/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3630607b32d3a6398ba2bfd22b935bb6e6989e9
--- /dev/null
+++ b/uniperceiver/losses/accuracy.py
@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from uniperceiver.config import configurable
+from .build import LOSSES_REGISTRY
+
+@LOSSES_REGISTRY.register()
+class Accuracy(nn.Module):
+ @configurable
+ def __init__(
+ self
+ ):
+ super(Accuracy, self).__init__()
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ }
+
+ def Forward(self, logits, targets):
+ pred = torch.argmax(logits.view(-1, logits.shape[-1]), -1)
+ targets = targets.view(-1)
+ return torch.mean((pred == targets).float())
+
+ def forward(self, outputs_dict):
+
+ ret = {}
+ for logit, target, loss_identification in zip(outputs_dict['logits'],
+ outputs_dict['targets'],
+ outputs_dict['loss_names']):
+ if logit.shape == target.shape:
+ # for mixup
+ target = torch.argmax(target, dim=-1)
+ acc = self.Forward(logit, target)
+ loss_name = 'Accuracy'
+ if len(loss_identification) > 0:
+ loss_name = loss_name + f' ({loss_identification})'
+ ret.update({loss_name: acc})
+
+ return ret
\ No newline at end of file
diff --git a/uniperceiver/losses/bce_logitis.py b/uniperceiver/losses/bce_logitis.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfe2f75d88ffd6e7a1aab3f5030912473ca76dfc
--- /dev/null
+++ b/uniperceiver/losses/bce_logitis.py
@@ -0,0 +1,38 @@
+import torch
+import torch.nn as nn
+from uniperceiver.config import configurable
+from .build import LOSSES_REGISTRY
+
+@LOSSES_REGISTRY.register()
+class BCEWithLogits(nn.Module):
+ @configurable
+ def __init__(self, loss_weight=1.0):
+ super(BCEWithLogits, self).__init__()
+ self.criterion = nn.BCEWithLogitsLoss(reduction="mean")
+ if not isinstance(loss_weight, float):
+ self.loss_weight = 1.0
+ else:
+ self.loss_weight = loss_weight
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ 'loss_weight' : getattr(cfg.LOSSES, 'LOSS_WEIGHT', 1.0)
+ }
+
+ def forward(self, outputs_dict):
+ ret = {}
+ for logit, target, loss_identification in zip(outputs_dict['logits'],
+ outputs_dict['targets'],
+ outputs_dict['loss_names']):
+
+ loss = self.criterion(logit, target)
+ if self.loss_weight != 1.0:
+ loss *= self.loss_weight
+ loss_name = 'BCEWithLogits_Loss'
+ if len(loss_identification) > 0:
+ loss_name = loss_name+ f' ({loss_identification})'
+ ret.update({ loss_name: loss })
+
+ return ret
+
\ No newline at end of file
diff --git a/uniperceiver/losses/build.py b/uniperceiver/losses/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7987ff48fc8ecdac094191de5ae22b63a34956c
--- /dev/null
+++ b/uniperceiver/losses/build.py
@@ -0,0 +1,15 @@
+
+from uniperceiver.utils.registry import Registry
+
+LOSSES_REGISTRY = Registry("LOSSES")
+LOSSES_REGISTRY.__doc__ = """
+Registry for losses
+"""
+
+def build_losses(cfg):
+ losses = []
+ for name in cfg.LOSSES.NAMES:
+ loss = LOSSES_REGISTRY.get(name)(cfg)
+ losses.append(loss)
+ return losses
+
diff --git a/uniperceiver/losses/cross_entropy.py b/uniperceiver/losses/cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..60f31e51040139b3fc07a7c134bc2e362f7021c6
--- /dev/null
+++ b/uniperceiver/losses/cross_entropy.py
@@ -0,0 +1,57 @@
+import torch
+import torch.nn as nn
+from uniperceiver.config import configurable
+from .build import LOSSES_REGISTRY
+
+@LOSSES_REGISTRY.register()
+class CrossEntropy(nn.Module):
+ @configurable
+ def __init__(self, loss_weight=1.0, reduction='mean', loss_fp32=False):
+ super(CrossEntropy, self).__init__()
+ if reduction is None:
+ reduction = 'mean'
+ self.criterion_func = nn.CrossEntropyLoss(ignore_index=-1, reduction=reduction)
+ if not isinstance(loss_weight, float):
+ self.loss_weight = 1.0
+ else:
+ self.loss_weight = loss_weight
+ self.reduction = reduction
+ self.loss_fp32 = loss_fp32
+
+ def criterion(self, x, target):
+ if self.loss_fp32 and x.dtype != torch.float32:
+ loss = self.criterion_func(x.to(torch.float32), target).to(x.dtype)
+ else:
+ loss = self.criterion_func(x, target)
+ return loss.mean()
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ 'loss_weight': getattr(cfg.LOSSES, 'LOSS_WEIGHT', None),
+ 'reduction': getattr(cfg.LOSSES, 'REDUCTION', 'mean'),
+ 'loss_fp32': getattr(cfg.LOSSES, 'LOSS_FP32', False),
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ cfg.LOSSES.LOSS_WEIGHT = None
+ cfg.LOSSES.REDUCTION = 'mean'
+
+ def forward(self, outputs_dict):
+ ret = {}
+
+ for logit, target, loss_identification in zip(outputs_dict['logits'],
+ outputs_dict['targets'],
+ outputs_dict['loss_names']):
+
+ loss = self.criterion(logit, target)
+ if self.loss_weight != 1.0:
+ loss *= self.loss_weight
+ loss_name = 'CrossEntropy_Loss'
+ if len(loss_identification) > 0:
+ loss_name = loss_name+ f' ({loss_identification})'
+ ret.update({ loss_name: loss })
+
+
+ return ret
diff --git a/uniperceiver/losses/label_smoothing.py b/uniperceiver/losses/label_smoothing.py
new file mode 100644
index 0000000000000000000000000000000000000000..db56e0fe09db38b8ac9f6783d7cf9b50ec144fbe
--- /dev/null
+++ b/uniperceiver/losses/label_smoothing.py
@@ -0,0 +1,66 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from uniperceiver.config import configurable
+from .build import LOSSES_REGISTRY
+
+
+
+
+@LOSSES_REGISTRY.register()
+class LabelSmoothingCrossEntropy(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ label_smoothing,
+ loss_weight,
+ loss_fp32,
+ ):
+ super(LabelSmoothingCrossEntropy, self).__init__()
+ self.label_smoothing = label_smoothing
+ self.confidence = 1.0 - self.label_smoothing
+ if not isinstance(loss_weight, float):
+ self.loss_weight = 1.0
+ else:
+ self.loss_weight = loss_weight
+ self.loss_fp32 = loss_fp32
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "label_smoothing": cfg.LOSSES.LABELSMOOTHING,
+ 'loss_weight': getattr(cfg.LOSSES, 'LOSS_WEIGHT', None),
+ 'loss_fp32': getattr(cfg.LOSSES, 'LOSS_FP32', False),
+ }
+
+ def Forward(self, x, target):
+ if self.loss_fp32 and x.dtype != torch.float32:
+ logprobs = F.log_softmax(x, dim=-1,
+ dtype=torch.float32).to(x.dtype)
+ else:
+ logprobs = F.log_softmax(x, dim=-1)
+ nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
+ nll_loss = nll_loss.squeeze(1)
+ smooth_loss = -logprobs.mean(dim=-1)
+ loss = self.confidence * nll_loss + self.label_smoothing * smooth_loss
+ return loss.mean()
+
+ def forward(self, outputs_dict):
+ ret = {}
+
+ for logit, target, loss_identification in zip(outputs_dict['logits'],
+ outputs_dict['targets'],
+ outputs_dict['loss_names']):
+
+
+ loss = self.Forward(logit, target)
+ if self.loss_weight != 1.0:
+ loss *= self.loss_weight
+ loss_name = 'LabelSmoothing'
+ if len(loss_identification) > 0:
+ loss_name = loss_name + f' ({loss_identification})'
+ ret.update({loss_name: loss})
+
+
+ return ret
diff --git a/uniperceiver/losses/soft_target_cross_entropy.py b/uniperceiver/losses/soft_target_cross_entropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9d6f2dcf4fb8cf7e92694cf45a9bcf50965aec6
--- /dev/null
+++ b/uniperceiver/losses/soft_target_cross_entropy.py
@@ -0,0 +1,55 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from uniperceiver.config import configurable
+from .build import LOSSES_REGISTRY
+
+class CrossEntropyWithSoftTarget(nn.Module):
+
+ def __init__(self, loss_fp32):
+ super(CrossEntropyWithSoftTarget, self).__init__()
+ self.loss_fp32 = loss_fp32
+
+ def forward(self, x, target):
+ if self.loss_fp32 and x.dtype != torch.float32:
+ loss = torch.sum(-target * F.log_softmax(x, dim=-1, dtype=torch.float32), dim=-1).to(x.dtype)
+ else:
+ loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
+ return loss.mean()
+
+
+@LOSSES_REGISTRY.register()
+class SoftTargetCrossEntropy(nn.Module):
+ @configurable
+ def __init__(self, loss_weight=1.0, loss_fp32=False):
+ super(SoftTargetCrossEntropy, self).__init__()
+ self.criterion = CrossEntropyWithSoftTarget(loss_fp32)
+ if not isinstance(loss_weight, float):
+ self.loss_weight = 1.0
+ else:
+ self.loss_weight = loss_weight
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ 'loss_weight' : getattr(cfg.LOSSES, 'LOSS_WEIGHT', None),
+ 'loss_fp32' : getattr(cfg.LOSSES, 'LOSS_FP32', False),
+ }
+
+ def forward(self, outputs_dict):
+ ret = {}
+ for logit, target, loss_identification in zip(outputs_dict['logits'],
+ outputs_dict['targets'],
+ outputs_dict['loss_names']):
+
+ loss = self.criterion(logit, target)
+ if self.loss_weight != 1.0:
+ loss *= self.loss_weight
+ loss_name = 'SoftTargetCrossEntropy_Loss'
+ if len(loss_identification) > 0:
+ loss_name = loss_name+ f' ({loss_identification})'
+ ret.update({ loss_name: loss })
+
+
+ return ret
diff --git a/uniperceiver/lr_scheduler/__init__.py b/uniperceiver/lr_scheduler/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3006863777a8226854b2b2af35b6d7ab2640b88
--- /dev/null
+++ b/uniperceiver/lr_scheduler/__init__.py
@@ -0,0 +1,11 @@
+
+from .build import build_lr_scheduler
+
+
+from .warmup_lr import (
+ WarmupConstant,
+ WarmupLinear,
+ WarmupCosine,
+ WarmupCosineWithHardRestarts,
+ WarmupMultiStepLR
+)
diff --git a/uniperceiver/lr_scheduler/build.py b/uniperceiver/lr_scheduler/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..47b833291ac0ec4eed36fa167373eb7faf190ea6
--- /dev/null
+++ b/uniperceiver/lr_scheduler/build.py
@@ -0,0 +1,11 @@
+
+from uniperceiver.utils.registry import Registry
+
+LR_SCHEDULER_REGISTRY = Registry("LR_SCHEDULER")
+LR_SCHEDULER_REGISTRY.__doc__ = """
+Registry for lr scheduler
+"""
+
+def build_lr_scheduler(cfg, optimizer, data_size):
+ lr_scheduler = LR_SCHEDULER_REGISTRY.get(cfg.LR_SCHEDULER.NAME)(cfg, optimizer, data_size)
+ return lr_scheduler
diff --git a/uniperceiver/lr_scheduler/warmup_lr.py b/uniperceiver/lr_scheduler/warmup_lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..b76bc191da4f9a10a0fa0673a5a1fdc5b2b7bc24
--- /dev/null
+++ b/uniperceiver/lr_scheduler/warmup_lr.py
@@ -0,0 +1,223 @@
+import math
+from bisect import bisect_right
+import warnings
+import torch
+from torch.optim import Optimizer
+from torch.optim.lr_scheduler import LambdaLR
+from uniperceiver.config import configurable
+from .build import LR_SCHEDULER_REGISTRY
+from uniperceiver.utils import comm
+
+
+@LR_SCHEDULER_REGISTRY.register()
+class WarmupConstant(LambdaLR):
+ """ Linear warmup and then constant.
+ Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
+ Keeps learning rate schedule equal to 1. after warmup_steps.
+ """
+ @configurable
+ def __init__(
+ self,
+ *,
+ optimizer,
+ warmup_steps,
+ last_epoch=-1):
+ self.warmup_steps = warmup_steps
+ super(WarmupConstant, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ @classmethod
+ def from_config(cls, cfg, optimizer, data_size):
+ return {
+ "optimizer": optimizer,
+ "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size,
+ "last_epoch": -1
+ }
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ return 1.
+
+@LR_SCHEDULER_REGISTRY.register()
+class WarmupLinear(LambdaLR):
+ """ Linear warmup and then linear decay.
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
+ Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
+ """
+ @configurable
+ def __init__(
+ self,
+ *,
+ optimizer,
+ min_lr,
+ warmup_steps,
+ t_total,
+ last_epoch=-1):
+
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.min_lr = min_lr
+ super(WarmupLinear, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ @classmethod
+ def from_config(cls, cfg, optimizer, data_size):
+ return {
+ "optimizer": optimizer,
+ "min_lr": cfg.LR_SCHEDULER.MIN_LR / cfg.SOLVER.BASE_LR,
+ "warmup_steps": cfg.LR_SCHEDULER.WARMUP if cfg.INFERENCE.ITER_BASED else (cfg.LR_SCHEDULER.WARMUP * data_size),
+ "t_total": cfg.SOLVER.MAX_ITER if cfg.INFERENCE.ITER_BASED else (cfg.SOLVER.EPOCH * data_size), # total iterations
+ "last_epoch": -1
+ }
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1, self.warmup_steps))
+ return max(self.min_lr, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
+
+@LR_SCHEDULER_REGISTRY.register()
+class WarmupCosine(LambdaLR):
+ """ Linear warmup and then cosine decay.
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
+ Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
+ If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
+ """
+ @configurable
+ def __init__(
+ self,
+ *,
+ optimizer,
+ min_lr,
+ warmup_steps,
+ t_total,
+ cycles=.5,
+ last_epoch=-1):
+
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.cycles = cycles
+ self.min_lr = min_lr
+ super(WarmupCosine, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+ if comm.get_rank() == 0:
+ print('warmup cosine lr, warmup_steps: {} t_total {} .'.format(warmup_steps, t_total))
+
+ @classmethod
+ def from_config(cls, cfg, optimizer, data_size):
+ return {
+ "optimizer": optimizer,
+ "min_lr": cfg.LR_SCHEDULER.MIN_LR / cfg.SOLVER.BASE_LR,
+ "warmup_steps": cfg.LR_SCHEDULER.WARMUP if cfg.INFERENCE.ITER_BASED else (cfg.LR_SCHEDULER.WARMUP * data_size),
+ "t_total": cfg.SOLVER.MAX_ITER if cfg.INFERENCE.ITER_BASED else (cfg.SOLVER.EPOCH * data_size), # total iterations
+ "cycles": .5,
+ "last_epoch": -1
+ }
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1.0, self.warmup_steps))
+ # progress after warmup
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
+ return max(self.min_lr, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
+
+
+
+
+@LR_SCHEDULER_REGISTRY.register()
+class WarmupCosineWithHardRestarts(LambdaLR):
+ """ Linear warmup and then cosine cycles with hard restarts.
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
+ If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
+ learning rate (with hard restarts).
+ """
+ @configurable
+ def __init__(
+ self,
+ *,
+ optimizer,
+ warmup_steps,
+ t_total,
+ cycles=1.,
+ last_epoch=-1):
+
+ self.warmup_steps = warmup_steps
+ self.t_total = t_total
+ self.cycles = cycles
+ super(WarmupCosineWithHardRestarts, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
+
+ @classmethod
+ def from_config(cls, cfg, optimizer, data_size):
+ return {
+ "optimizer": optimizer,
+ "warmup_steps": cfg.LR_SCHEDULER.WARMUP * data_size,
+ "t_total": cfg.SOLVER.MAX_ITER if cfg.INFERENCE.ITER_BASED else (cfg.SOLVER.EPOCH * data_size), # total iterations
+ "cycles": 1.,
+ "last_epoch": -1
+ }
+
+ def lr_lambda(self, step):
+ if step < self.warmup_steps:
+ return float(step) / float(max(1, self.warmup_steps))
+ # progress after warmup
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
+ if progress >= 1.0:
+ return 0.0
+ return max(0.0, 0.5 * (1. + math.cos(math.pi * ((float(self.cycles) * progress) % 1.0))))
+
+@LR_SCHEDULER_REGISTRY.register()
+class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
+ @configurable
+ def __init__(
+ self,
+ *,
+ optimizer,
+ milestones,
+ gamma=0.1,
+ warmup_factor=1.0 / 3,
+ warmup_iters=500,
+ warmup_method="linear",
+ last_epoch=-1,
+ ):
+ if not list(milestones) == sorted(milestones):
+ raise ValueError(
+ "Milestones should be a list of" " increasing integers. Got {}",
+ milestones,
+ )
+
+ if warmup_method not in ("constant", "linear"):
+ raise ValueError(
+ "Only 'constant' or 'linear' warmup_method accepted"
+ "got {}".format(warmup_method)
+ )
+ self.milestones = milestones
+ self.gamma = gamma
+ self.warmup_factor = warmup_factor
+ self.warmup_iters = warmup_iters
+ self.warmup_method = warmup_method
+ super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch)
+
+ @classmethod
+ def from_config(cls, cfg, optimizer, data_size):
+ steps = [step * data_size for step in cfg.LR_SCHEDULER.STEPS]
+ return {
+ "optimizer": optimizer,
+ "milestones": steps,
+ "gamma": cfg.LR_SCHEDULER.GAMMA,
+ "warmup_factor": cfg.LR_SCHEDULER.WARMUP_FACTOR,
+ "warmup_iters": cfg.LR_SCHEDULER.WARMUP * data_size,
+ "warmup_method": cfg.LR_SCHEDULER.WARMUP_METHOD,
+ "last_epoch": -1
+ }
+
+ def get_lr(self):
+ warmup_factor = 1
+ if self.last_epoch < self.warmup_iters:
+ if self.warmup_method == "constant":
+ warmup_factor = self.warmup_factor
+ elif self.warmup_method == "linear":
+ alpha = self.last_epoch / self.warmup_iters
+ warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+ return [
+ base_lr
+ * warmup_factor
+ * self.gamma ** bisect_right(self.milestones, self.last_epoch)
+ for base_lr in self.base_lrs
+ ]
diff --git a/uniperceiver/modeling/__init__.py b/uniperceiver/modeling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcc420d44d30996d709b989cf12d827f4861ab46
--- /dev/null
+++ b/uniperceiver/modeling/__init__.py
@@ -0,0 +1,3 @@
+
+
+from .meta_arch import (build_model, add_config)
diff --git a/uniperceiver/modeling/decode_strategy/__init__.py b/uniperceiver/modeling/decode_strategy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b56d787e1a95d63479b5c15f82e529d5cd42ee1f
--- /dev/null
+++ b/uniperceiver/modeling/decode_strategy/__init__.py
@@ -0,0 +1,7 @@
+
+
+from .build import build_beam_searcher, build_greedy_decoder
+from .caption_beam_searcher_v2 import CaptionBeamSearcherV2
+from .caption_beam_searcher_v3 import CaptionBeamSearcherV3
+
+__all__ = list(globals().keys())
diff --git a/uniperceiver/modeling/decode_strategy/build.py b/uniperceiver/modeling/decode_strategy/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00b75e5ffad30e43c674939c5abf4f53433a00b
--- /dev/null
+++ b/uniperceiver/modeling/decode_strategy/build.py
@@ -0,0 +1,15 @@
+
+from uniperceiver.utils.registry import Registry
+
+DECODE_STRATEGY_REGISTRY = Registry("DECODE_STRATEGY")
+DECODE_STRATEGY_REGISTRY.__doc__ = """
+Registry for decode strategy
+"""
+
+def build_beam_searcher(cfg):
+ beam_search = None if cfg.DECODE_STRATEGY.NAME.lower() == "none" else DECODE_STRATEGY_REGISTRY.get(cfg.DECODE_STRATEGY.NAME)(cfg)
+ return beam_search
+
+def build_greedy_decoder(cfg):
+ greedy_decoder = DECODE_STRATEGY_REGISTRY.get("GreedyDecoder")(cfg)
+ return greedy_decoder
\ No newline at end of file
diff --git a/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v2.py b/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0afc1b1c10006c3356a5294e232439ff1fc0be7e
--- /dev/null
+++ b/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v2.py
@@ -0,0 +1,334 @@
+
+# Copyright (c) 2019, AImageLab
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from uniperceiver.config import configurable
+from uniperceiver.functional import expand_tensor
+from .decode_strategy import DecodeStrategy
+from .build import DECODE_STRATEGY_REGISTRY
+from uniperceiver.utils import comm
+import math
+
+
+@DECODE_STRATEGY_REGISTRY.register()
+class CaptionBeamSearcherV2(DecodeStrategy):
+
+ def data_half(self, data):
+ if self.fp16:
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
+ data[k] = v.half()
+ # print(k)
+ return data
+ else:
+ return data
+
+
+
+
+ def _select(self, batch_size, beam_size, t, candidate_logprob):
+ selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True)
+ selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]
+ return selected_idx, selected_logprob
+
+ def _expand_state(self, states, selected_beam, batch_size, beam_size, cur_beam_size):
+ for i in range(len(states)):
+ shape = list(states[i].shape)
+ beam = selected_beam
+ for _ in shape[1:]:
+ beam = beam.unsqueeze(-1)
+ states[i] = torch.gather(states[i].view(*([batch_size, cur_beam_size] + shape[1:])), 1,
+ beam.expand(*([batch_size, beam_size] + shape[1:])))
+ states[i] = states[i].view(*([-1, ] + shape[1:]))
+
+
+ def _forward(self, batched_inputs, model):
+ # only two caption tasks are generative task now!
+ # for caption tasks, the computations are:
+ # 1. encode the image sequence; save for further use.
+ # 2. if no cached encoded dictionary, encode the dictionary and save; otherwise reuse cache.
+ # 3. compute attention. We use cross attention insted of self attention.
+
+ # batched_inputs[kfg.IMAGE] = batched_inputs.pop(kfg.VIDEO).squeeze(1)
+
+ inputs = batched_inputs
+ inputs = self.data_half(inputs)
+
+
+ out_size = batched_inputs.get('OUT_SIZE', 1)
+
+
+ # 0. token embedding
+ if model.visual_embed is not None:
+ # ve_out = model.visual_embed(batched_inputs)
+ # inputs.update(ve_out)
+ model.visual_embed(inputs)
+
+
+ if model.video_embed is not None:
+ # ve_out = model.video_embed(batched_inputs)
+ # inputs.update(ve_out)
+ model.video_embed(inputs)
+
+ if model.token_embed is not None:
+ # te_out = model.token_embed(batched_inputs)
+ # inputs.update(te_out)
+ model.token_embed(inputs)
+
+ prompt_data = {}
+ if model.prompt_embed is not None:
+ prompt_data = model.prompt_embed(batched_inputs)
+ prompt_data[kfg.DEEP_PROMPT] = model.prompt and model.deep_prompt
+ inputs.update(prompt_data)
+
+ # 1. encode the image/video sequence.
+ # bs = inputs[kfg.ATT_FEATS].size(0)
+ bs = inputs['images'].size(0)
+
+ v_input = []
+ # v_input.append(model._get_sep_embed(inputs, bs))
+ v_input.append(model._get_sep_embed(inputs.get('spe_token_embed', None), bs))
+ # v_input.append(inputs[kfg.ATT_FEATS])
+ # comm._LOCAL_IMAGE_LENGTH = inputs[kfg.ATT_FEATS].shape[1]
+ comm._LOCAL_IMAGE_LENGTH = inputs['images'].shape[-1]
+ # add by zjg
+ if kfg.PROMPT_EMBED in inputs and not model.deep_prompt:
+ v_input.append(batched_inputs[kfg.PROMPT_EMBED])
+
+ v_input = torch.cat(v_input, dim=1)
+
+ # ext_u_tmasks = torch.ones((bs, v_input.shape[1], v_input.shape[1]), dtype=v_input.dtype, device=v_input.device)
+ # ext_u_tmasks = ext_u_tmasks.unsqueeze(1)
+ # ext_u_tmasks = (1.0 - ext_u_tmasks) * -10000.0
+ # for img encoder, do not need mask
+ v_input = {
+ kfg.MM_EMBEDS: v_input,
+ # kfg.ATT_FEATS: inputs[kfg.ATT_FEATS],
+ kfg.TEXT_GEN_MODE: False,
+ kfg.EXT_U_TOKENS_MASKS: None,
+ }
+
+ # for deep prompt tuning
+ if prompt_data.get(kfg.DEEP_PROMPT, False):
+ v_input.update(prompt_data)
+
+
+ # masks = model.get_extended_attention_mask(v_input)
+ # v_input.update(masks)
+
+ # v_input.update( {kfg.EXT_U_TOKENS_MASKS: v_input[kfg.EXT_U_TOKENS_MASKS][:, :, :, 1:]} ) # remove the mask for special token
+ # vfeats = model.encoder(v_input)[kfg.U_HIDDEN_STATES]
+
+ # 2. encode the dictionary - if no pre-computed, add that into input
+ if getattr(self, 'pre_computed_word_embeds', None) is None:
+ w_input = []
+ vocab_size = model.token_embed.embeddings.num_embeddings
+ w_input.append(model._get_sep_embed(inputs.get('spe_token_embed', None), vocab_size))
+
+ # range_slice = torch.arange(start=0, end=vocab_size).unsqueeze(1).to(inputs[kfg.ATT_FEATS].device)
+ range_slice = torch.arange(start=0, end=vocab_size).unsqueeze(1).to(inputs['images'].device)
+ # - [HACK] we hardcode the EOT token
+ eot_to_append = range_slice.new_full(range_slice.shape, 49407)
+ range_slice_concat_eot = torch.cat([range_slice, eot_to_append], dim=1)
+ # temp = {
+ # kfg.U_TOKENS_IDS: range_slice_concat_eot,
+ # kfg.U_TOKENS_TYPE: torch.zeros_like(range_slice_concat_eot)
+ # }
+ temp = {
+ "shared_targets": [{
+ "shared_tgt_tokens":range_slice_concat_eot,
+ },
+ ]
+ # kfg.U_TOKENS_TYPE: torch.zeros_like(range_slice_concat_eot)
+ }
+
+ # word_embeddings = model.token_embed(temp)['shared_tgt_token_embed']
+ model.token_embed(temp)
+ word_embeddings = temp["shared_targets"][0]['shared_tgt_token_embed']
+
+ w_input.append(word_embeddings)
+ w_input = torch.cat(w_input, dim=1)
+ v_input.update({ kfg.WORD_EMBEDS: w_input })
+
+ v_input = self.data_half(v_input)
+
+ model.add_tag_embedding(v_input)
+
+ enc_out = model.encoder(v_input, return_all=True)
+ self.pre_computed_word_embeds = enc_out[kfg.WORD_HIDDEN_STATES]
+ vfeats = enc_out[kfg.U_HIDDEN_STATES]
+ else:
+ v_input = self.data_half(v_input)
+ vfeats = model.encoder(v_input, return_all=True)[kfg.U_HIDDEN_STATES]
+
+ # 3. compute attention
+
+ comm._CAPTION_GEN_MODE = True
+
+ beam_size = self.beam_size
+ log_probs = []
+ selected_words = None
+ seq_logprob = torch.zeros((bs, 1, 1)).cuda() # bs, 1, 1
+ seq_mask = torch.ones((bs, beam_size, 1)).cuda()
+ wt = Variable(torch.zeros(bs, dtype=torch.long).cuda().unsqueeze(1)) + self.spe_token_id
+ u_tokens_type = wt.new_zeros(wt.shape) # [Note] we assume the type tokens are 0.
+
+ history_states = vfeats[:-1]
+ len_prefix = history_states[0].shape[1]
+ total_history_states = [ history_states[0].new_zeros(beam_size * bs, vfeats[0].shape[1] + self.max_seq_len, vfeats[0].shape[2]) for _ in history_states]
+ for i, ths in enumerate(total_history_states):
+ hs = history_states[i]
+ ths[:hs.shape[0], :hs.shape[1], :] = hs
+
+ outputs = []
+ for t in range(self.max_seq_len):
+ cur_beam_size = 1 if t == 0 else beam_size
+
+ history_states = [ ths[ :(cur_beam_size*bs), :(len_prefix+t), :] for ths in total_history_states]
+ t_input = {
+ kfg.U_TOKENS_IDS: wt,
+ kfg.U_TOKENS_TYPE: u_tokens_type,
+ kfg.EXT_U_TOKENS_MASKS: None,
+ kfg.HISTORY_STATES: history_states,
+ kfg.TIME_STEP: t
+ }
+
+ vt_out = model.token_embed(t_input)
+ t_input.update(vt_out)
+
+ t_input.update({ kfg.MM_EMBEDS: t_input[kfg.U_TOKEN_EMBED] })
+
+ if prompt_data.get(kfg.DEEP_PROMPT, False) and prompt_data['PROMPT_EMBED'].shape[1] != t_input[
+ 'MM_EMBEDS'].shape[0]:
+ prompt_data['PROMPT_EMBED'] = prompt_data[
+ 'PROMPT_EMBED'][:, :1].expand(
+ -1, t_input['MM_EMBEDS'].shape[0], -1, -1)
+ t_input.update(prompt_data)
+
+ t_input = self.data_half(t_input)
+ encoder_out = model.encoder(t_input, return_all=True)
+
+ pred_input = {
+ kfg.TEXT_GEN_MODE: True,
+ kfg.WORD_HIDDEN_STATES: self.pre_computed_word_embeds,
+ kfg.U_HIDDEN_STATES: encoder_out[kfg.U_HIDDEN_STATES],
+ kfg.TASK_NAME: batched_inputs[kfg.TASK_NAME]
+ }
+
+ logit = model.predictor(pred_input, force_spe_first=True)[kfg.OUTPUT]
+ word_logprob = F.log_softmax(logit, dim=-1)
+ word_logprob = word_logprob.view(bs, cur_beam_size, -1)
+ candidate_logprob = seq_logprob + word_logprob
+
+ # # Mask sequence if it reaches EOS
+ # if t > 0:
+ # mask = (selected_words.view(bs, cur_beam_size) != 0).float().unsqueeze(-1) # 为什么是不等于0
+ # seq_mask = seq_mask * mask
+ # word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
+ # old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
+ # old_seq_logprob[:, :, 1:] = -999
+ # candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
+
+ eos_id = 49407
+ if t > 0:
+ mask = (selected_words.view(bs, cur_beam_size) != eos_id).float().unsqueeze(-1)
+ seq_mask = seq_mask * mask
+ word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
+ old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
+ old_seq_logprob[:, :, :eos_id] = -999
+ old_seq_logprob[:, :, eos_id + 1:] = -999
+ candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
+
+ selected_idx, selected_logprob = self._select(bs, beam_size, t, candidate_logprob) # bs beam
+ selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor')
+ selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
+
+ self._expand_state(history_states, selected_beam, bs, beam_size, cur_beam_size)
+
+ seq_logprob = selected_logprob.unsqueeze(-1)
+ seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
+ outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
+ outputs.append(selected_words.unsqueeze(-1))
+
+ this_word_logprob = torch.gather(word_logprob, 1,
+ selected_beam.unsqueeze(-1).expand(bs, beam_size, word_logprob.shape[-1]))
+ this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
+ log_probs = list(
+ torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(bs, beam_size, 1)) for o in log_probs)
+ log_probs.append(this_word_logprob)
+ selected_words = selected_words.view(-1, 1)
+ # wt = selected_words
+
+ if t == 0:
+ u_tokens_type = expand_tensor(u_tokens_type, beam_size)
+ wt = expand_tensor(wt, beam_size)
+
+ selected_t_input = {
+ kfg.U_TOKENS_IDS: selected_words,
+ kfg.U_TOKENS_TYPE: u_tokens_type,
+ kfg.EXT_U_TOKENS_MASKS: None,
+ kfg.HISTORY_STATES: history_states,
+ kfg.TIME_STEP: t
+ }
+ selected_vt_out = model.token_embed(selected_t_input)
+ selected_t_input.update(selected_vt_out)
+
+ selected_t_input.update({ kfg.MM_EMBEDS: selected_t_input[kfg.U_TOKEN_EMBED] })
+
+ selected_t_prompt_data = dict(prompt_data)
+ if selected_t_prompt_data.get(kfg.DEEP_PROMPT, False) and selected_t_prompt_data['PROMPT_EMBED'].shape[1] != selected_t_input['MM_EMBEDS'].shape[0]:
+ selected_t_prompt_data['PROMPT_EMBED'] = selected_t_prompt_data['PROMPT_EMBED'][:, :1].expand(
+ -1, selected_t_input['MM_EMBEDS'].shape[0], -1, -1)
+ selected_t_input.update(selected_t_prompt_data)
+
+ selected_t_input = self.data_half(selected_t_input)
+ selected_encoder_out = model.encoder(selected_t_input, return_all=True)
+
+ for i, ths in enumerate(total_history_states):
+ hs = history_states[i]
+ ths[:hs.shape[0], :hs.shape[1], :] = hs
+ ths[:hs.shape[0], hs.shape[1], :] = selected_encoder_out[kfg.U_HIDDEN_STATES][i].squeeze(1)
+
+ # expand_keys = {
+ # kfg.ATT_FEATS,
+ # kfg.GLOBAL_FEATS,
+ # kfg.ATT_MASKS,
+ # kfg.EXT_ATT_MASKS,
+ # kfg.P_ATT_FEATS,
+ # kfg.EXT_G_TOKENS_MASKS,
+ # kfg.G_TOKENS_TYPE
+ # }
+ # for key in expand_keys:
+ # if key in inputs:
+ # if isinstance(inputs[key], list):
+ # inputs[key] = inputs[key][-1] # usually is ATT_FEATS in TDEN
+ # tensor = expand_tensor(inputs[key], beam_size)
+ # inputs.update({ key: tensor })
+
+ outputs = torch.cat(outputs, -1)
+
+
+ if self.len_penalty > 0:
+ step = outputs.ne(49407).sum(-1, keepdim=True) + 1
+ seq_logprob /= step ** self.len_penalty
+ seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
+
+ outputs = torch.gather(outputs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
+ log_probs = torch.cat(log_probs, -1)
+ log_probs = torch.gather(log_probs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
+
+ outputs = outputs.contiguous()[:, :out_size]
+ log_probs = log_probs.contiguous()[:, :out_size]
+ if out_size == 1:
+ outputs = outputs.squeeze(1)
+ log_probs = log_probs.squeeze(1)
+
+ comm._CAPTION_GEN_MODE = False
+
+ return {
+ kfg.IDS: batched_inputs[kfg.IDS],
+ kfg.G_SENTS_IDS: outputs,
+ kfg.G_LOGP: log_probs
+ }
diff --git a/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v3.py b/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..950781bb6226495bf6c6012f48a5494da2b611d3
--- /dev/null
+++ b/uniperceiver/modeling/decode_strategy/caption_beam_searcher_v3.py
@@ -0,0 +1,221 @@
+# Copyright (c) 2019, AImageLab
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.autograd import Variable
+from uniperceiver.config import configurable
+from uniperceiver.functional import expand_tensor
+from .decode_strategy import DecodeStrategy
+from .build import DECODE_STRATEGY_REGISTRY
+from uniperceiver.utils import comm
+import math
+from torch.cuda.amp import autocast
+
+@DECODE_STRATEGY_REGISTRY.register()
+class CaptionBeamSearcherV3(DecodeStrategy):
+
+ def data_half(self, data):
+ if self.fp16:
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
+ data[k] = v.half()
+ # print(k)
+ return data
+ else:
+ return data
+
+
+
+
+ def _select(self, batch_size, beam_size, t, candidate_logprob):
+ selected_logprob, selected_idx = torch.sort(candidate_logprob.view(batch_size, -1), -1, descending=True)
+ selected_logprob, selected_idx = selected_logprob[:, :beam_size], selected_idx[:, :beam_size]
+ return selected_idx, selected_logprob
+
+ def _expand_state(self, states, selected_beam, batch_size, beam_size, cur_beam_size):
+ for i in range(len(states)):
+ shape = list(states[i].shape)
+ beam = selected_beam
+ for _ in shape[1:]:
+ beam = beam.unsqueeze(-1)
+ states[i] = torch.gather(states[i].view(*([batch_size, cur_beam_size] + shape[1:])), 1,
+ beam.expand(*([batch_size, beam_size] + shape[1:])))
+ states[i] = states[i].view(*([-1, ] + shape[1:]))
+
+
+ def _forward(self, batched_inputs, model):
+ # only two caption tasks are generative task now!
+ # for caption tasks, the computations are:
+ # 1. encode the image sequence; save for further use.
+ # 2. if no cached encoded dictionary, encode the dictionary and save; otherwise reuse cache.
+ # 3. compute attention. We use cross attention insted of self attention.
+
+ # batched_inputs[kfg.IMAGE] = batched_inputs.pop(kfg.VIDEO).squeeze(1)
+
+ inputs = batched_inputs
+ inputs = self.data_half(inputs)
+
+
+ out_size = batched_inputs.get('OUT_SIZE', 1)
+
+ task_info = inputs['task_info']
+ bs = task_info['batch_size']
+ if isinstance(bs, torch.Tensor):
+ bs = bs.item()
+
+ image_input = inputs['input_sample_list']
+ vocab_input = inputs['shared_target_sets'][self.vocab_name]
+
+
+ # 1. encode the image/video sequence.
+ moe_embedding = None
+ for image_data in image_input:
+ if 'moe_embedding' in image_data:
+ moe_embedding = image_data['moe_embedding']
+ image_encode = model._forward_data(image_input, task_info=task_info, return_all=True)[0]['data']
+
+
+ # 2. encode the vocabulary - if no pre-computed, add that into input
+ if getattr(self, 'pre_computed_word_embeds', None) is None:
+ vocab_encode = model._forward_data(vocab_input, task_info=task_info, return_all=False)[0]
+ self.pre_computed_word_embeds = vocab_encode
+ else:
+ vocab_encode = self.pre_computed_word_embeds
+
+ # 3. compute attention
+
+ comm._CAPTION_GEN_MODE = True
+ task_info.update({"prefix_spe_before_fuse": False})
+
+ beam_size = self.beam_size
+ log_probs = []
+ selected_words = None
+ seq_logprob = torch.zeros((bs, 1, 1)).cuda() # bs, 1, 1
+ seq_mask = torch.ones((bs, beam_size, 1)).cuda()
+ wt = Variable(torch.zeros(bs, dtype=torch.long).cuda().unsqueeze(1)) + self.spe_token_id
+ u_tokens_type = wt.new_zeros(wt.shape) # [Note] we assume the type tokens are 0.
+
+ history_states = image_encode[:-1]
+ len_prefix = history_states[0].shape[1]
+ total_history_states = [ history_states[0].new_zeros(beam_size * bs, image_encode[0].shape[1] + self.max_seq_len, image_encode[0].shape[2]) for _ in history_states]
+ for i, ths in enumerate(total_history_states):
+ hs = history_states[i]
+ ths[:hs.shape[0], :hs.shape[1], :] = hs
+
+ outputs = []
+ common_info = {
+ "modality": "text",
+ 'data_type': 'input',
+ 'moe_embedding': moe_embedding,
+
+ }
+ for t in range(self.max_seq_len):
+ cur_beam_size = 1 if t == 0 else beam_size
+
+ history_states = [ ths[ :(cur_beam_size*bs), :(len_prefix+t), :] for ths in total_history_states]
+
+ step_data = { "data": wt,
+ "time_step": t,
+ "sample_info":
+ {
+ "data_cum_length": [1, len_prefix, len_prefix+t+1]
+ }
+ }
+ step_data.update(common_info)
+
+ step_encode = model._forward_data([step_data], task_info=task_info, history_states=history_states, return_all=False)
+
+ step_predictor_input = {
+ "input_sample_list": step_encode,
+ "target_sample_list": [],
+ "shared_target_sets": {self.vocab_name: [vocab_encode]},
+ "target_set_list": [self.vocab_name],
+ "target_idx_list": [],
+ "task_info": task_info
+ }
+ logit = model.loss_prepare(**step_predictor_input)['output']
+
+ with autocast(enabled=not self.cfg.SOLVER.FORCE_SOFTMAX_FP16):
+ word_logprob = F.log_softmax(logit, dim=-1)
+ word_logprob = word_logprob.view(bs, cur_beam_size, -1)
+ candidate_logprob = seq_logprob + word_logprob
+
+ # # Mask sequence if it reaches EOS
+ # if t > 0:
+ # mask = (selected_words.view(bs, cur_beam_size) != 0).float().unsqueeze(-1) # 为什么是不等于0
+ # seq_mask = seq_mask * mask
+ # word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
+ # old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
+ # old_seq_logprob[:, :, 1:] = -999
+ # candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
+
+ if t > 0:
+ mask = (selected_words.view(bs, cur_beam_size) != self.eos_token_id).float().unsqueeze(-1)
+ seq_mask = seq_mask * mask
+ word_logprob = word_logprob * seq_mask.expand_as(word_logprob)
+ old_seq_logprob = seq_logprob.expand_as(candidate_logprob).contiguous()
+ old_seq_logprob[:, :, :self.eos_token_id] = -999
+ old_seq_logprob[:, :, self.eos_token_id + 1:] = -999
+ candidate_logprob = seq_mask * candidate_logprob + old_seq_logprob * (1 - seq_mask)
+
+ selected_idx, selected_logprob = self._select(bs, beam_size, t, candidate_logprob) # bs beam
+ selected_beam = torch.div(selected_idx, candidate_logprob.shape[-1], rounding_mode='floor')
+ selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
+
+ self._expand_state(history_states, selected_beam, bs, beam_size, cur_beam_size)
+
+ seq_logprob = selected_logprob.unsqueeze(-1)
+ seq_mask = torch.gather(seq_mask, 1, selected_beam.unsqueeze(-1))
+ outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
+ outputs.append(selected_words.unsqueeze(-1))
+
+ this_word_logprob = torch.gather(word_logprob, 1,
+ selected_beam.unsqueeze(-1).expand(bs, beam_size, word_logprob.shape[-1]))
+ this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
+ log_probs = list(
+ torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(bs, beam_size, 1)) for o in log_probs)
+ log_probs.append(this_word_logprob)
+ selected_words = selected_words.view(-1, 1)
+ # wt = selected_words
+
+ if t == 0:
+ u_tokens_type = expand_tensor(u_tokens_type, beam_size)
+ wt = expand_tensor(wt, beam_size)
+
+ step_selected_data = {"data": selected_words, "time_step": t, "sample_info": {"data_cum_length": [1, len_prefix, len_prefix+t+1]}}
+ step_selected_data.update(common_info)
+
+ step_selected_encode = model._forward_data([step_selected_data], task_info=task_info, history_states=history_states, return_all=True)
+
+ for i, ths in enumerate(total_history_states):
+ hs = history_states[i]
+ ths[:hs.shape[0], :hs.shape[1], :] = hs
+ ths[:hs.shape[0], hs.shape[1], :] = step_selected_encode[0]['data'][i].squeeze(1)
+
+ outputs = torch.cat(outputs, -1)
+
+
+ if self.len_penalty > 0:
+ step = outputs.ne(self.eos_token_id).sum(-1, keepdim=True) + 1
+ seq_logprob /= step ** self.len_penalty
+ seq_logprob, sort_idxs = torch.sort(seq_logprob, 1, descending=True)
+
+ outputs = torch.gather(outputs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
+ log_probs = torch.cat(log_probs, -1)
+ log_probs = torch.gather(log_probs, 1, sort_idxs.expand(bs, beam_size, self.max_seq_len))
+
+ outputs = outputs.contiguous()[:, :out_size]
+ log_probs = log_probs.contiguous()[:, :out_size]
+ if out_size == 1:
+ outputs = outputs.squeeze(1)
+ log_probs = log_probs.squeeze(1)
+
+ comm._CAPTION_GEN_MODE = False
+
+ ids = torch.stack([torch.tensor(v['id']) for v in inputs['input_sample_list'][0]['sample_info']])
+
+ return {
+ "IDS": ids,
+ "G_SENTS_IDS": outputs,
+ "G_LOGP": log_probs
+ }
diff --git a/uniperceiver/modeling/decode_strategy/decode_strategy.py b/uniperceiver/modeling/decode_strategy/decode_strategy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aafb19ace60e50eb8c750d4703e6243aa0dbdcd
--- /dev/null
+++ b/uniperceiver/modeling/decode_strategy/decode_strategy.py
@@ -0,0 +1,101 @@
+
+from abc import ABCMeta, abstractmethod
+import torch
+import torch.nn as nn
+from uniperceiver.config import configurable
+from uniperceiver.functional import load_vocab, decode_sequence, decode_sequence_bert
+# from uniperceiver.tokenization import BertTokenizer
+from uniperceiver.tokenization import ClipTokenizer
+
+class DecodeStrategy(nn.Module, metaclass=ABCMeta):
+ @configurable
+ def __init__(
+ self,
+ *,
+ vocab_path,
+ vocab_name,
+ beam_size,
+ max_seq_len,
+ tokenizer,
+ bos_token_id,
+ eos_token_id,
+ spe_token_id = None,
+ fp16=False,
+ cfg=None,
+ ):
+ super().__init__()
+ self.beam_size = beam_size
+ if tokenizer is None:
+ self.vocab = load_vocab(vocab_path)
+ else:
+ self.vocab = None
+
+ if len(vocab_name) > 1:
+ raise NotImplementedError("Only support caption inference on a single vocabulary!")
+ else:
+ self.vocab_name = vocab_name[0]
+ self.max_seq_len = max_seq_len
+ self.tokenizer = tokenizer
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.spe_token_id = spe_token_id
+ self.fp16 = fp16
+ self.cfg = cfg
+ self.len_penalty = self.cfg.DECODE_STRATEGY.get('LEN_PENALTY', 0.0) # do not normalize
+ pass
+
+ @classmethod
+ def from_config(cls, cfg):
+ tokenizer_map = {
+ # 'BERT': BertTokenizer,
+ 'CLIP': ClipTokenizer,
+ }
+
+ tokenizer_cls = tokenizer_map.get(cfg.INFERENCE.VOCAB, None)
+ spe_token_id = None
+ if tokenizer_cls is None:
+ tokenizer = None
+ bos_token_id = 0
+ eos_token_id = 0
+ elif cfg.INFERENCE.VOCAB == 'CLIP':
+ tokenizer = tokenizer_cls()
+ bos_token_id = tokenizer.vocab['<|startoftext|>']
+ eos_token_id = tokenizer.vocab['<|endoftext|>']
+ spe_token_id = tokenizer.vocab['<|spe|>']
+ elif cfg.INFERENCE.VOCAB == 'CLIP_CAPTION':
+ tokenizer = tokenizer_cls()
+ bos_token_id = tokenizer.vocab['<|gen|>']
+ eos_token_id = tokenizer.vocab['<|endoftext|>']
+ else:
+ tokenizer = tokenizer_cls.from_pretrained(cfg.MODEL.PRETRAINING.MODEL_NAME, do_lower_case=cfg.MODEL.PRETRAINING.DO_LOWER_CASE)
+ if cfg.INFERENCE.VOCAB == 'BERT':
+ raise NotImplementedError
+ bos_token_id = tokenizer.vocab["[CLS]"]
+ eos_token_id = tokenizer.vocab["[SEP]"]
+
+ return {
+ "vocab_path": cfg.INFERENCE.VOCAB,
+ "vocab_name": cfg.DATASETS.TARGET_SET,
+ "beam_size": cfg.DECODE_STRATEGY.BEAM_SIZE,
+ "max_seq_len": cfg.MODEL.EVAL_MAX_SEQ_LEN if 'EVAL_MAX_SEQ_LEN' in cfg.MODEL else cfg.MODEL.MAX_SEQ_LEN,
+ 'tokenizer': tokenizer,
+ "bos_token_id": bos_token_id,
+ "eos_token_id": eos_token_id,
+ "spe_token_id": spe_token_id,
+ "cfg": cfg,
+ # "fp16": cfg.SOLVER.AMP_FP16,
+ }
+
+ @abstractmethod
+ def _forward(self, batched_inputs, model):
+ pass
+
+ def forward(self, batched_inputs, output_sents, model):
+ ret = self._forward(batched_inputs, model)
+ if output_sents:
+ if self.vocab:
+ sents = decode_sequence(self.vocab, ret["G_SENTS_IDS"])
+ else:
+ sents = decode_sequence_bert(self.tokenizer, ret["G_SENTS_IDS"], self.eos_token_id)
+ ret.update({ "output": sents })
+ return ret
\ No newline at end of file
diff --git a/uniperceiver/modeling/embedding/__init__.py b/uniperceiver/modeling/embedding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..995a51689d598d2657f487b02eeb38968686ad65
--- /dev/null
+++ b/uniperceiver/modeling/embedding/__init__.py
@@ -0,0 +1,7 @@
+
+from .build import build_embeddings
+from .token_embed import TokenBaseEmbedding
+from .video_embed import VideoBaseEmbedding
+from .prompt_embed import PrefixPromptEmbedding
+
+__all__ = list(globals().keys())
\ No newline at end of file
diff --git a/uniperceiver/modeling/embedding/build.py b/uniperceiver/modeling/embedding/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8487d3d7bb6add72954b3fd6f6c9c8a9f582b3
--- /dev/null
+++ b/uniperceiver/modeling/embedding/build.py
@@ -0,0 +1,11 @@
+
+from uniperceiver.utils.registry import Registry
+
+EMBEDDING_REGISTRY = Registry("EMBEDDING")
+EMBEDDING_REGISTRY.__doc__ = """
+Registry for embedding
+"""
+
+def build_embeddings(cfg, name):
+ embeddings = None if name.lower() == "none" else EMBEDDING_REGISTRY.get(name)(cfg)
+ return embeddings
\ No newline at end of file
diff --git a/uniperceiver/modeling/embedding/position_embedding.py b/uniperceiver/modeling/embedding/position_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..17ca787108db4bb39e8f76f41e6cf8bf93b3561f
--- /dev/null
+++ b/uniperceiver/modeling/embedding/position_embedding.py
@@ -0,0 +1,53 @@
+import math
+import torch
+from torch import nn
+
+from uniperceiver.utils.registry import Registry
+
+POSITION_ENC_REGISTRY = Registry("POSITION_ENC")
+POSITION_ENC_REGISTRY.__doc__ = """
+Registry for positional encoding
+"""
+
+__all__ = ["SinusoidEncoding", "NNEmbeddingEncoding"]
+
+def build_position_encoding(cfg, dim, max_len):
+ name = cfg.MODEL.TOKEN_EMBED.POSITION
+ return POSITION_ENC_REGISTRY.get(name)(dim, max_len)
+
+@POSITION_ENC_REGISTRY.register()
+class SinusoidEncoding(nn.Module):
+ def __init__(self, dim, max_len):
+ super(SinusoidEncoding, self).__init__()
+ pe = torch.zeros(max_len, dim)
+ position = torch.arange(0, max_len).unsqueeze(1).float()
+ div_term = torch.exp(torch.arange(0, dim, 2).float() *
+ -(math.log(max_len * 2.0) / dim))
+ pe[:, 0::2] = torch.sin(position * div_term)
+ pe[:, 1::2] = torch.cos(position * div_term)
+ pe = pe.unsqueeze(0)
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ if isinstance(x, int):
+ return self.pe[:, x]
+ else:
+ x_size = x.size(1)
+ return self.pe[:, :x_size]
+
+@POSITION_ENC_REGISTRY.register()
+class NNEmbeddingEncoding(nn.Module):
+ def __init__(self, dim, max_len):
+ super(NNEmbeddingEncoding, self).__init__()
+ self.position_embeddings = nn.Embedding(max_len, dim)
+
+ def forward(self, x, start_time=0):
+ if isinstance(x, int):
+ position_embeddings = self.position_embeddings(torch.tensor([x], dtype=torch.long).cuda())
+ elif isinstance(x, torch.Tensor) and x.dim()==1:
+ position_embeddings = self.position_embeddings(x)
+ else:
+ x_size = x.size(1)
+ position_ids = torch.arange(x_size, dtype=torch.long, device=x.device) + start_time
+ position_embeddings = self.position_embeddings(position_ids)
+ return position_embeddings
\ No newline at end of file
diff --git a/uniperceiver/modeling/embedding/prompt_embed.py b/uniperceiver/modeling/embedding/prompt_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6149861d9df12016361ce4e9d6416954bbf1684
--- /dev/null
+++ b/uniperceiver/modeling/embedding/prompt_embed.py
@@ -0,0 +1,126 @@
+import torch
+from torch import nn
+from einops import rearrange
+
+from uniperceiver.config import configurable
+
+from ..layers.create_act import get_act_layer
+from .build import EMBEDDING_REGISTRY
+
+__all__ = ["PrefixPromptEmbedding"]
+
+@EMBEDDING_REGISTRY.register()
+class PrefixPromptEmbedding(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ cfg,
+ dim: int,
+ query_size: int, # include /
+ label_prompt,
+ label_size,
+ **kwargs):
+ super(PrefixPromptEmbedding, self).__init__()
+
+ self.cfg = cfg
+ self.input_prompt = cfg.MODEL.PROMPT_EMBED.INPUT_PROMPT and cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH > 0
+ self.target_prompt = cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT and cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT_LENGTH > 0
+ self.label_prompt = label_prompt
+
+ if self.input_prompt:
+ self.embeddings = nn.Embedding(cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH, cfg.MODEL.ENCODER_DIM)
+ if self.target_prompt:
+ if self.label_prompt:
+ self.target_embedding = nn.Parameter(
+ torch.zeros((cfg.MODEL.PROMPT_EMBED.LABEL_SIZE, 1, cfg.MODEL.ENCODER_DIM)))
+ else:
+ self.target_embedding = nn.Embedding(cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT_LENGTH, cfg.MODEL.ENCODER_DIM)
+
+ self.embeddings_act = kwargs.pop("embeddings_act", None)
+ self.embeddings_norm = kwargs.pop("embeddings_norm", None)
+ self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
+ self.prompt_with_pos = kwargs.pop('prompt_with_pos', None)
+
+ @classmethod
+ def from_config(cls, cfg):
+ kwargs = {
+ "dim": cfg.MODEL.ENCODER_DIM,
+ "query_size": cfg.MODEL.PROMPT_EMBED.PROMPT_LENGTH,
+ "label_prompt": cfg.MODEL.PROMPT_EMBED.LABLE_PROMPT,
+ "label_size": cfg.MODEL.PROMPT_EMBED.LABEL_SIZE,
+ "target_prompt": cfg.MODEL.PROMPT_EMBED.TARGET_PROMPT,
+ "num_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS,
+ }
+
+ activation_name = (cfg.MODEL.PROMPT_EMBED.ACTIVATION).lower()
+ if activation_name != "none":
+ activation = get_act_layer(activation_name)
+ assert activation is not None
+
+ act_kwargs = {}
+ if activation_name in {"elu", "celu"}:
+ act_kwargs["alpha"] = cfg.MODEL.PROMPT_EMBED.ELU_ALPHA
+ embeddings_act = activation(**act_kwargs)
+ kwargs['embeddings_act'] = embeddings_act
+
+ if cfg.MODEL.PROMPT_EMBED.DROPOUT > 0:
+ embeddings_dropout = nn.Dropout(cfg.MODEL.PROMPT_EMBED.DROPOUT)
+ kwargs['embeddings_dropout'] = embeddings_dropout
+
+ if cfg.MODEL.PROMPT_EMBED.USE_NORM:
+ embeddings_norm = nn.LayerNorm(cfg.MODEL.PROMPT_EMBED.DIM)
+ kwargs['embeddings_norm'] = embeddings_norm
+
+ kwargs['prompt_with_pos'] = cfg.MODEL.PROMPT_EMBED.WITH_POS
+ kwargs['cfg'] = cfg
+
+ return kwargs
+
+ def forward(self, data_list):
+ bs = data_list[0]['data'].shape[0]
+ prompt_embed = self._forward(bs, data_type=data_list[0]['data_type'])
+
+ if prompt_embed is None:
+ return
+
+
+ insert_data = {
+ 'data': prompt_embed,
+ 'invalid_mask': None,
+ 'modality': None,
+ 'data_type': data_list[0]['data_type'],
+
+ }
+ data_list.insert(0, insert_data)
+
+ #TODO label prompt
+
+
+
+
+ def _forward(self, bs, data_type:str = None):
+ assert data_type in ['input', 'target']
+ if data_type == 'input' and self.input_prompt:
+ embeddings = self.embeddings.weight.unsqueeze(0).repeat(bs, 1, 1)
+ elif data_type == 'target' and self.target_prompt:
+ if not self.label_prompt:
+ embeddings = self.target_embedding.weight.unsqueeze(0).repeat(bs, 1, 1)
+ elif self.label_prompt:
+ embeddings = self.target_embedding
+ else:
+ # target will not have prompt_embedding
+ return None
+ else:
+ return None
+
+ if self.embeddings_act is not None:
+ embeddings = self.embeddings_act(embeddings)
+
+ if self.embeddings_norm is not None:
+ embeddings = self.embeddings_norm(embeddings)
+
+ if self.embeddings_dropout is not None:
+ embeddings = self.embeddings_dropout(embeddings)
+
+ return embeddings
diff --git a/uniperceiver/modeling/embedding/token_embed.py b/uniperceiver/modeling/embedding/token_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffafe97c3822927971b6649310ed38036eb46329
--- /dev/null
+++ b/uniperceiver/modeling/embedding/token_embed.py
@@ -0,0 +1,197 @@
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+from ..layers.create_act import get_act_layer
+from .build import EMBEDDING_REGISTRY
+from .position_embedding import build_position_encoding
+# from uniperceiver.modeling.layers import LayerNorm
+from uniperceiver.utils import comm
+import copy
+from uniperceiver.modeling.layers import FP16LayerNorm
+
+
+__all__ = ["TokenBaseEmbedding"]
+
+@EMBEDDING_REGISTRY.register()
+class TokenBaseEmbedding(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ dim: int,
+ vocab_size: int, # include /
+ **kwargs
+ ):
+ super(TokenBaseEmbedding, self).__init__()
+ self.embeddings = nn.Embedding(vocab_size, dim)
+ self.embeddings_act = kwargs.pop("embeddings_act", None)
+ self.embeddings_norm = kwargs.pop("embeddings_norm", None)
+ self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
+ self.embeddings_pos = kwargs.pop("embeddings_pos", None)
+ self.embeddings_token_type = kwargs.pop('embeddings_token_type', None)
+ self.embeddings_token_seg = kwargs.pop('embeddings_token_seg', None)
+ self.bw_own_embed = kwargs.pop('bw_own_embed', False)
+ self.pos_before = kwargs.pop('pos_before', True)
+ self.cfg = kwargs.pop('cfg', None)
+
+ if self.bw_own_embed:
+ # only for debugging
+ self.bw_embeddings = copy.deepcopy(self.embeddings)
+ self.bw_embeddings_norm = copy.deepcopy(self.embeddings_norm)
+ self.bw_embeddings_pos = copy.deepcopy(self.embeddings_pos)
+ self.bw_embeddings_token_type = copy.deepcopy(self.embeddings_token_type)
+ self.s_token_bias = None
+
+ @classmethod
+ def from_config(cls, cfg):
+ kwargs = {
+ "dim": cfg.MODEL.TOKEN_EMBED.DIM,
+ "vocab_size": cfg.MODEL.VOCAB_SIZE
+ }
+
+ activation_name = (cfg.MODEL.TOKEN_EMBED.ACTIVATION).lower()
+ if activation_name != "none":
+ activation = get_act_layer(activation_name)
+ assert activation is not None
+
+ act_kwargs = {}
+ if activation_name in { "elu", "celu" }:
+ act_kwargs["alpha"] = cfg.MODEL.TOKEN_EMBED.ELU_ALPHA
+ embeddings_act = activation(**act_kwargs)
+ kwargs['embeddings_act'] = embeddings_act
+
+ if cfg.MODEL.TOKEN_EMBED.DROPOUT > 0:
+ embeddings_dropout = nn.Dropout(cfg.MODEL.TOKEN_EMBED.DROPOUT)
+ kwargs['embeddings_dropout'] = embeddings_dropout
+
+ if cfg.MODEL.TOKEN_EMBED.USE_NORM:
+ if cfg.SOLVER.FORCE_LN_FP16:
+ embeddings_norm = FP16LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
+ else:
+ embeddings_norm = nn.LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
+ kwargs['embeddings_norm'] = embeddings_norm
+
+ if (cfg.MODEL.TOKEN_EMBED.POSITION).lower() != 'none':
+ embeddings_pos = build_position_encoding(cfg,
+ cfg.MODEL.TOKEN_EMBED.DIM, cfg.MODEL.TOKEN_EMBED.POSITION_MAX_LEN)
+ kwargs['embeddings_pos'] = embeddings_pos
+
+ if cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE > 0:
+ embeddings_token_type = nn.Embedding(
+ cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
+ kwargs['embeddings_token_type'] = embeddings_token_type
+
+ if cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE > 0:
+ embeddings_token_seg = nn.Embedding(
+ cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
+ kwargs['embeddings_token_seg'] = embeddings_token_seg
+
+ # for debug
+ kwargs['bw_own_embed'] = cfg.MODEL.BW_OWD_EMBED
+ kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
+ kwargs['cfg'] = cfg
+ return kwargs
+
+ def get_time_step(self, data, sample_info, task_info=None):
+ """
+ data: Bs, L
+ task_info: {
+ task_type: str
+ }
+ """
+ # TODO: the position embedding for caption text should be handled in a different way. 0,1, n/2,0,1, n/2,
+ if task_info is None:
+ task_type = ''
+ else:
+ task_type = task_info.get('task_type', None)
+ time_step = None
+ if isinstance(sample_info, list):
+ sample_info = sample_info[0]
+ if task_type in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
+ text_length = data.shape[1]
+ time_step = torch.cat([
+ torch.arange(text_length // 2,
+ dtype=torch.long,
+ device=data.device) for _ in range(2)
+ ])
+ elif task_type == 'VQA' and sample_info.get('text_spe_cat', False):
+ text_length = data.shape[1]
+ time_step = torch.cat([
+ torch.arange(text_length - 1,
+ dtype=torch.long,
+ device=data.device),
+ torch.arange(1, dtype=torch.long, device=data.device)
+ ])
+
+
+ return time_step
+
+ def forward(self, data, sample_info={}, task_info={}, **kwargs):
+
+
+ time_step = kwargs.pop('time_step', None)
+ if time_step is None:
+ time_step = self.get_time_step(data, sample_info, task_info)
+
+ if kwargs.pop("prompt_with_pos", False):
+ raise NotImplementedError
+ else:
+ start_time = 0
+
+ type_embed = kwargs.get('type_embed', True)
+ pos_emb = kwargs.get('pos_embed', True)
+
+ data = self._forward(data,
+ type_embed=type_embed,
+ pos_emb=pos_emb,
+ token_seg_ids=None,
+ time_step=time_step,
+ start_time=start_time)
+
+ return data
+
+
+
+ def set_s_token_bias(self, s_token_bias):
+ self.s_token_bias = s_token_bias
+
+ def _forward(self, input_ids, type_embed=True, token_seg_ids=None, time_step=None, pos_emb=True, start_time=0, ):
+
+ embeddings = self.embeddings(input_ids)
+ if self.cfg.SOLVER.FORCE_EMBED_FP16:
+ embeddings = embeddings.half()
+
+ if self.s_token_bias is not None:
+ # learnable
+ embeddings[input_ids == 49410] = embeddings[input_ids == 49410] + self.s_token_bias
+
+ if self.embeddings_pos is not None and pos_emb and self.pos_before:
+ pos_inputs = input_ids if time_step is None else time_step
+ position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
+ embeddings = embeddings + position_embeddings.to(embeddings.dtype)
+
+ if self.embeddings_token_type is not None and type_embed:
+
+ embeddings_token_type = self.embeddings_token_type.weight[0].unsqueeze(0).unsqueeze(1)
+ embeddings = embeddings + embeddings_token_type.to(embeddings.dtype)
+
+ if (self.embeddings_token_seg is not None) and (token_seg_ids is not None):
+ embeddings_token_seg = self.embeddings_token_seg(token_seg_ids)
+ embeddings = embeddings + embeddings_token_seg
+
+ if self.embeddings_act is not None:
+ embeddings = self.embeddings_act(embeddings)
+
+ if self.embeddings_norm is not None:
+ embeddings = self.embeddings_norm(embeddings)
+
+ if self.embeddings_pos is not None and pos_emb and not self.pos_before:
+ pos_inputs = input_ids if time_step is None else time_step
+ position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
+ embeddings = embeddings + position_embeddings.to(embeddings.dtype)
+
+ if self.embeddings_dropout is not None:
+ embeddings = self.embeddings_dropout(embeddings)
+
+ return embeddings
diff --git a/uniperceiver/modeling/embedding/video_embed.py b/uniperceiver/modeling/embedding/video_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba58523e5ff9a57b8634d3c6345cfd941c9501d6
--- /dev/null
+++ b/uniperceiver/modeling/embedding/video_embed.py
@@ -0,0 +1,190 @@
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+from ..layers.create_act import get_act_layer
+from .build import EMBEDDING_REGISTRY
+from .position_embedding import NNEmbeddingEncoding
+from einops import rearrange, repeat
+from uniperceiver.modeling.layers import FP16LayerNorm
+
+
+__all__ = ["VideoBaseEmbedding"]
+
+@EMBEDDING_REGISTRY.register()
+class VideoBaseEmbedding(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ cfg: dict,
+ in_dim: int,
+ out_dim: int,
+ patch_size: int,
+ time_span: int,
+ max_time_len: int,
+ max_spatial_size = 196,
+ **kwargs
+ ):
+ super(VideoBaseEmbedding, self).__init__()
+ self.cfg = cfg
+ self.embeddings = nn.Linear(in_dim, out_dim)
+ self.embeddings_act = kwargs.pop("embeddings_act", None)
+ self.embeddings_norm = kwargs.pop("embeddings_norm", None)
+ self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
+ self.embeddings_pos = kwargs.pop('embeddings_pos', None)
+ self.embeddings_type = kwargs.pop("embeddings_token_type", None)
+ self.random_temporal_pos = kwargs.pop("random_temporal_pos", True)
+ self.patch_size = patch_size
+ self.time_span = time_span
+ self.pos_before = kwargs.pop('pos_before', True)
+ self.add_type_embedding = cfg.MODEL.VIDEO_EMBED.ADD_TYPE_EMBED
+ if self.add_type_embedding:
+ assert self.embeddings_type is not None
+
+ self.embeddings_st_pos = None
+ self.max_spatial_size = max_spatial_size
+ if isinstance(self.embeddings_pos, str):
+ if self.embeddings_pos == 'divide_st_pos':
+ self.embeddings_st_pos = Divide_ST_POS(
+ max_spatial_size, max_time_len, out_dim,
+ self.random_temporal_pos)
+ self.embeddings_pos = None
+ del self.embeddings
+ self.embeddings = nn.Conv2d(in_dim//(self.patch_size**2), out_dim, kernel_size=self.patch_size, stride=self.patch_size)
+ pass
+
+ def replace_weight(self, visual_embed):
+ if visual_embed is not None:
+ del self.embeddings
+ self.embeddings = visual_embed.patch_embed.proj
+
+ def share_spatial_pos(self, visual_embed):
+ if self.embeddings_st_pos is not None and visual_embed is not None:
+
+ if self.embeddings_st_pos.spatial_pos_embed.weight.shape[0] == visual_embed.patch_embed.pos_embed.weight.shape[0]:
+ self.embeddings_st_pos.spatial_pos_embed_index = 0
+ else:
+ # cls token in image patch tokenizer
+ self.embeddings_st_pos.spatial_pos_embed_index = 1
+ del self.embeddings_st_pos.spatial_pos_embed
+ self.embeddings_st_pos.spatial_pos_embed = visual_embed.patch_embed.pos_embed
+ pass
+
+ @classmethod
+ def from_config(cls, cfg):
+ kwargs = {
+ "in_dim": cfg.MODEL.VIDEO_EMBED.IN_DIM,
+ "out_dim": cfg.MODEL.VIDEO_EMBED.OUT_DIM,
+ "patch_size": cfg.MODEL.PATCH_SIZE,
+ "time_span": cfg.MODEL.VIDEO_EMBED.PATCH_SIZE_T,
+ "max_time_len": cfg.MODEL.VIDEO_EMBED.MAX_FRAMES,
+ }
+ max_spatial_size = int(cfg.MODEL.IMG_INPUT_SIZE/cfg.MODEL.PATCH_SIZE)**2
+ kwargs['max_spatial_size'] = max_spatial_size
+ activation_name = (cfg.MODEL.VIDEO_EMBED.ACTIVATION).lower()
+ if activation_name != "none":
+ activation = get_act_layer(activation_name)
+ assert activation is not None
+
+ act_kwargs = {}
+ if activation_name in { "elu", "celu" }:
+ act_kwargs["alpha"] = cfg.MODEL.VIDEO_EMBED.ELU_ALPHA
+ embeddings_act = activation(**act_kwargs)
+ kwargs['embeddings_act'] = embeddings_act
+
+ if cfg.MODEL.VIDEO_EMBED.DROPOUT > 0:
+ embeddings_dropout = nn.Dropout(cfg.MODEL.VIDEO_EMBED.DROPOUT)
+ kwargs['embeddings_dropout'] = embeddings_dropout
+
+ if cfg.MODEL.VIDEO_EMBED.USE_NORM:
+ if cfg.SOLVER.FORCE_LN_FP16:
+ embeddings_norm = FP16LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
+ else:
+ embeddings_norm = nn.LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
+ kwargs['embeddings_norm'] = embeddings_norm
+
+ if cfg.MODEL.VIDEO_EMBED.DIVIDE_ST_POS:
+ kwargs['embeddings_pos'] = "divide_st_pos"
+
+ elif cfg.MODEL.VIDEO_EMBED.POSITION.lower() != 'none':
+ embeddings_pos = NNEmbeddingEncoding(cfg.MODEL.VIDEO_EMBED.OUT_DIM, cfg.MODEL.VIDEO_EMBED.MAX_LENGTH)
+ kwargs['embeddings_pos'] = embeddings_pos
+
+ if cfg.MODEL.VIDEO_EMBED.TYPE_SIZE > 0:
+ embeddings_token_type = nn.Embedding(
+ cfg.MODEL.VIDEO_EMBED.TYPE_SIZE, cfg.MODEL.VIDEO_EMBED.OUT_DIM)
+ kwargs['embeddings_token_type'] = embeddings_token_type
+ kwargs['random_temporal_pos'] = cfg.MODEL.VIDEO_EMBED.POS_RANDOM
+ kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
+ kwargs['cfg'] = cfg
+ return kwargs
+
+ def forward(self, data, **kwargs):
+
+ if data.dim() == 4:
+ #images
+ data = data.unsqueeze(1) # BS, 3, 224, 224
+
+
+ if self.embeddings_st_pos is not None:
+ bs = data.size(0)
+ x = self.embeddings(data.flatten(0, 1)) # b*t, dim, 14, 14
+ x = x.flatten(2) # .flatten(2)
+ embeddings = rearrange(x, '(b t s) c hw -> b t hw (s c)', b=bs, s = self.time_span)
+ embeddings_pos = self.embeddings_st_pos(embeddings).unsqueeze(
+ 0).flatten(1, 2)
+ embeddings = embeddings.flatten(1, 2)
+ if self.pos_before:
+ embeddings = embeddings + embeddings_pos.to(embeddings.dtype)
+
+
+ if self.embeddings_pos is not None:
+ x = rearrange(data, 'b (t s) c (h p1) (w p2) -> b (t h w) (s c p1 p2)', s = self.time_span, p1 = self.patch_size, p2 = self.patch_size)
+ embeddings = self.embeddings(x)
+ embeddings_pos = self.embeddings_pos(x).unsqueeze(0)
+ if self.pos_before:
+ embeddings = embeddings + embeddings_pos.to(embeddings.dtype)
+
+ if self.add_type_embedding:
+ embeddings = embeddings + self.embeddings_type.weight[0].unsqueeze(0).unsqueeze(1).to(embeddings.dtype)
+
+ if self.embeddings_act is not None:
+ embeddings = self.embeddings_act(embeddings)
+
+ if self.embeddings_norm is not None:
+ embeddings = self.embeddings_norm(embeddings)
+
+ if not self.pos_before:
+ embeddings = embeddings + embeddings_pos
+
+ if self.embeddings_dropout is not None:
+ embeddings = self.embeddings_dropout(embeddings)
+
+ return embeddings
+
+
+
+class Divide_ST_POS(nn.Module):
+ def __init__(self, num_patches, max_time_len, out_dim,
+ random_temporal_pos):
+ super(Divide_ST_POS, self).__init__()
+ self.spatial_pos_embed = nn.Embedding(num_patches, out_dim)
+ self.temporal_pos_embed = nn.Embedding(max_time_len, out_dim)
+ self.spatial_pos_embed_index = 0 # sometimes image has cls_token
+ self.max_frames = max_time_len
+ self.random_temporal_pos = random_temporal_pos
+
+ def forward(self, x):
+ dtype = x.dtype
+ temp_len, spatial_size = x.size(1), x.size(2)
+
+ if self.training and self.random_temporal_pos:
+ temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device) + \
+ torch.randint(0, self.max_frames - temp_len + 1, size=(1,), dtype=torch.long, device=x.device)
+ else:
+ temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device)
+
+ pos_embed = self.temporal_pos_embed(temporal_pos_ids).unsqueeze(1).to(dtype=dtype) + \
+ self.spatial_pos_embed( torch.arange(start= self.spatial_pos_embed_index, end=spatial_size + self.spatial_pos_embed_index , dtype=torch.long, device=x.device)).unsqueeze(0).to(dtype=dtype)
+ return pos_embed
diff --git a/uniperceiver/modeling/encoder/__init__.py b/uniperceiver/modeling/encoder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b12eaec6eec77dc90d285f5333a4af871d19e83
--- /dev/null
+++ b/uniperceiver/modeling/encoder/__init__.py
@@ -0,0 +1,7 @@
+
+from .build import build_encoder, add_encoder_config, build_unfused_encoders
+from .unified_bert_encoder import UnifiedBertEncoder
+from .standard_vit_encoder import StandardViT, TextEncoder, VisualEncoder
+
+
+__all__ = list(globals().keys())
diff --git a/uniperceiver/modeling/encoder/build.py b/uniperceiver/modeling/encoder/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1a898b4144a9b075e65663b6b64455b06dd2398
--- /dev/null
+++ b/uniperceiver/modeling/encoder/build.py
@@ -0,0 +1,28 @@
+from uniperceiver.utils.registry import Registry
+from torch import ModuleDict
+
+
+ENCODER_REGISTRY = Registry("ENCODER")
+ENCODER_REGISTRY.__doc__ = """
+Registry for encoder
+"""
+
+def build_encoder(cfg):
+ encoder = ENCODER_REGISTRY.get(cfg.MODEL.ENCODER)(cfg) if len(cfg.MODEL.ENCODER) > 0 else None
+ return encoder
+
+def build_unfused_encoders(cfg):
+ from uniperceiver.config import CfgNode
+ encoder_dict = {}
+ for config in cfg.ENCODERS:
+ tmg_config = CfgNode(config)
+ encoder = ENCODER_REGISTRY.get(
+ tmg_config.TYPE)(tmg_config, cfg) if len(tmg_config.TYPE) > 0 else None
+ encoder_dict[tmg_config.NAME] = encoder
+
+ return encoder_dict
+
+
+def add_encoder_config(cfg, tmp_cfg):
+ if len(tmp_cfg.MODEL.ENCODER) > 0:
+ ENCODER_REGISTRY.get(tmp_cfg.MODEL.ENCODER).add_config(cfg)
\ No newline at end of file
diff --git a/uniperceiver/modeling/encoder/standard_vit_encoder.py b/uniperceiver/modeling/encoder/standard_vit_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b3dd3e071defc1225639b1cfce2281d4c3933c4
--- /dev/null
+++ b/uniperceiver/modeling/encoder/standard_vit_encoder.py
@@ -0,0 +1,137 @@
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+
+from ..layers.transformer_encoder_layer import TransformerEncoderLayer
+from .build import ENCODER_REGISTRY
+
+import uniperceiver.utils.comm as comm
+
+__all__ = ["StandardViT", "TextEncoder", "VisualEncoder"]
+
+
+
+@ENCODER_REGISTRY.register()
+class StandardViT(nn.Module):
+ @configurable
+ def __init__(self, *, num_hidden_layers: int, bert_layers, cfg):
+ super(StandardViT, self).__init__()
+ self.num_hidden_layers = num_hidden_layers
+ self.layers = bert_layers
+ self.cfg = cfg
+ self.name = cfg.NAME
+
+ @classmethod
+ def from_config(cls, cfg, global_cfg):
+ if cfg.DROP_PATH_PROB_FIXED:
+ dpr = [cfg.DROP_PATH_PROB for _ in range(cfg.NUM_HIDDEN_LAYERS)]
+ else:
+ dpr = [x.item() for x in torch.linspace(0, cfg.DROP_PATH_PROB, cfg.NUM_HIDDEN_LAYERS)]
+
+ layers = []
+ for i in range(cfg.NUM_HIDDEN_LAYERS):
+ layers.append(
+ TransformerEncoderLayer(
+ d_model=cfg.HIDDEN_SIZE,
+ nhead=cfg.NUM_ATTENTION_HEADS,
+ dim_feedforward=cfg.INTERMEDIATE_SIZE,
+ dropout=cfg.HIDDEN_DROPOUT_PROB,
+ drop_path_ratio=dpr[i],
+ activation=cfg.HIDDEN_ACT,
+ layer_scale=global_cfg.MODEL.LAYER_SCALE,
+ ls_init_values=global_cfg.MODEL.LAYER_SCALE_INIT,
+ batch_first=True,
+ norm_first=True,
+ cfg=cfg,
+ ))
+
+ bert_layers = nn.ModuleList(
+ layers
+ )
+ return {
+ "num_hidden_layers": cfg.NUM_HIDDEN_LAYERS,
+ "bert_layers": bert_layers,
+ "cfg": cfg
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def _forward(self, x, attn_mask=None, key_padding_masks=None, history_states=None, *kwargs):
+
+ for l, layer_module in enumerate(self.layers):
+ x = layer_module(
+ src=x, src_mask=attn_mask, src_key_padding_mask=key_padding_masks
+ )
+
+ return x
+
+
+ def forward(self, batched_inputs, return_all=False):
+
+ raise NotImplementedError
+
+@ENCODER_REGISTRY.register()
+class VisualEncoder(StandardViT):
+
+ @staticmethod
+ def _construct_attention_masks( data, sample_info, task_info):
+
+ return None
+
+ def forward(self, data, invalid_mask, sample_info, task_info, **kwargs):
+ #TODO: prepare attn mask for each task type
+ # used for visual encoder
+ attn_mask = self._construct_attention_masks(data, sample_info, task_info)
+ history_states = kwargs.pop('history_states', None)
+ out = self._forward(data,
+ attn_mask,
+ invalid_mask,
+ history_states=history_states,
+ **kwargs,
+ )
+
+ return out
+
+
+@ENCODER_REGISTRY.register()
+class TextEncoder(StandardViT):
+
+ @staticmethod
+ def _construct_attention_masks( data, sample_info, task_info):
+ mask_type = torch.bool
+ device = data.device
+
+ attn_mask = None
+ if isinstance(sample_info, list):
+ sample_info = sample_info[0]
+ if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
+ total_length = data.shape[1]
+ attn_mask = torch.ones((total_length, total_length), dtype=mask_type, device=device)
+ attn_mask[:total_length // 2, :total_length // 2] = torch.ones(
+ (total_length // 2, total_length // 2), dtype=mask_type, device=device).triu_(diagonal=1)
+ attn_mask[total_length // 2:, : total_length // 2] = torch.ones(
+ (total_length // 2, total_length // 2),
+ dtype=mask_type,
+ device=device).triu_(diagonal=0)
+ attn_mask[total_length // 2:, total_length // 2:] = ~torch.ones(
+ (total_length // 2),
+ dtype=mask_type,
+ device=device).diag()
+
+ return attn_mask
+
+ def forward(self, data, invalid_mask, sample_info, task_info, **kwargs):
+ #TODO: prepare attn mask for each task type
+ # used for text encoder
+ attn_mask = self._construct_attention_masks(data, sample_info, task_info)
+ history_states = kwargs.pop('history_states', None)
+ out = self._forward(data,
+ attn_mask,
+ invalid_mask,
+ history_states=history_states,
+ **kwargs)
+
+ return out
diff --git a/uniperceiver/modeling/encoder/unified_bert_encoder.py b/uniperceiver/modeling/encoder/unified_bert_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ea76cb5e1282f9aa8095ccd7e3a0de69181b7c4
--- /dev/null
+++ b/uniperceiver/modeling/encoder/unified_bert_encoder.py
@@ -0,0 +1,173 @@
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+from ..layers.transformer_encoder_layer import TransformerEncoderLayer
+from ..layers.transformer_encoder_moe_layer import MoETransformerEncoderLayer
+from .build import ENCODER_REGISTRY
+import uniperceiver.utils.comm as comm
+
+
+
+__all__ = ["UnifiedBertEncoder"]
+
+def _construct_attention_masks( data, sample_info, task_info):
+ mask_type = torch.bool
+ device = data.device
+
+ attn_mask = None
+ if isinstance(sample_info, list):
+ sample_info = sample_info[0]
+ if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
+
+ # the extra 1 length for spe token
+ spe_length, img_length, text_total_length = sample_info['data_length']
+ text_length = text_total_length//2
+
+ attn_mask = torch.ones((spe_length + img_length + text_total_length,
+ spe_length + img_length + text_total_length), dtype=mask_type, device=device)
+
+ attn_mask[:spe_length + img_length + text_total_length, :spe_length+img_length] = False
+ attn_mask[spe_length + img_length:spe_length + img_length + text_length, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
+ (text_length, text_length), dtype=mask_type, device=device).triu_(diagonal=1)
+ attn_mask[spe_length + img_length + text_length:, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
+ (text_length, text_length),
+ dtype=mask_type,
+ device=device).triu_(diagonal=0)
+ attn_mask[spe_length + img_length + text_length:,
+ spe_length + img_length + text_length:] = ~torch.ones(
+ (text_length), dtype=mask_type,
+ device=device).diag()
+
+ return attn_mask
+
+
+@ENCODER_REGISTRY.register()
+class UnifiedBertEncoder(nn.Module):
+ @configurable
+ def __init__(self, *, num_hidden_layers: int, bert_layers,
+ skip_target_encode, word_balance_losses,
+ bookswiki_word_alone, cfg):
+ super(UnifiedBertEncoder, self).__init__()
+ self.num_hidden_layers = num_hidden_layers
+ self.layers = bert_layers
+ self.skip_target_encode = skip_target_encode
+ self.word_balance_losses = word_balance_losses
+ self.bookswiki_word_alone = bookswiki_word_alone
+ self.cfg = cfg
+
+
+
+ @classmethod
+ def from_config(cls, cfg):
+ if cfg.MODEL.BERT.DROP_PATH_PROB_FIXED:
+ dpr = [cfg.MODEL.BERT.DROP_PATH_PROB for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
+ else:
+ dpr = [x.item() for x in torch.linspace(0, cfg.MODEL.BERT.DROP_PATH_PROB, cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
+
+ layers = []
+ for layer_idx in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS):
+ if not cfg.MOE.MOE:
+ layers.append(
+ TransformerEncoderLayer(
+ d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
+ nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
+ dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
+ dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
+ drop_path_ratio=dpr[layer_idx],
+ activation=cfg.MODEL.BERT.HIDDEN_ACT,
+ layer_scale=cfg.MODEL.LAYER_SCALE,
+ ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
+ batch_first=True,
+ norm_first=True,
+ cfg = cfg,
+ ))
+ else:
+ attention_moe = False
+ ffn_moe = False
+
+ moe_layer_start_idx = cfg.MOE.MOE_LAYER_START_IDX
+ moe_layer_end_idx = cfg.MOE.MOE_LAYER_END_IDX
+
+ if cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'odd':
+ if layer_idx % 2 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
+ moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
+ attention_moe = "SA" in moe_layers
+ ffn_moe = 'FFN' in moe_layers
+
+ elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'four':
+ if layer_idx % 4 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
+ moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
+ attention_moe = "SA" in moe_layers
+ ffn_moe = 'FFN' in moe_layers
+
+ elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'all':
+ if layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
+ moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
+ attention_moe = "SA" in moe_layers
+ ffn_moe = 'FFN' in moe_layers
+ elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'none':
+ attention_moe = None
+ ffn_moe = None
+
+
+ elif cfg.MOE.MOE:
+ raise NotImplementedError('cfg.MOE.MOE_EXPERT_LOCATION')
+
+ layers.append(
+ MoETransformerEncoderLayer(
+ d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
+ nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
+ dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
+ dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
+ drop_path_ratio=dpr[layer_idx],
+ activation=cfg.MODEL.BERT.HIDDEN_ACT,
+ layer_scale=cfg.MODEL.LAYER_SCALE,
+ ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
+ batch_first=False,
+ norm_first=True,
+ cfg = cfg,
+ ffn_moe=ffn_moe,
+ attn_moe=attention_moe,
+ ))
+
+
+
+ bert_layers = nn.ModuleList(
+ layers
+ )
+ return {
+ "num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS,
+ "skip_target_encode": cfg.MODEL.BERT.SKIP_TARGET_ENCODE,
+ "bert_layers": bert_layers,
+ "word_balance_losses": cfg.SOLVER.WORD_BALANCE_LOSSESS,
+ "bookswiki_word_alone": cfg.MODEL.BW_WORD_ALONE,
+ "cfg": cfg
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+
+ def forward(self, data, invalid_mask, sample_info, task_info, history_states=None, return_all=False, **kwargs):
+
+ attn_mask = _construct_attention_masks(data, sample_info, task_info)
+ kwargs.update({'sample_info': sample_info})
+ data_type = kwargs.get('data_type', 'input')
+ if data_type == 'target' and self.skip_target_encode:
+ # used for debugging with single gpu sometimes
+ return data
+ if return_all:
+ data_all = [data]
+ for l, layer_module in enumerate(self.layers):
+
+ if history_states is None:
+ data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, task_info=task_info, **kwargs)
+ else:
+ data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, history_states=history_states[l], task_info=task_info, **kwargs)
+
+ if return_all:
+ data_all.append(data)
+
+ return data if not return_all else data_all
diff --git a/uniperceiver/modeling/layers/__init__.py b/uniperceiver/modeling/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c3f64c35c08d62323450a6052a641bd5e4cd838
--- /dev/null
+++ b/uniperceiver/modeling/layers/__init__.py
@@ -0,0 +1 @@
+from .layer_norm import FP16LayerNorm
\ No newline at end of file
diff --git a/uniperceiver/modeling/layers/create_act.py b/uniperceiver/modeling/layers/create_act.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b415ee06cef2bdb204e410ac54296ef00e689ec
--- /dev/null
+++ b/uniperceiver/modeling/layers/create_act.py
@@ -0,0 +1,66 @@
+
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+__all__ = ["get_act_layer", "get_activation"]
+
+########################################### Layer ###########################################
+_ACT_LAYER_DEFAULT = dict(
+ relu=nn.ReLU,
+ elu=nn.ELU,
+ celu=nn.CELU,
+ sigmoid=nn.Sigmoid,
+ tanh=nn.Tanh,
+)
+
+def get_act_layer(name='none'):
+ if name in _ACT_LAYER_DEFAULT:
+ return _ACT_LAYER_DEFAULT[name]
+ else:
+ return None
+
+########################################### Function ###########################################
+def swish(x):
+ return x * torch.sigmoid(x)
+
+def _gelu_python(x):
+ """ Original Implementation of the gelu activation function in Google Bert repo when initially created.
+ For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
+ 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+ This is now written in C in torch.nn.functional
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+gelu = getattr(F, "gelu", _gelu_python)
+
+def gelu_new(x):
+ """ Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
+ Also see https://arxiv.org/abs/1606.08415
+ """
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+
+def mish(x):
+ return x * torch.tanh(nn.functional.softplus(x))
+
+ACT2FN = {
+ "relu": F.relu,
+ "swish": swish,
+ "gelu": gelu,
+ "tanh": F.tanh,
+ "gelu_new": gelu_new,
+ "mish": mish
+}
+
+def get_activation(activation_string):
+ if activation_string in ACT2FN:
+ return ACT2FN[activation_string]
+ else:
+ raise KeyError(
+ "function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
+ activation_string, list(ACT2FN.keys())
+ )
+ )
diff --git a/uniperceiver/modeling/layers/layer_norm.py b/uniperceiver/modeling/layers/layer_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9b7200facf4440db9419dee01e4873de0eb3940
--- /dev/null
+++ b/uniperceiver/modeling/layers/layer_norm.py
@@ -0,0 +1,46 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor, Size
+from torch.cuda.amp import autocast
+
+
+__all__ = ["LayerNorm"]
+
+try:
+ from apex.normalization import FusedLayerNorm as _FusedLayerNorm
+
+ has_fused_layernorm = True
+
+ class FusedLayerNorm(_FusedLayerNorm):
+ @torch.jit.unused
+ def forward(self, x):
+ if not x.is_cuda:
+ return super().forward(x)
+ else:
+ with torch.cuda.device(x.device):
+ return super().forward(x)
+
+except ImportError:
+ has_fused_layernorm = False
+
+
+def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False):
+ if torch.jit.is_scripting():
+ export = True
+ if not export and torch.cuda.is_available() and has_fused_layernorm:
+ return FusedLayerNorm(normalized_shape, eps, elementwise_affine)
+ else:
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+ return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine)
+
+class FP16LayerNorm(torch.nn.LayerNorm):
+
+ def forward(self, input):
+ with autocast(enabled=False):
+ return F.layer_norm(input.half(), self.normalized_shape, self.weight.half(), self.bias.half(), self.eps)
diff --git a/uniperceiver/modeling/layers/pe_encoder.py b/uniperceiver/modeling/layers/pe_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..2234d07c3196e43e3790131d219743a3dbb1ada1
--- /dev/null
+++ b/uniperceiver/modeling/layers/pe_encoder.py
@@ -0,0 +1,30 @@
+import torch
+
+
+class DeepPrompt(torch.nn.Module):
+ # naive implementation
+ def __init__(self, cfg):
+ super().__init__()
+
+ embedding_hidden_size = cfg.MODEL.BERT.HIDDEN_SIZE
+ self.target_prompt = cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT and not cfg.MODEL.PROMPT_EMBED.SHARE_DEEP_PROMPT
+ self.embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.INPUT_DEEP_PROMPT_LENGTH, embedding_hidden_size)
+ if self.target_prompt:
+ self.target_embedding = torch.nn.Embedding(cfg.MODEL.PROMPT_EMBED.TARGET_DEEP_PROMPT_LENGTH, embedding_hidden_size)
+
+
+ def forward(self, x, batch_first=False, data_type=None, **kwargs):
+ # x: length, bs, hidden_size
+
+ if data_type == 'target' and self.target_prompt:
+ embddings = self.target_embedding.weight
+ else:
+ embddings = self.embedding.weight
+
+ if batch_first:
+ bs = x.shape[0]
+ embddings = embddings.unsqueeze(0).expand(bs, -1, -1)
+ else:
+ bs = x.shape[1]
+ embddings = embddings.unsqueeze(1).expand(-1,bs, -1)
+ return embddings
diff --git a/uniperceiver/modeling/layers/transformer_encoder_layer.py b/uniperceiver/modeling/layers/transformer_encoder_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..10036583e5b5acc742e031e44a2c903aeb3c2d8f
--- /dev/null
+++ b/uniperceiver/modeling/layers/transformer_encoder_layer.py
@@ -0,0 +1,218 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from typing import Optional, Any, Union, Callable
+from torch import Tensor
+
+from .create_act import get_act_layer, get_activation
+from timm.models.layers import DropPath
+from .layer_norm import LayerNorm
+from .pe_encoder import DeepPrompt
+
+class TransformerEncoderLayer(nn.Module):
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
+ This standard encoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ activation: the activation function of the intermediate layer, can be a string
+ ("relu" or "gelu") or a unary callable. Default: relu
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
+ operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
+ Examples::
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = encoder_layer(src)
+ Alternatively, when ``batch_first`` is ``True``:
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
+ >>> src = torch.rand(32, 10, 512)
+ >>> out = encoder_layer(src)
+ Fast path:
+ forward() will use a special optimized implementation if all of the following
+ conditions are met:
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
+ argument ``requires_grad``
+ - training is disabled (using ``.eval()``)
+ - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
+ - norm_first is ``False`` (this restriction may be loosened in the future)
+ - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
+ - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
+ - if src is a `NestedTensor `_, neither ``src_mask``
+ nor ``src_key_padding_mask`` is passed
+ - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
+ unless the caller has manually modified one without modifying the other)
+ If the optimized implementation is in use, a
+ `NestedTensor `_ can be
+ passed for ``src`` to represent padding more efficiently than using a padding
+ mask. In this case, a `NestedTensor `_ will be
+ returned, and an additional speedup proportional to the fraction of the input that
+ is padding can be expected.
+ """
+ __constants__ = ['batch_first', 'norm_first'] # we inherit this variable from pytorch's code for jit
+
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, drop_path_ratio: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_scale: bool = False, ls_init_values: float = 1e-3,
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
+ device=None, dtype=None, cfg: dict = None) -> None:
+ #
+ factory_kwargs = {}
+ super(TransformerEncoderLayer, self).__init__()
+
+ self.cfg = cfg
+
+ # The interface of nn.MultiheadAttention changed since torch 1.9.0.
+ _torch_version_main = torch.__version__.split('.')[:2]
+ if (int(_torch_version_main[0]) >= 1) and (int(_torch_version_main[1])) >= 9:
+ self._torch_nn_new_interface = True
+ else:
+ self._torch_nn_new_interface = False
+
+ if self._torch_nn_new_interface:
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
+ **factory_kwargs)
+ else:
+ factory_kwargs = {}
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout,
+ **factory_kwargs)
+
+ self.batch_first = batch_first
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ self.norm_first = norm_first
+ if self.cfg.SOLVER.FUSED_LAYERNORM:
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
+ else:
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.drop_path1 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
+ self.drop_path2 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
+
+ self.layer_scale = layer_scale
+ if self.layer_scale:
+ self.gamma_1 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ activation = get_activation(activation)
+
+ self.activation = activation
+
+ # prompt embedding setup
+ self.deep_prompt = self.cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT
+ if self.deep_prompt:
+ self.deep_prompt_embedding = DeepPrompt(cfg)
+
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super(TransformerEncoderLayer, self).__setstate__(state)
+
+ def forward(self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ history_states: Optional[Tensor] = None,
+ **kwargs) -> Tensor:
+ r"""Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ Shape:
+ see the docs in Transformer class.
+ """
+
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
+
+ if self.batch_first and not self._torch_nn_new_interface:
+ x = src.transpose(0,1)
+ if history_states is not None:
+ history_states = history_states.transpose(0,1)
+ else:
+ x = src
+
+ if self.norm_first:
+ history_states_norm = history_states if (history_states is None) else self.norm1(history_states)
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, history_states=history_states_norm, **kwargs)
+ x = x + self._ff_block(self.norm2(x), **kwargs)
+ else:
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, history_states=history_states, **kwargs))
+ x = self.norm2(x + self._ff_block(x), **kwargs)
+
+ if self.batch_first and not self._torch_nn_new_interface:
+ x = x.transpose(0, 1)
+
+ return x
+
+ # self-attention block
+ def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], history_states: Optional[Tensor],
+ **kwargs) -> Tensor:
+
+ if history_states is not None:
+ kv = torch.cat(
+ [history_states, x],
+ dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0
+ )
+ # TODO: changes for attn_mask and key_padding_mask
+ else:
+ kv = x
+
+ if self.deep_prompt:
+
+ deep_prompt_embedding = self.deep_prompt_embedding(x, batch_first=(self.batch_first and self._torch_nn_new_interface), **kwargs)
+ if self.norm_first:
+ deep_prompt_embedding = self.norm1(deep_prompt_embedding)
+ kv = torch.cat([deep_prompt_embedding, kv], dim=1 if (self.batch_first and self._torch_nn_new_interface) else 0)
+ if attn_mask is not None:
+ L, S = attn_mask.shape
+ pe_length = deep_prompt_embedding.shape[1 if
+ (self.batch_first and self._torch_nn_new_interface) else 0] # length, bs, hidden_size
+ attn_mask = torch.cat([torch.zeros((L, pe_length), dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=1)
+ if key_padding_mask is not None:
+ if self.batch_first and self._torch_nn_new_interface:
+ bs, pe_length = deep_prompt_embedding.shape[:2]
+ else:
+ pe_length, bs = deep_prompt_embedding.shape[:2]
+ key_padding_mask = torch.cat(
+ [torch.zeros((bs, pe_length), dtype=key_padding_mask.dtype, device=key_padding_mask.device), key_padding_mask], dim=1)
+
+
+ x = self.self_attn(x, kv, kv,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask,
+ need_weights=False)[0]
+ x = self.drop_path1(self.dropout1(x))
+ if self.layer_scale:
+ x = self.gamma_1 * x
+ return x
+
+
+ # feed forward block
+ def _ff_block(self, x: Tensor, **kwargs) -> Tensor:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ x = self.drop_path2(self.dropout2(x))
+ if self.layer_scale:
+ x = self.gamma_2 * x
+ return x
diff --git a/uniperceiver/modeling/layers/transformer_encoder_moe_layer.py b/uniperceiver/modeling/layers/transformer_encoder_moe_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..73f3b35b04cc05e0aef85b0ebeb59126b89243f2
--- /dev/null
+++ b/uniperceiver/modeling/layers/transformer_encoder_moe_layer.py
@@ -0,0 +1,406 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from typing import Optional, Any, Union, Callable
+from torch import Tensor
+
+from .create_act import get_act_layer, get_activation
+from timm.models.layers import DropPath
+from .layer_norm import LayerNorm
+from .pe_encoder import DeepPrompt
+from uniperceiver.task_moe.layer import TaskMoE
+from uniperceiver.utils import comm
+from functools import partial
+import math
+from uniperceiver.modeling.layers import FP16LayerNorm
+from torch.cuda.amp import autocast
+
+class MoETransformerEncoderLayer(nn.Module):
+ r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
+ This standard encoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ activation: the activation function of the intermediate layer, can be a string
+ ("relu" or "gelu") or a unary callable. Default: relu
+ layer_norm_eps: the eps value in layer normalization components (default=1e-5).
+ batch_first: If ``True``, then the input and output tensors are provided
+ as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
+ norm_first: if ``True``, layer norm is done prior to attention and feedforward
+ operations, respectivaly. Otherwise it's done after. Default: ``False`` (after).
+ Examples::
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
+ >>> src = torch.rand(10, 32, 512)
+ >>> out = encoder_layer(src)
+ Alternatively, when ``batch_first`` is ``True``:
+ >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
+ >>> src = torch.rand(32, 10, 512)
+ >>> out = encoder_layer(src)
+ Fast path:
+ forward() will use a special optimized implementation if all of the following
+ conditions are met:
+ - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
+ argument ``requires_grad``
+ - training is disabled (using ``.eval()``)
+ - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
+ - norm_first is ``False`` (this restriction may be loosened in the future)
+ - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
+ - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
+ - if src is a `NestedTensor `_, neither ``src_mask``
+ nor ``src_key_padding_mask`` is passed
+ - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
+ unless the caller has manually modified one without modifying the other)
+ If the optimized implementation is in use, a
+ `NestedTensor `_ can be
+ passed for ``src`` to represent padding more efficiently than using a padding
+ mask. In this case, a `NestedTensor `_ will be
+ returned, and an additional speedup proportional to the fraction of the input that
+ is padding can be expected.
+ """
+ __constants__ = ['batch_first', 'norm_first'] # we inherit this variable from pytorch's code for jit
+
+ def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, drop_path_ratio: float = 0.1,
+ activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_scale: bool = False, ls_init_values: float = 1e-3,
+ layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
+ device=None, dtype=None, cfg: dict = None, ffn_moe: bool = False, attn_moe: bool = False) -> None:
+
+ if batch_first and comm.is_main_process():
+ print(f'set batch_first to \'False\' to support torch >= 1.12!')
+ batch_first = False
+
+ factory_kwargs = {}
+ super(MoETransformerEncoderLayer, self).__init__()
+
+ self.cfg = cfg
+
+ # The interface of nn.MultiheadAttention changed since torch 1.9.0.
+ _torch_version_main = torch.__version__.split('.')[:2]
+ if (int(_torch_version_main[0]) >= 1) and (int(_torch_version_main[1])) >= 9:
+ self._torch_nn_new_interface = True
+ else:
+ self._torch_nn_new_interface = False
+
+ # for moe
+ self.ffn_moe = ffn_moe and self.cfg.MOE.MOE
+ self.attn_moe = attn_moe and self.cfg.MOE.MOE
+ if self.cfg.MOE.MOE:
+ # assert self.ffn_moe and self.attn_moe
+ # data-independent moe
+ if self.cfg.MOE.MOE_TYPE in ['attribute']:
+ MoE_layer = partial(
+ TaskMoE,
+ num_experts=cfg.MOE.NUM_EXPERTS,
+ k=cfg.MOE.TOP_K,
+ capacity_factor=cfg.MOE.CAPACITY_FACTOR,
+ eval_capacity_factor=cfg.MOE.EVAL_MIN_CAPACITY,
+ min_capacity=cfg.MOE.MIN_CAPACITY,
+ noisy_gate_policy=cfg.MOE.NOISY_GATE_POLICY,
+ use_rts=cfg.MOE.USE_RTS,
+ use_tutel=cfg.MOE.USE_TUTEL,
+ cfg=cfg,
+ )
+ else:
+ raise NotImplementedError(f'{self.cfg.MOE.MOE_TYPE}')
+
+
+
+ self.self_attn = MoEAttentionBlock(d_model, nhead, attention_probs_dropout_prob=dropout, cfg=cfg, moe_layer=MoE_layer, attn_moe=attn_moe)
+
+
+ self.batch_first = batch_first
+
+ # Implementation of Feedforward model
+ self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
+ self.dropout = nn.Dropout(dropout)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
+
+ if self.ffn_moe:
+ self.linear1 = MoE_layer(hidden_size=d_model, expert=self.linear1)
+ self.linear2 = MoE_layer(hidden_size=d_model, expert=self.linear2)
+
+ self.norm_first = norm_first
+ if self.cfg.SOLVER.FUSED_LAYERNORM:
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps)
+ elif self.cfg.SOLVER.FORCE_LN_FP16:
+ self.norm1 = FP16LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = FP16LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ else:
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
+
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.drop_path1 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
+ self.drop_path2 = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
+
+ self.layer_scale = layer_scale
+ if self.layer_scale:
+ self.gamma_1 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
+ self.gamma_2 = nn.Parameter(ls_init_values * torch.ones((d_model)),requires_grad=True)
+
+ # Legacy string support for activation function.
+ if isinstance(activation, str):
+ activation = get_activation(activation)
+
+ self.activation = activation
+
+ # prompt embedding setup
+ self.deep_prompt = self.cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT
+ if self.deep_prompt:
+ self.deep_prompt_embedding = DeepPrompt(cfg)
+
+
+
+
+ def __setstate__(self, state):
+ if 'activation' not in state:
+ state['activation'] = F.relu
+ super(MoETransformerEncoderLayer, self).__setstate__(state)
+
+ def forward(self,
+ src: Tensor,
+ src_mask: Optional[Tensor] = None,
+ src_key_padding_mask: Optional[Tensor] = None,
+ history_states: Optional[Tensor] = None,
+ **kwargs) -> Tensor:
+ r"""Pass the input through the encoder layer.
+ Args:
+ src: the sequence to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ Shape:
+ see the docs in Transformer class.
+ """
+
+ # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
+
+ x = src
+
+ if self.norm_first:
+ history_states_norm = history_states if (history_states is None) else self.norm1(history_states)
+ x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, history_states=history_states_norm, **kwargs)
+ x = x + self._ff_block(self.norm2(x), **kwargs)
+ else:
+ x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, history_states=history_states, **kwargs))
+ x = self.norm2(x + self._ff_block(x), **kwargs)
+
+
+ return x
+
+ # self-attention block
+ def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], history_states: Optional[Tensor],
+ **kwargs) -> Tensor:
+
+ if history_states is not None:
+ kv = torch.cat(
+ [history_states, x],
+ dim=1
+ )
+ # TODO: changes for attn_mask and key_padding_mask
+ else:
+ kv = None
+
+ if self.deep_prompt:
+
+ deep_prompt_embedding = self.deep_prompt_embedding(x, batch_first=True, **kwargs)
+ if self.norm_first:
+ deep_prompt_embedding = self.norm1(deep_prompt_embedding)
+ kv = torch.cat([deep_prompt_embedding, x], dim=1) if kv is None else torch.cat([deep_prompt_embedding, kv], dim=1)
+ if 'sample_info' in kwargs:
+ pe_length = deep_prompt_embedding.shape[1]
+ kwargs['sample_info']['pe_length'] = pe_length
+ if attn_mask is not None:
+ L, S = attn_mask.shape
+ pe_length = deep_prompt_embedding.shape[1] # length, bs, hidden_size
+ attn_mask = torch.cat([torch.zeros((L, pe_length), dtype=attn_mask.dtype, device=attn_mask.device), attn_mask], dim=1)
+ if key_padding_mask is not None:
+
+ bs, pe_length = deep_prompt_embedding.shape[:2]
+
+ key_padding_mask = torch.cat(
+ [torch.zeros((bs, pe_length), dtype=key_padding_mask.dtype, device=key_padding_mask.device), key_padding_mask], dim=1)
+
+
+ x, _ = self.self_attn(x, history_states=kv, attn_mask=attn_mask, key_padding_mask=key_padding_mask, **kwargs)
+ x = self.drop_path1(self.dropout1(x))
+ if self.layer_scale:
+ if self.cfg.MODEL.LAYER_SCALE_FP32:
+ x = self.gamma_1 * x
+ else:
+ x = self.gamma_1.to(x.dtype) * x
+ return x
+
+
+ # feed forward block
+ def _ff_block(self, x: Tensor, **kwargs) -> Tensor:
+ if self.ffn_moe:
+ x, gate_decision = self.linear1(x, **kwargs)
+ if not self.cfg.MOE.FFN_SHARE_GATE_DECISION:
+ gate_decision = None
+ x, _ = self.linear2(self.dropout(self.activation(x)), gate_decision=gate_decision, **kwargs)
+ else:
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
+ x = self.drop_path2(self.dropout2(x))
+ if self.layer_scale:
+ if self.cfg.MODEL.LAYER_SCALE_FP32:
+ x = self.gamma_2 * x
+ else:
+ x = self.gamma_2.to(x.dtype) * x
+ return x
+
+
+class MoEAttentionBlock(nn.Module):
+
+ def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, cfg, moe_layer=None, attn_moe=False):
+ super(MoEAttentionBlock, self).__init__()
+ self.cfg = cfg
+ if hidden_size % num_attention_heads != 0:
+ raise ValueError("The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads))
+
+ self.hidden_size = hidden_size
+
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_size = int(hidden_size / num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.qkv_bias = cfg.MODEL.BERT.QKV_BIAS
+
+ self.unify_qkv = cfg.MODEL.BERT.UNIFY_QKV
+
+ if not cfg.MODEL.BERT.UNIFY_QKV:
+ self.query = nn.Linear(hidden_size, self.all_head_size, bias=self.qkv_bias)
+ self.key = nn.Linear(hidden_size, self.all_head_size, bias=self.qkv_bias)
+ self.value = nn.Linear(hidden_size, self.all_head_size, bias=self.qkv_bias)
+ else:
+ self.qkv_proj = nn.Linear(hidden_size, self.all_head_size * 3, bias=self.qkv_bias)
+
+ self.dense = nn.Linear(hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(attention_probs_dropout_prob)
+
+ self.attn_moe = attn_moe
+ if self.attn_moe:
+ if not cfg.MODEL.BERT.UNIFY_QKV:
+ raise NotADirectoryError('use UNIFY_QKV=True please')
+ else:
+ self.qkv_proj = moe_layer(hidden_size=hidden_size, expert=self.qkv_proj)
+ self.dense = moe_layer(hidden_size=hidden_size, expert=self.dense)
+
+ self.scale_multi_before = cfg.MODEL.BERT.SCALE_MULTI_BEFORE
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (
+ self.num_attention_heads,
+ self.attention_head_size,
+ )
+ x = x.view(*new_x_shape)
+
+ shape_list = list(range(len(new_x_shape)))
+ shape_list[-2], shape_list[-3] = shape_list[-3], shape_list[-2]
+ return x.permute(shape_list)
+ #return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attn_mask, key_padding_mask, history_states=None, **kwargs):
+ if attn_mask is None and key_padding_mask is None:
+ attention_mask = None
+ else:
+ # attn_mask [L, S] key_padding_mask[N, S]
+ if attn_mask is not None and key_padding_mask is not None:
+ attention_mask = torch.logical_or(attn_mask.unsqueeze(0).bool(), key_padding_mask.unsqueeze(1).bool())
+ elif attn_mask is not None:
+ attention_mask = attn_mask.unsqueeze(0)
+ else:
+ attention_mask = key_padding_mask.unsqueeze(1)
+ if attention_mask is not None:
+ attention_mask = attention_mask.unsqueeze(1) * -10000.0
+
+ if self.unify_qkv:
+ if history_states is None:
+
+ B, N, C = hidden_states.shape
+ if self.attn_moe:
+ # qkv, _, _ = self.self.qkv_proj(hidden_states)
+ hidden_states, gate_decision = self.qkv_proj(hidden_states, **kwargs)
+ mixed_query_layer, mixed_key_layer, mixed_value_layer =hidden_states.chunk(3, dim=-1)
+ else:
+ mixed_query_layer, mixed_key_layer, mixed_value_layer = self.qkv_proj(hidden_states).chunk(3, dim=-1)
+
+ else:
+ # usually inference with history embedding
+ if self.attn_moe:
+
+ mixed_query_layer, gate_decision = self.qkv_proj(hidden_states, mode='q', **kwargs)
+
+ history_states = self.qkv_proj(history_states, mode='kv', gate_decision=gate_decision, **kwargs)[0]
+ mixed_key_layer, mixed_value_layer = history_states.chunk(2, dim=-1)
+
+ else:
+ # query
+ _start = 0
+ _end = self.hidden_size
+ mixed_query_layer = F.linear(hidden_states,
+ self.qkv_proj.weight[_start:_end, :],
+ bias=None if self.qkv_proj.bias is None else self.qkv_proj.bias[_start:_end])
+
+ # key and value
+ # torch.equal(key, value)
+ _start = self.hidden_size
+ mixed_key_layer, mixed_value_layer = F.linear(history_states,
+ self.qkv_proj.weight[_start:, :],
+ bias=None if self.qkv_proj.bias is None else self.qkv_proj.bias[_start:]).chunk(
+ 2, dim=-1)
+
+
+ else:
+ raise NotImplementedError('please use unify qkv_proj')
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ if self.scale_multi_before:
+ attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))
+ else:
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ if self.cfg.SOLVER.FORCE_SOFTMAX_FP16:
+ with autocast(enabled=False):
+ attention_probs = F.softmax(attention_scores.half(), dim=-1)
+ else:
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+
+ shape_list = list(range(len(context_layer.shape)))
+ shape_list[-2], shape_list[-3] = shape_list[-3], shape_list[-2]
+ context_layer = context_layer.permute(shape_list).contiguous()
+ #context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size, )
+ context_layer = context_layer.view(*new_context_layer_shape)
+
+ if self.attn_moe:
+ context_layer, _ = self.dense(context_layer, gate_decision=gate_decision, **kwargs)
+ else:
+ context_layer = self.dense(context_layer)
+
+ return context_layer, attention_probs
diff --git a/uniperceiver/modeling/meta_arch/__init__.py b/uniperceiver/modeling/meta_arch/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ee68a61e56ab4b8df05a7c42309e8d8873bb3b0
--- /dev/null
+++ b/uniperceiver/modeling/meta_arch/__init__.py
@@ -0,0 +1,7 @@
+
+
+from .build import META_ARCH_REGISTRY, build_model, add_config
+from .unified_transformer import MultiTaskTransformerEncoder
+
+
+__all__ = list(globals().keys())
diff --git a/uniperceiver/modeling/meta_arch/base_enc_dec.py b/uniperceiver/modeling/meta_arch/base_enc_dec.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f6cfe0f905a7d83abbd75e80fb893a86e50abfb
--- /dev/null
+++ b/uniperceiver/modeling/meta_arch/base_enc_dec.py
@@ -0,0 +1,78 @@
+import copy
+import numpy as np
+import weakref
+import torch
+from torch import nn
+from torch.autograd import Variable
+import torch.nn.functional as F
+from abc import ABCMeta, abstractmethod
+
+from uniperceiver.config import configurable
+from uniperceiver.config import CfgNode as CN
+from uniperceiver.functional import pad_tensor, dict_to_cuda, flat_list_of_lists
+from ..embedding import build_embeddings
+from ..encoder import build_encoder, add_encoder_config
+# from ..decoder import build_decoder, add_decoder_config
+from ..predictor import build_predictor, add_predictor_config
+from ..decode_strategy import build_beam_searcher, build_greedy_decoder
+
+class BaseEncoderDecoder(nn.Module, metaclass=ABCMeta):
+ @configurable
+ def __init__(
+ self,
+ *,
+ vocab_size,
+ max_seq_len,
+ token_embed,
+ fused_encoder,
+ decoder,
+ greedy_decoder,
+ beam_searcher,
+ **kwargs,
+ ):
+ super(BaseEncoderDecoder, self).__init__()
+ self.fused_encoder = fused_encoder
+ self.decoder = decoder
+
+ self.token_embed = token_embed
+ self.greedy_decoder = greedy_decoder
+ self.beam_searcher = beam_searcher
+ self.vocab_size = vocab_size
+ self.max_seq_len = max_seq_len
+
+
+ @classmethod
+ def add_config(cls, cfg, tmp_cfg):
+ add_encoder_config(cfg, tmp_cfg)
+ add_predictor_config(cfg, tmp_cfg)
+
+ def forward(self, batched_inputs, use_beam_search=None, output_sents=False):
+ if use_beam_search is None:
+ return self._forward(batched_inputs)
+ # elif use_beam_search == False or self.beam_searcher.beam_size == 1:
+ elif use_beam_search == False:
+ return self.greedy_decode(batched_inputs, output_sents)
+ else:
+ return self.decode_beam_search(batched_inputs, output_sents)
+
+ @abstractmethod
+ def _forward(self, batched_inputs):
+ pass
+
+ def bind_or_init_weights(self):
+ pass
+
+
+ def greedy_decode(self, batched_inputs, output_sents=False):
+ return self.greedy_decoder(
+ batched_inputs,
+ output_sents,
+ model=weakref.proxy(self)
+ )
+
+ def decode_beam_search(self, batched_inputs, output_sents=False):
+ return self.beam_searcher(
+ batched_inputs,
+ output_sents,
+ model=weakref.proxy(self)
+ )
diff --git a/uniperceiver/modeling/meta_arch/build.py b/uniperceiver/modeling/meta_arch/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..14316f8d995f9c1a32785c709ea4b09e7aa0e675
--- /dev/null
+++ b/uniperceiver/modeling/meta_arch/build.py
@@ -0,0 +1,31 @@
+
+# Copyright (c) Facebook, Inc. and its affiliates.
+import torch
+
+from uniperceiver.utils.logger import _log_api_usage
+from uniperceiver.utils.registry import Registry
+
+META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip
+META_ARCH_REGISTRY.__doc__ = """
+Registry for meta-architectures, i.e. the whole model.
+
+The registered object will be called with `obj(cfg)`
+and expected to return a `nn.Module` object.
+"""
+
+
+def build_model(cfg):
+ """
+ Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``.
+ Note that it does not load any weights from ``cfg``.
+ """
+ meta_arch = cfg.MODEL.META_ARCHITECTURE
+ model = META_ARCH_REGISTRY.get(meta_arch)(cfg)
+ model.to(torch.device(cfg.MODEL.DEVICE))
+ _log_api_usage("modeling.meta_arch." + meta_arch)
+ return model
+
+
+def add_config(cfg, tmp_cfg):
+ meta_arch = tmp_cfg.MODEL.META_ARCHITECTURE
+ META_ARCH_REGISTRY.get(meta_arch).add_config(cfg, tmp_cfg)
diff --git a/uniperceiver/modeling/meta_arch/unified_transformer.py b/uniperceiver/modeling/meta_arch/unified_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..013205dbe2c3e3b8a4e800ee65fccf7d7c170760
--- /dev/null
+++ b/uniperceiver/modeling/meta_arch/unified_transformer.py
@@ -0,0 +1,469 @@
+import os
+import pickle
+import torch
+from torch import nn
+from torch.autograd import Variable
+import torch.nn.functional as F
+import weakref
+
+from uniperceiver.utils.transformer_util import data_half, preprocess, postprocess, null_loss_check
+from uniperceiver.config import configurable
+from uniperceiver.functional import pad_tensor, dict_to_cuda, dict_as_tensor
+from ..predictor import build_v_predictor
+from .build import META_ARCH_REGISTRY
+from ..embedding import build_embeddings
+from ..encoder import build_encoder, add_encoder_config, build_unfused_encoders
+from ..predictor import build_predictor, add_predictor_config
+from collections import defaultdict
+from omegaconf import DictConfig
+from ..decode_strategy import build_beam_searcher, build_greedy_decoder
+from .base_enc_dec import BaseEncoderDecoder
+from uniperceiver.modeling.predictor import EmbedClsAsRetrievalPredictor
+from torch.nn import init
+import math
+from uniperceiver.utils import comm
+import torch.distributed.nn
+from uniperceiver.tokenization import ClipTokenizer
+import logging
+from uniperceiver.losses import build_losses
+
+
+__all__ = ["MultiTaskTransformerEncoder"]
+
+
+@META_ARCH_REGISTRY.register()
+class MultiTaskTransformerEncoder(BaseEncoderDecoder):
+
+ @configurable
+ def __init__(
+ self,
+ *,
+ task_modules,
+ fused_encoder,
+ unfused_encoders,
+ decoder,
+ token_embed,
+ video_embed,
+ prompt_embed,
+ loss_prepare,
+ vocab_size,
+ imagenet_tuning,
+ cfg,
+ ):
+ super().__init__(fused_encoder=fused_encoder,
+ decoder=decoder,
+ vocab_size=vocab_size,
+ token_embed=token_embed,
+ **list(task_modules.values())[0])
+
+ self.unfused_encoders = unfused_encoders
+ for name, module in self.unfused_encoders.items():
+ self.add_module(name, module)
+ self.video_embed = video_embed
+ self.prompt_embed = prompt_embed
+ self.task_modules = dict()
+ self.module_names = set()
+ self.imagenet_tuning = imagenet_tuning
+ self.cfg = cfg
+
+ self.losses = self.build_losses(cfg)
+
+ self.tokenizer = ClipTokenizer()
+
+ self.loss_prepare = loss_prepare
+
+
+ for task_name, task_module in task_modules.items():
+ self.task_modules[task_name] = nn.Module()
+ for module_name, sub_module in task_module.items():
+ setattr(self.task_modules[task_name], module_name, sub_module)
+ self.module_names.add(module_name)
+ self.process_module(sub_module)
+ self.add_module(task_name,self.task_modules[task_name])
+
+
+
+ if self.cfg.MODEL.SHARE_LAYERNORM:
+ from uniperceiver.utils.transformer_util import share_token_embed_ln
+ share_token_embed_ln(self.video_embed, self.token_embed)
+
+ self.prepare_prompt_embed(cfg)
+
+ self.fp16 = self.cfg.SOLVER.AMP_FP16
+ self.bf16 = self.cfg.SOLVER.BF16
+
+
+
+ if self.token_embed is None:
+ # used for standard classification head
+ self.cls_token = nn.Embedding(1,cfg.MODEL.BERT.HIDDEN_SIZE)
+
+
+ self.initialize(cfg)
+
+ # init fc prompt layer
+ if self.use_fc_prompt and self.prompt:
+ nn.init.zeros_(self.fc_prompt.weight)
+ nn.init.zeros_(self.fc_prompt.bias)
+
+
+ self.logger = logging.getLogger(__name__)
+
+ if not self.cfg.MODEL.OLD_CHECKPONT:
+ comm.old_checkpoint = False
+ self.logger.info(f'please note that the <|spe|> is \'spe\' now!')
+
+ def prepare_prompt_embed(self, cfg):
+
+ self.prompt = cfg.MODEL.PROMPT
+ self.deep_prompt = cfg.MODEL.PROMPT_EMBED.DEEP_PROMPT
+ self.use_fc_prompt = cfg.MODEL.FC_PROMPT
+ prompt_params = cfg.MODEL.PROMPT_PARAM
+ fc_prompt_out = cfg.MODEL.FC_PROMPT_OUT
+ fc_prompt_weights = cfg.MODEL.FC_PROMPT_WEIGHTS
+
+ if self.prompt and 's_token_bias' in prompt_params:
+ self.s_token_bias = nn.Parameter(torch.zeros((1, self.token_embed.embeddings.weight.size(1)), device=self.token_embed.embeddings.weight.device))
+ self.token_embed.set_s_token_bias(self.s_token_bias)
+
+ if self.use_fc_prompt:
+ self.fc_prompt = nn.Linear(self.cfg.MODEL.BERT.HIDDEN_SIZE, fc_prompt_out)
+ if fc_prompt_weights == 'learn':
+ self.similarity_weight = nn.Parameter(torch.ones([]))
+ elif fc_prompt_weights == 'zero':
+ self.similarity_weight = 0.
+ else:
+ raise NotImplementedError
+
+ if self.prompt:
+ for name, param in self.named_parameters():
+ if not any([p_param in name for p_param in prompt_params]):
+ param.requires_grad = False
+
+
+ def initialize(self, cfg ):
+ if cfg.MODEL.TimmParamsInit:
+ global INIT_STD
+ INIT_STD = cfg.MODEL.TimmParamsInitSTD
+ global INIT_EMBEDDING_STD
+ INIT_EMBEDDING_STD = cfg.MODEL.TimmParamsINIT_EMBEDDING_STD
+ from uniperceiver.utils.transformer_util import init_timm_params
+ self.apply(init_timm_params)
+ elif cfg.MODEL.MAEParamsInit:
+ from uniperceiver.utils.transformer_util import initialize_weights_as_mae
+ initialize_weights_as_mae(self)
+ elif cfg.MODEL.MOCOv3ParamsInit:
+ from uniperceiver.utils.transformer_util import initialize_weights_as_mocov3
+ initialize_weights_as_mocov3(self)
+ elif cfg.MODEL.SwitchParamsInit:
+ from uniperceiver.utils.transformer_util import init_switchtransformer_params
+ self.apply(init_switchtransformer_params)
+ elif cfg.MODEL.BertParamsInit:
+ from uniperceiver.utils.transformer_util import init_bert_params
+ self.apply(init_bert_params)
+ elif cfg.MODEL.UniformTokenEmbed:
+ init.kaiming_uniform_(self.token_embed.embeddings.weight, a=math.sqrt(5))
+ else:
+ print('please check your parameters initialization method!')
+
+ @classmethod
+ def build_losses(cls, cfg):
+ losses = {}
+ for task_config in cfg.TASKS:
+ task_config = DictConfig(task_config)
+ losses[task_config.NAME] = build_losses(task_config)
+
+ return losses
+
+ def process_module(self, submodule):
+ '''
+ process some submodule
+ '''
+ if isinstance(submodule, EmbedClsAsRetrievalPredictor):
+ submodule.replace_weight(self.token_embed.embeddings.weight)
+
+
+ def operatedweight(self, ):
+ pass
+
+
+ @classmethod
+ def from_config(cls, cfg):
+ task_names = [ a['NAME'] for a in cfg.TASKS]
+ task_modules = defaultdict(dict)
+
+ for idx, task_names in enumerate(task_names):
+ cfg_task = DictConfig(cfg.TASKS[idx])
+ this_task_modules = {
+
+ "greedy_decoder": None,
+ "beam_searcher": None if getattr(cfg_task, 'DECODE_STRATEGY', None) is None
+ else build_beam_searcher(cfg_task),
+ # "vocab_size": cfg_task.MODEL.VOCAB_SIZE,
+ "max_seq_len": cfg_task.MODEL.MAX_SEQ_LEN,
+ }
+
+ task_modules[task_names].update(this_task_modules)
+
+ if cfg.SOLVER.AUGLOSS:
+ num_augloss = (cfg.MODEL.BERT.NUM_HIDDEN_LAYERS - max(
+ 0, cfg.SOLVER.AUGLOSS_START)) // cfg.SOLVER.AUGLOSS_INTERVAL
+ ret = {
+ "task_modules":
+ task_modules,
+ "fused_encoder":
+ build_encoder(cfg),
+ "unfused_encoders":
+ build_unfused_encoders(cfg),
+ "decoder":
+ None,
+ "loss_prepare":
+ build_predictor(cfg) if not cfg.SOLVER.AUGLOSS else nn.ModuleList(build_predictor(cfg) for _ in range(num_augloss)),
+ "vocab_size":
+ cfg.MODEL.VOCAB_SIZE,
+ "prompt_embed":
+ None if getattr(cfg.MODEL, 'PROMPT_EMBED', None) is None or not cfg.MODEL.PROMPT else build_embeddings(
+ cfg, cfg.MODEL.PROMPT_EMBED.NAME),
+ "imagenet_tuning":
+ cfg.MODEL.IN_TUNING,
+
+ "token_embed": None if not getattr(cfg.MODEL.TOKEN_EMBED, 'NAME', None)
+ else build_embeddings(cfg, cfg.MODEL.TOKEN_EMBED.NAME),
+ "video_embed": None if not getattr(cfg.MODEL.VIDEO_EMBED, 'NAME', None)
+ else build_embeddings(cfg, cfg.MODEL.VIDEO_EMBED.NAME),
+ "cfg": cfg,
+ }
+
+
+ return ret
+
+ @classmethod
+ def add_config(cls, cfg, tmp_cfg):
+ add_encoder_config(cfg, tmp_cfg)
+ # we do not have decoder anymore
+ # add_decoder_config(cfg, tmp_cfg)
+ cfg.MODEL.SharePredictor = False
+ cfg.MODEL.UniformTokenEmbed = False
+ cfg.MODEL.BertParamsInit = False
+
+ def to_task(self, task_name):
+ # in train_loop, you do not need to reset_atrr explictly
+ self.reset_attr()
+ for name in self.module_names:
+ setattr(self, name, getattr(self.task_modules[task_name], name))
+
+ def reset_attr(self):
+ for name in self.module_names:
+ # in case different task has different modules
+ if getattr(self, name, 'none') != 'none':
+ delattr(self, name)
+
+
+ def _forward(self, batched_inputs):
+
+
+ batched_inputs = data_half(self.fp16, self.bf16, batched_inputs)
+
+ #TODO: add imagenet classname and word in evaluation mode
+
+ task_info = batched_inputs['task_info']
+
+
+
+ batched_inputs['input_sample_list'] = self._forward_data(
+ batched_inputs['input_sample_list'], task_info=task_info)
+
+ if batched_inputs['target_sample_list'] is not None and len(batched_inputs['target_sample_list']) > 0:
+ batched_inputs['target_sample_list'] = self._forward_data(batched_inputs['target_sample_list'], task_info=task_info)
+
+
+ for target_set_name, data_list in batched_inputs['shared_target_sets'].items():
+ if data_list is not None and len(data_list)>0:
+ batched_inputs['shared_target_sets'][target_set_name] = self._forward_data(data_list, task_info=task_info)
+
+ loss_inputs = self.loss_prepare(**batched_inputs)
+
+ self.fc_prompt_process(loss_inputs)
+
+ if self.training:
+ # training mode
+ loss_dict = {}
+ for loss in self.losses[task_info['task_name']]:
+ loss_dict.update(loss(loss_inputs))
+
+ # if self.load_balance_losses is not None:
+ # loss_dict.update(self.load_balance_losses(batched_inputs))
+
+ loss_dict.update(null_loss_check(outputs_dict=batched_inputs))
+ return loss_dict
+ else:
+ # evaluation mode
+ return loss_inputs
+
+ def fc_prompt_process(self, outputs_dict):
+ if self.prompt and self.use_fc_prompt:
+ for idx, logit in enumerate(outputs_dict['logits']):
+ assert 'feats' in outputs_dict
+ feat = outputs_dict['feats'][idx]
+ logit = self.similarity_weight * logit + self.fc_prompt(feat)
+ outputs_dict['logits'][idx] = logit
+ if 'output' in outputs_dict:
+ outputs_dict['output'] = logit
+
+
+
+ def _forward_data(self, data_list:list, task_info:dict, history_states=None, return_all=False):
+
+ # data is dict value
+ for data in data_list:
+
+ data = data_half(self.fp16, self.bf16, data)
+
+ self._tokenize(data, task_info)
+
+ self._forward_unfused_encoders(data, task_info)
+
+ # fused encoders
+ if self.prompt_embed is not None:
+ # prefix_prompt, label prompt
+ self.prompt_embed(data_list=data_list)
+ fused_data_dict = preprocess(self.tokenizer, self.token_embed, data_list, task_info=task_info)
+
+ fused_data_dict = data_half(self.fp16, self.bf16, fused_data_dict)
+ fused_data_dict['data'] = self.fused_encoder(**fused_data_dict, task_info=task_info, history_states=history_states, return_all=return_all)
+
+ postprocess(fused_data_dict, task_info=task_info)
+
+ return [fused_data_dict]
+
+ def _tokenize(self, data, task_info):
+ # toknizer
+ if data['modality'] in ['image', 'video']:
+ data['data'] = self.video_embed(**data, task_info=task_info)
+ elif data['modality'] == 'text':
+ data['data'] = self.token_embed(**data, task_info=task_info)
+ else:
+ raise NotImplementedError
+
+
+ def _forward_unfused_encoders(self, data, task_info):
+
+
+ # specific encoders.
+ # defaultly, modality-specific encoder
+ if data['modality'] in ['image', 'video']:
+ if "VisualEncoder" in self.unfused_encoders:
+ data['data'] = self.unfused_encoders['VisualEncoder'](**data, task_info=task_info)
+ elif data['modality'] == 'text':
+ if "TextEncoder" in self.unfused_encoders:
+ data['data'] = self.unfused_encoders['TextEncoder'](**data, task_info=task_info)
+ else:
+ raise NotImplementedError
+
+
+
+
+
+ @torch.jit.ignore
+ def no_weight_decay(self,):
+ ret = [
+ 'logit_scale', 'logit_scale_img_cls', 'logit_scale_video_cls',
+ 'logit_scale_text_mlm', 'logit_scale_text_caption',
+ 'logit_scale_caption', 'logit_scale_mlm', 'logit_scale_retrieve',
+ 'logit_scale_text_retrieve', "logit_scale_downstream",
+ "logit_scale_tqa_mlm", "logit_scale_tqa_caption",
+ "logit_scale_tqa_retrieve", "similarity_weight", "gamma_1", "gamma_2",
+ ]
+ if self.cfg.SOLVER.OUTPUTPROJ_NOWD:
+ ret.append("predictor.proj")
+ return ret
+
+ @torch.jit.ignore
+ def expert_gate_group(self, ):
+ return ['gate.wg', 'gate.tag_transform']
+
+
+
+ def load_state_dict(self, state_dict, strict=True):
+ out_dict = {}
+ if self.cfg.MODEL.CHECKPOINT_FILETER:
+ def resize_pos_embed(posemb, posemb_new, cls_token=False):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ self.logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[0]
+ posemb_tok = posemb
+ if not cls_token:
+ posemb_grid = posemb
+ else:
+ raise NotImplementedError
+ gs_old = int(math.sqrt(len(posemb_grid)))
+ gs_new = int(math.sqrt(ntok_new))
+
+
+ self.logger.info('Position embedding grid-size from %s to %s',
+ gs_old, gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
+ posemb_grid = F.interpolate(posemb_grid.float(), size=(gs_new, gs_new), mode='bilinear')
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1).squeeze(0)
+ if cls_token:
+ posemb_grid = torch.cat([posemb_tok, posemb_grid], dim=1)
+ return posemb_grid.to(posemb_new.dtype)
+ # 'convert patch embedding weight from manual patchify'
+
+ for k, v in state_dict.items():
+ if k.startswith('video_embed.embeddings_st_pos.spatial_pos_embed') or k.startswith('visual_embed.patch_embed.pos_embed'):
+ # To resize pos embedding when using model at different size from pretrained weights
+ if v.shape != self.state_dict()[k].shape:
+ v = resize_pos_embed(v, self.state_dict()[k])
+
+ out_dict[k] = v
+ else:
+
+ for k, v in state_dict.items():
+ if k.startswith('video_embed.embeddings_st_pos.spatial_pos_embed') or k.startswith('visual_embed.patch_embed.pos_embed'):
+ # To resize pos embedding when using model at different size from pretrained weights
+ if v.shape != self.state_dict()[k].shape:
+ # v = resize_pos_embed(v, self.state_dict()[k])
+ continue
+ out_dict[k] = v
+
+ if self.cfg.MODEL.CHECKPOINT_FILETER_VIDEO:
+
+ def resize_temporal_pos_embed(posemb, posemb_new, cls_token=False):
+ # Rescale the grid of position embeddings when loading from state_dict. Adapted from
+ # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
+ self.logger.info('Resized position embedding: %s to %s',
+ posemb.shape, posemb_new.shape)
+ ntok_new = posemb_new.shape[0]
+ if not cls_token:
+ posemb_grid = posemb
+ else:
+ raise NotImplementedError
+ gs_old = len(posemb_grid)
+ gs_new = ntok_new
+
+ self.logger.info('temporal embedding grid-size from %s to %s',
+ gs_old, gs_new)
+ posemb_grid = posemb_grid.reshape(1, gs_old,
+ -1).permute(0, 2, 1)
+ posemb_grid = F.interpolate(posemb_grid.float(),
+ size=(gs_new),
+ mode='linear')
+ posemb_grid = posemb_grid.permute(0, 2, 1).squeeze(0)
+
+ return posemb_grid.to(posemb_new.dtype)
+
+ # 'convert patch embedding weight from manual patchify'
+ for k, v in out_dict.items():
+ if k.startswith(
+ 'video_embed.embeddings_st_pos.temporal_pos_embed'
+ ) :
+ # To resize pos embedding when using model at different size from pretrained weights
+ if v.shape != self.state_dict()[k].shape:
+ v = resize_temporal_pos_embed(v, self.state_dict()[k])
+
+ out_dict[k] = v
+
+
+ return super().load_state_dict(out_dict, strict=strict)
\ No newline at end of file
diff --git a/uniperceiver/modeling/predictor/__init__.py b/uniperceiver/modeling/predictor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dbc639afc5ef361d77e88b03a27bb2476d1a8de
--- /dev/null
+++ b/uniperceiver/modeling/predictor/__init__.py
@@ -0,0 +1,7 @@
+
+from .build import build_predictor, build_v_predictor, build_predictor_with_name, add_predictor_config
+
+
+from .embed_cls_predictor import EmbedClsAsRetrievalPredictor
+
+__all__ = list(globals().keys())
diff --git a/uniperceiver/modeling/predictor/base_predictor.py b/uniperceiver/modeling/predictor/base_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..68dc76dfcd7a4da3b139d0fd806090503475cb27
--- /dev/null
+++ b/uniperceiver/modeling/predictor/base_predictor.py
@@ -0,0 +1,277 @@
+
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+from uniperceiver.config import kfg
+from .build import PREDICTOR_REGISTRY
+import math
+import torch.nn.functional as F
+
+__all__ = ["BasePredictor", "RobertaLMHead","TwoLayerPredictor", "RobertaRegressionHead"]
+
+@PREDICTOR_REGISTRY.register()
+class BasePredictor(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ hidden_size: int,
+ vocab_size: int, # include /
+ dropout: float
+ ):
+ super(BasePredictor, self).__init__()
+ self.logits = nn.Linear(hidden_size, vocab_size)
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "hidden_size": cfg.MODEL.DECODER_DIM,
+ "vocab_size": cfg.MODEL.VOCAB_SIZE,
+ "dropout": cfg.MODEL.PRED_DROPOUT
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def forward(self, batched_inputs):
+ hidden_states = batched_inputs[kfg.G_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+ if self.dropout:
+ hidden_states = self.dropout(hidden_states)
+ logits = self.logits(hidden_states)
+ return { kfg.G_LOGITS: logits }
+
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+ )
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+@PREDICTOR_REGISTRY.register()
+class TwoLayerPredictor(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ hidden_size: int,
+ vocab_size: int, # include /
+ dropout: float
+ ):
+ super(TwoLayerPredictor, self).__init__()
+
+ self.dense = nn.Linear(hidden_size, hidden_size)
+ self.activation_fn = gelu
+ self.layer_norm = nn.LayerNorm(hidden_size)
+
+ self.logits = nn.Linear(hidden_size, vocab_size)
+
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
+
+ def replace_logits(self, shared_weights):
+ self.logits.weight = shared_weights
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "hidden_size": cfg.MODEL.DECODER_DIM,
+ "vocab_size": cfg.MODEL.VOCAB_SIZE,
+ "dropout": cfg.MODEL.PRED_DROPOUT
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def forward(self, batched_inputs):
+ hidden_states = batched_inputs[kfg.G_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+
+ x = self.dense(hidden_states)
+ x = self.activation_fn(x)
+ x = self.layer_norm(x)
+
+ logits = self.logits(x)
+ return { kfg.G_LOGITS: logits }
+
+
+@PREDICTOR_REGISTRY.register()
+class RobertaLMHead(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ hidden_size: int,
+ vocab_size: int, # include /
+ dropout: float,
+ untie_weight_embedding: bool,
+ use_bias: bool,
+ share_hidden: bool,
+ ):
+ super(RobertaLMHead, self).__init__()
+
+
+ self.activation_fn = gelu
+
+ if untie_weight_embedding is True:
+ self.weight = nn.Linear(hidden_size, vocab_size, bias=False).weight
+ else:
+ self.weight = None
+
+ if share_hidden:
+ self.dense = None
+ self.layer_norm = None
+ else:
+ self.dense = nn.Linear(hidden_size, hidden_size)
+ self.layer_norm = nn.LayerNorm(hidden_size)
+
+ if use_bias:
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
+ else:
+ self.bias = None
+
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else None
+ # print("dropout: {}".format(self.dropout))
+
+ def replace_weight(self, weight):
+ if self.weight is None:
+ self.weight = weight
+ else:
+ print('already has weight, please set UNTIE_WEIGHT_EMBEDDING to False')
+
+ def replace_module_hidden(self, dense, layer_norm):
+ if (self.dense is None) and (self.layer_norm is None):
+ self.dense = dense
+ self.layer_norm = layer_norm
+ else:
+ print('already has hidden layers!')
+ raise ValueError
+
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "hidden_size": cfg.MODEL.DECODER_DIM,
+ "vocab_size": cfg.MODEL.VOCAB_SIZE,
+ "dropout": cfg.MODEL.PRED_DROPOUT,
+ "untie_weight_embedding": cfg.MODEL.UNTIE_WEIGHT_EMBEDDING,
+ "use_bias": cfg.MODEL.USE_PREDICTOR_BIAS,
+ "share_hidden": cfg.MODEL.SHARE_PREDICTOR_HIDDEN,
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def forward(self, batched_inputs):
+
+ if kfg.G_HIDDEN_STATES in batched_inputs:
+ hidden_states = batched_inputs[kfg.G_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+
+ if kfg.G_TARGET_IDS in batched_inputs:
+ mask_tokens = batched_inputs[kfg.G_TARGET_IDS].ne(-1)
+ hidden_states = hidden_states[mask_tokens]
+ batched_inputs[kfg.G_TARGET_IDS] = batched_inputs[kfg.G_TARGET_IDS][mask_tokens]
+ logits = self._forward(hidden_states)
+
+ return { kfg.G_LOGITS: logits }
+
+ elif kfg.U_HIDDEN_STATES in batched_inputs:
+ hidden_states = batched_inputs[kfg.U_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+
+ mask_tokens = batched_inputs[kfg.U_TARGET_IDS].ne(-1)
+ hidden_states = hidden_states[mask_tokens]
+ batched_inputs[kfg.U_TARGET_IDS] = batched_inputs[kfg.U_TARGET_IDS][mask_tokens]
+ logits = self._forward(hidden_states)
+
+ return { kfg.U_LOGITS: logits }
+
+ def _forward(self, x):
+ x = self.dense(x)
+ x = self.activation_fn(x)
+ x = self.layer_norm(x)
+
+ if self.dropout:
+ x = self.dropout(x)
+ if self.bias is not None:
+ logits = F.linear(x, self.weight) + self.bias
+ else:
+ logits = F.linear(x, self.weight)
+ return logits
+
+@PREDICTOR_REGISTRY.register()
+class RobertaRegressionHead(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ *,
+ hidden_size,
+ feat_dim,
+ transform,
+ sigmoid
+ ):
+ super(RobertaRegressionHead, self).__init__()
+ self.transform = transform
+ self.decoder = nn.Linear(hidden_size, feat_dim)
+ self.output_sigmoid = sigmoid
+
+
+ @classmethod
+ def from_config(cls, cfg):
+ return {
+ "hidden_size": cfg.MODEL.DECODER_DIM,
+ "feat_dim": cfg.MODEL.LABELS_NUM,
+ "sigmoid": cfg.MODEL.SIGMOID,
+ "transform": BertPooler(cfg)
+ }
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def test_forward(self, u_logits):
+ # for Single stream similarity
+ return { kfg.OUTPUT: u_logits }
+
+ def forward(self, batched_inputs):
+ ret = {}
+ if kfg.G_HIDDEN_STATES in batched_inputs:
+ hidden_states = batched_inputs[kfg.G_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+ hidden_states = self.transform(hidden_states)
+ logits = self.decoder(hidden_states)
+ if self.output_sigmoid:
+ logits = torch.sigmoid(logits)
+ ret.update({ kfg.G_LOGITS: logits })
+ if not self.training:
+ ret_test = self.test_forward(logits)
+ ret.update(ret_test)
+ return ret
+
+ elif kfg.U_HIDDEN_STATES in batched_inputs:
+ hidden_states = batched_inputs[kfg.U_HIDDEN_STATES]
+ if isinstance(hidden_states, list):
+ hidden_states = hidden_states[-1]
+ hidden_states = self.transform(hidden_states)
+ logits = self.decoder(hidden_states)
+ if self.output_sigmoid:
+ logits = torch.sigmoid(logits)
+ ret.update({ kfg.U_LOGITS: logits })
+ if not self.training:
+ ret_test = self.test_forward(logits)
+ ret.update(ret_test)
+ return ret
diff --git a/uniperceiver/modeling/predictor/build.py b/uniperceiver/modeling/predictor/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d38260effa27d578c8e4493c319d7e3f2e34097
--- /dev/null
+++ b/uniperceiver/modeling/predictor/build.py
@@ -0,0 +1,23 @@
+
+from uniperceiver.utils.registry import Registry
+
+PREDICTOR_REGISTRY = Registry("PREDICTOR")
+PREDICTOR_REGISTRY.__doc__ = """
+Registry for PREDICTOR
+"""
+
+def build_predictor(cfg):
+ predictor = PREDICTOR_REGISTRY.get(cfg.MODEL.PREDICTOR)(cfg) if len(cfg.MODEL.PREDICTOR) > 0 else None
+ return predictor
+
+def build_v_predictor(cfg):
+ predictor = PREDICTOR_REGISTRY.get(cfg.MODEL.V_PREDICTOR)(cfg) if len(cfg.MODEL.V_PREDICTOR) > 0 else None
+ return predictor
+
+def build_predictor_with_name(cfg, name):
+ predictor = PREDICTOR_REGISTRY.get(name)(cfg) if len(name) > 0 else None
+ return predictor
+
+def add_predictor_config(cfg, tmp_cfg):
+ if len(tmp_cfg.MODEL.PREDICTOR) > 0:
+ PREDICTOR_REGISTRY.get(tmp_cfg.MODEL.PREDICTOR).add_config(cfg)
\ No newline at end of file
diff --git a/uniperceiver/modeling/predictor/embed_cls_predictor.py b/uniperceiver/modeling/predictor/embed_cls_predictor.py
new file mode 100644
index 0000000000000000000000000000000000000000..695fc197d281ec1ad2857f2af178807c84b57c1f
--- /dev/null
+++ b/uniperceiver/modeling/predictor/embed_cls_predictor.py
@@ -0,0 +1,395 @@
+import torch
+from torch import nn
+
+from uniperceiver.config import configurable
+from .build import PREDICTOR_REGISTRY
+import math
+import pickle
+import torch.nn.functional as F
+
+import numpy as np
+
+from uniperceiver.utils import comm
+import torch.distributed as dist
+from uniperceiver.modeling.layers import FP16LayerNorm
+from torch.cuda.amp import autocast
+
+
+
+__all__ = ["EmbedClsAsRetrievalPredictor"]
+def gelu_accurate(x):
+ if not hasattr(gelu_accurate, "_a"):
+ gelu_accurate._a = math.sqrt(2 / math.pi)
+ return (
+ 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
+ )
+
+
+def gelu(x: torch.Tensor) -> torch.Tensor:
+ return torch.nn.functional.gelu(x.float()).type_as(x)
+
+@PREDICTOR_REGISTRY.register()
+class EmbedClsAsRetrievalPredictor(nn.Module):
+ @configurable
+ def __init__(
+ self,
+ temperature,
+ use_norm,
+ temp_learn,
+ mb_list,
+ queue_len,
+ feat_dim,
+ task2tempname,
+ fc_prompt_feature_index,
+ output_proj,
+ cfg,
+
+ ):
+ super(EmbedClsAsRetrievalPredictor, self).__init__()
+ self.cfg = cfg
+ self.use_norm = use_norm
+ self.temp_learn = temp_learn
+ if temp_learn:
+ self.logit_scale_img_cls = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_video_cls = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_text_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_text_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_retrieve = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_tqa_mlm = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_tqa_caption = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_tqa_retrieve = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+ self.logit_scale_downstream = nn.Parameter(torch.ones([]) * np.log(1 / temperature))
+
+ else:
+ self.logit_scale_img_cls = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_video_cls = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_text_mlm = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_text_caption = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_caption = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_mlm = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_retrieve = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_tqa_mlm = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_tqa_caption = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_tqa_retrieve = torch.ones([]).cuda() * np.log(1 / temperature)
+ self.logit_scale_downstream = torch.ones([]).cuda() * np.log(1 / temperature)
+
+
+ self.task2tempname = task2tempname
+
+
+ self.memory_save = []
+ self.queue_len = queue_len
+ self.feat_dim = feat_dim
+ self.fc_prompt_feature_index = fc_prompt_feature_index
+ for task_name in mb_list:
+ self.memory_save.append(task_name)
+ self.register_buffer('queue_h1_{}'.format(task_name), torch.randn(queue_len, feat_dim ))
+ self.register_buffer('queue_h2_{}'.format(task_name), torch.randn(queue_len, feat_dim ))
+ setattr(self, 'queue_h1_{}'.format(task_name), nn.functional.normalize(getattr(self, 'queue_h1_{}'.format(task_name)), dim=1))
+ setattr(self, 'queue_h2_{}'.format(task_name), nn.functional.normalize(getattr(self, 'queue_h2_{}'.format(task_name)), dim=1))
+
+ self.register_buffer("queue_ptr1_{}".format(task_name), torch.zeros(1, dtype=torch.long))
+ self.register_buffer("queue_ptr2_{}".format(task_name), torch.zeros(1, dtype=torch.long))
+
+ pass
+
+ self.output_proj = output_proj
+ if self.output_proj:
+ # if cfg.MODEL.LN_FP32:
+ # self.ln_post = CustomLayernorm(feat_dim)
+ # else:
+ # self.ln_post = nn.LayerNorm(feat_dim)
+ if self.cfg.SOLVER.FORCE_LN_FP16:
+ self.ln_post = FP16LayerNorm(feat_dim)
+ else:
+ self.ln_post = nn.LayerNorm(feat_dim)
+ self.proj = nn.Linear(feat_dim, feat_dim)
+
+ if cfg.MODEL.FEATURE_GATHER_FORCE:
+ assert cfg.DATALOADER.STRATEGY == 'turn'
+ self.gather_feature = True
+ else:
+ self.gather_feature = len(cfg.TASKS) == 1 and getattr(cfg.MODEL, "FEATURE_GATHER", False)
+
+ @torch.no_grad()
+ def _dequeue_and_enqueue(self, q1, q2, task_name):
+ """Update queue."""
+ # gather keys before updating queue
+
+
+ batch_size1 = q1.shape[0]
+ batch_size2 = q2.shape[0]
+
+ ptr1 = int(getattr(self, "queue_ptr1_{}".format(task_name)))
+ ptr2 = int(getattr(self, "queue_ptr2_{}".format(task_name)))
+
+ assert self.queue_len % batch_size1 == 0 # for simplicity
+
+ # replace the keys at ptr (dequeue and enqueue)
+
+ getattr(self, 'queue_h1_{}'.format(task_name))[ptr1:ptr1+batch_size1, :] = q1 # save text features
+ getattr(self, 'queue_h2_{}'.format(task_name))[ptr2:ptr2+batch_size2, :] = q2 # save img features
+
+ ptr1 = (ptr1 + batch_size1) % self.queue_len # move pointer
+ ptr2 = (ptr2 + batch_size2) % self.queue_len # move pointer
+
+
+ getattr(self, "queue_ptr1_{}".format(task_name))[0] = ptr1
+ getattr(self, "queue_ptr2_{}".format(task_name))[0] = ptr2
+
+ pass
+
+ def replace_weight(self, weight):
+ pass
+
+ def replace_module_hidden(self,dense, layer_norm):
+ pass
+
+ @classmethod
+ def from_config(cls, cfg):
+
+ mb_list = []
+ task2tempname = {}
+ if len(cfg.TASKS) > 0:
+ for task_config in cfg.TASKS:
+ if 'MODEL' in task_config and task_config['MODEL'].get('MEMORY_BANK', False):
+ mb_list.append(task_config['NAME'])
+ task2tempname[task_config['NAME']] = task_config['MODEL']['TEMP_NAME']
+
+ ret = { "temperature": cfg.MODEL.PRED_TEMPERATURE,
+ "use_norm": cfg.MODEL.PRED_USE_NORM,
+ 'temp_learn': getattr(cfg.MODEL, "LEARN_TEMP", False),
+ 'mb_list': mb_list,
+ 'queue_len': cfg.MODEL.QUEUE_LEN,
+ 'feat_dim': cfg.MODEL.ENCODER_DIM,
+ 'task2tempname': task2tempname,
+ "fc_prompt_feature_index": cfg.MODEL.FC_PROMPT_INDEX,
+ "output_proj": cfg.MODEL.OUTPUT_PROJ,
+ "cfg": cfg,
+ }
+ print(f'********* using temperature {cfg.MODEL.PRED_TEMPERATURE} **********')
+
+ return ret
+
+ @classmethod
+ def add_config(cls, cfg):
+ pass
+
+ def test_forward(self, logits):
+ return { "output": logits }
+
+ def postproj(self, hidden_states):
+ x = self.ln_post(hidden_states)
+ if self.proj is not None:
+ x = self.proj(x)
+ return x
+
+
+
+ def forward(self,
+ input_sample_list,
+ target_sample_list,
+ shared_target_sets,
+ target_set_list,
+ target_idx_list,
+ task_info,
+ **kwargs):
+
+ if len(target_sample_list) > 0:
+ q2_hidden_states = target_sample_list[0]['data']
+ else:
+ if len(target_set_list) > 1:
+ raise NotImplementedError('only one target supported now')
+ target_set_name = target_set_list[0]
+ q2_hidden_states = shared_target_sets[target_set_name][0]['data']
+
+
+ q1_hidden_states = input_sample_list[0]['data']
+
+ q2_hidden_states = q2_hidden_states[:, 0]
+
+
+ task_type = task_info.get('task_type')
+
+ if task_type in ['image_classification', 'video_classification']:
+ q1_hidden_states = q1_hidden_states[:, 0]
+ elif task_type in ['image_retrieval', 'video_retrieval']:
+ q1_hidden_states = q1_hidden_states[:, 0]
+ elif task_type == 'text_mlm':
+ mask_tokens = target_idx_list[0].ne(-1) # -1 is unmasked position
+ q1_hidden_states = q1_hidden_states[:, -mask_tokens.size(1):][mask_tokens]
+ target_idx_list[0] = target_idx_list[0][mask_tokens]
+ elif task_type in ['image_caption', 'video_caption']:
+ if self.training:
+ sample_info = input_sample_list[0]['sample_info']
+ if isinstance(sample_info, list):
+ sample_info = sample_info[0]
+ text_length = sample_info['data_length'][-1] // 2
+ q1_hidden_states = q1_hidden_states[:, -text_length:, :]
+ mask_tokens = target_idx_list[0].ne(-1) # -1 is padding position
+ q1_hidden_states = q1_hidden_states[mask_tokens] # .flatten(0, 1)
+ target_idx_list[0] = target_idx_list[0][mask_tokens] # .flatten(0, 1)
+ else:
+ q1_hidden_states = q1_hidden_states[:, -1]
+ elif task_type in ['text_classification', 'vqa']:
+ sample_info = input_sample_list[0]['sample_info']
+ if isinstance(sample_info, list):
+ sample_info = sample_info[0]
+
+ sample_infos = sample_info if isinstance(sample_info, list) else sample_info['sample_info_per_sample'][-1]
+ if 'spe_index' in sample_infos[0]:
+ text_length = sample_info['data_length'][-1]
+ q1_hidden_states = q1_hidden_states[:, -text_length:, :] # get text part; remove the first spe or the prompt embedding part
+ # gather spe representation from the 'spe_index' from text part via index of spe token
+ spe_index = torch.tensor([si['spe_index'] for si in sample_infos], device=q1_hidden_states.device).view(-1, 1, 1).expand(-1, -1, q1_hidden_states.size(2))
+ q1_hidden_states = torch.gather(q1_hidden_states, 1, spe_index)[:, 0]
+ else:
+ q1_hidden_states = q1_hidden_states[:, 0]
+
+ else:
+ raise NotImplementedError
+
+
+ if self.output_proj:
+
+ q1_hidden_states = self.postproj(q1_hidden_states)
+ q2_hidden_states = self.postproj(q2_hidden_states)
+
+ feat = q1_hidden_states
+
+
+ if len(target_sample_list) == 0:
+ # in1k
+ logits = self._forward(q1_hidden_states, q2_hidden_states, task_name=task_info.get("task_name", None))
+
+
+ ret = { "logits": [logits], "feats": [feat], "loss_names": [''] }
+ if len(target_idx_list) > 0:
+ ret.update({"targets": [target_idx_list[0]]})
+
+ if not self.training:
+ ret_test = self.test_forward(logits)
+ ret.update(ret_test)
+ # ret = self.test_forward(logits)
+
+
+
+ else:
+ # image and text retrieval in one forwarding:
+
+ if not self.training:
+ return {
+ "input_feats": q1_hidden_states / q1_hidden_states.norm(dim=-1, keepdim=True),
+ "tgt_feats": q2_hidden_states / q2_hidden_states.norm(dim=-1, keepdim=True),
+ }
+
+ if self.gather_feature:
+ local_q1 = q1_hidden_states
+ local_q2 = q2_hidden_states
+ packed_feature = torch.cat([local_q1, local_q2], dim=1).float()
+
+ gathered_features = [ torch.zeros_like(packed_feature) for _ in range(comm.get_world_size())]
+
+ dist.all_gather(gathered_features, packed_feature)
+
+ all_features = torch.cat([packed_feature] +
+ gathered_features[:comm.get_rank()] +
+ gathered_features[comm.get_rank() + 1:]).to(local_q1)
+
+ q1_hidden_states, q2_hidden_states = torch.split(all_features, [local_q1.size(1), local_q2.size(1)], dim=1)
+
+
+ if task_info.get("task_name", None) in self.memory_save:
+ # retrieval task with memory buffer
+ logits, logits_per_cls = self._forward_with_mb(
+ q1_hidden_states,
+ q2_hidden_states,
+ task_name=task_info.get("task_name", None))
+ else:
+ logits, logits_per_cls = self._forward(q1_hidden_states,
+ q2_hidden_states,
+ task_name=task_info.get(
+ "task_name", None),
+ mutual_retrieval=True)
+
+ target = torch.arange(logits.shape[0], dtype=torch.long, device=logits.device)
+ target_per_cls = target
+
+ ret = {
+ "logits": [logits, logits_per_cls],
+ "targets": [target, target_per_cls],
+ "loss_names": ['i2t', 't2i'],
+ }
+
+ return ret
+
+
+ def _forward(self, g, cls_name, task_name, mutual_retrieval=False,):
+ temperature = self.temperature_task(task_name)
+ if self.cfg.SOLVER.FORCE_TEMP_FP16:
+ temperature = temperature.half()
+ if self.temp_learn and temperature > 100.0:
+ temperature = 100.0
+
+ # getattr(self, self.task2tempname[task_name]).data.clamp_(max=math.log(100.0))
+
+ if self.use_norm:
+ if not self.cfg.SOLVER.FORCE_NORM_FP16:
+ g = g / g.norm(dim=-1, keepdim=True)
+ cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True)
+ else:
+ with autocast(enabled=False):
+ g = g / g.norm(dim=-1, keepdim=True)
+ cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True)
+
+
+ logits = (g @ cls_name.t()) * temperature
+
+ if mutual_retrieval:
+ logits_per_cls = logits.transpose(0, 1)
+ return logits, logits_per_cls
+ return logits
+
+ def _forward_with_mb(self, g, cls_name, task_name):
+ temperature = self.temperature_task(task_name)
+ if self.temp_learn and temperature > 100.0:
+ temperature = 100.0
+
+ # if self.temp_learn:
+ # getattr(self, self.task2tempname[task_name]).data.clamp_(max=math.log(100.0))
+ if self.use_norm:
+ g = g / g.norm(dim=-1, keepdim=True)
+ cls_name = cls_name / cls_name.norm(dim=-1, keepdim=True)
+
+ logits_per_image = (g @ cls_name.t()) * temperature
+
+ logits_per_cls = logits_per_image.transpose(0, 1)
+
+ logits_per_image_neg = (g @ getattr(self, 'queue_h1_{}'.format(task_name)).clone().detach().t()) * temperature
+
+ logits_per_cls_neg = (cls_name @ getattr(self, 'queue_h2_{}'.format(task_name)).clone().detach().t()) * temperature
+
+ self._dequeue_and_enqueue(cls_name, g, task_name) # reverse sequnce to save
+
+ return torch.cat([logits_per_image, logits_per_image_neg], dim=-1) , torch.cat([logits_per_cls, logits_per_cls_neg], dim=-1)
+
+ @property
+ def temperature_dict(self):
+ return {
+ 'temperature/img_cls': 1/self.logit_scale_img_cls.exp(),
+ 'temperature/video_cls': 1/self.logit_scale_video_cls.exp(),
+ 'temperature/text_mlm': 1/self.logit_scale_text_mlm.exp(),
+ 'temperature/text_caption': 1/self.logit_scale_text_caption.exp(),
+ 'temperature/caption': 1/self.logit_scale_caption.exp(),
+ 'temperature/mlm': 1/self.logit_scale_mlm.exp(),
+ 'temperature/retrieve': 1/self.logit_scale_retrieve.exp(),
+ 'temperature/tqa_mlm': 1/self.logit_scale_tqa_mlm.exp(),
+ 'temperature/tqa_caption': 1/self.logit_scale_tqa_caption.exp(),
+ 'temperature/tqa_retrieve': 1/self.logit_scale_tqa_retrieve.exp(),
+ 'temperature/downstream': 1/self.logit_scale_downstream.exp(),
+ }
+
+ def temperature_task(self, taskname):
+ return getattr(self, self.task2tempname[taskname]).exp()
diff --git a/uniperceiver/optim/__init__.py b/uniperceiver/optim/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..885c5eb4bfb24a1ee7d45ea415283b388b07c5e3
--- /dev/null
+++ b/uniperceiver/optim/__init__.py
@@ -0,0 +1,6 @@
+
+from .build import build_optimizer, create_moe_param_groups, create_seperate_moe_param_groups, create_group_moe_param_groups
+
+
+from .adamw import AdamW
+from .lamb import LAMB
diff --git a/uniperceiver/optim/adamw.py b/uniperceiver/optim/adamw.py
new file mode 100644
index 0000000000000000000000000000000000000000..335b36f700382104e2fab85624da6b30b6f153bc
--- /dev/null
+++ b/uniperceiver/optim/adamw.py
@@ -0,0 +1,41 @@
+# Copyright 2021 JD.com, Inc., JD AI
+"""
+@author: Jianjie Luo
+@contact: jianjieluo.sysu@gmail.com
+"""
+import torch
+from uniperceiver.config import configurable
+from .build import SOLVER_REGISTRY
+
+@SOLVER_REGISTRY.register()
+class AdamW(torch.optim.AdamW):
+ @configurable
+ def __init__(
+ self,
+ *,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0.01,
+ amsgrad=False
+ ):
+ super(AdamW, self).__init__(
+ params,
+ lr,
+ betas,
+ eps,
+ weight_decay,
+ amsgrad
+ )
+
+ @classmethod
+ def from_config(cls, cfg, params):
+ return {
+ "params": params,
+ "lr": cfg.SOLVER.BASE_LR,
+ "betas": cfg.SOLVER.BETAS,
+ "eps": cfg.SOLVER.EPS,
+ "weight_decay": cfg.SOLVER.WEIGHT_DECAY,
+ "amsgrad": cfg.SOLVER.AMSGRAD
+ }
diff --git a/uniperceiver/optim/build.py b/uniperceiver/optim/build.py
new file mode 100644
index 0000000000000000000000000000000000000000..85cfc3f9825c715a5b818527b1933bed28d43ad4
--- /dev/null
+++ b/uniperceiver/optim/build.py
@@ -0,0 +1,652 @@
+import copy
+import torch
+import itertools
+from enum import Enum
+from uniperceiver.config import CfgNode
+from uniperceiver.utils.registry import Registry
+from uniperceiver.utils import comm
+
+from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
+
+SOLVER_REGISTRY = Registry("SOLVER")
+SOLVER_REGISTRY.__doc__ = """
+Registry for SOLVER.
+"""
+
+_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
+_GradientClipper = Callable[[_GradientClipperInput], None]
+
+def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
+ def clip_grad_norm(p: _GradientClipperInput):
+ torch.nn.utils.clip_grad_norm_(p, cfg.SOLVER.GRAD_CLIP, cfg.SOLVER.NORM_TYPE)
+
+ def clip_grad_value(p: _GradientClipperInput):
+ torch.nn.utils.clip_grad_value_(p, cfg.SOLVER.GRAD_CLIP)
+
+ _GRADIENT_CLIP_TYPE_TO_CLIPPER = {
+ 'value': clip_grad_value,
+ 'norm': clip_grad_norm,
+ }
+ clipper = _GRADIENT_CLIP_TYPE_TO_CLIPPER[cfg.SOLVER.GRAD_CLIP_TYPE]
+ if cfg.SOLVER.GRAD_CLIP_TYPE == 'value':
+ return clipper, None
+ else:
+ return None, clipper
+
+
+def get_default_optimizer_params(
+ model: torch.nn.Module,
+ base_lr: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ weight_decay_norm: Optional[float] = None,
+ bias_lr_factor: Optional[float] = 1.0,
+ weight_decay_bias: Optional[float] = None,
+ overrides: Optional[Dict[str, Dict[str, float]]] = None,
+):
+ if weight_decay_bias is None:
+ weight_decay_bias = weight_decay
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+
+ no_decay_list = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay_list = model.no_weight_decay()
+
+ for module_name, module in model.named_modules():
+ no_decay = False
+ if module_name in no_decay_list:
+ no_decay = True
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+
+ schedule_params = {
+ "lr": base_lr,
+ "weight_decay": weight_decay,
+ }
+
+
+ if isinstance(module, norm_module_types):
+ schedule_params["weight_decay"] = weight_decay_norm
+ elif module_param_name == "bias":
+ # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0
+ # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer
+ # hyperparameters are by default exactly the same as for regular
+ # weights.
+ schedule_params["lr"] = base_lr * bias_lr_factor
+ schedule_params["weight_decay"] = weight_decay_bias
+
+ if no_decay or (module_param_name in no_decay_list):
+ schedule_params["weight_decay"] = 0.
+
+
+ if overrides is not None and module_param_name in overrides:
+ schedule_params.update(overrides[module_param_name])
+ params += [
+ {
+ "params": [value],
+ "lr": schedule_params["lr"],
+ "weight_decay": schedule_params["weight_decay"],
+ }
+ ]
+
+ return params
+
+def get_layer_id(module_name, num_layers):
+ """
+ Assign a parameter with its layer id
+ modified from BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
+ """
+ if module_name.split('.')[0] in [
+ 'video_embed', 'token_embed', 'prompt_embed', 'visual_embed', 'cls_token' ''
+ ]:
+ return 0
+ elif module_name.startswith('encoder'):
+ return int(module_name.split('.')[2]) + 1
+ elif module_name.startswith('predictor'):
+ return num_layers
+ else:
+ raise NotImplementedError('please check this layer')
+
+def create_seperate_moe_param_groups(
+ model,
+ base_lr: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ weight_decay_norm: Optional[float] = None,
+ bias_lr_factor: Optional[float] = 1.0,
+ wg_lr_facetor: Optional[float] = 1.0,
+ weight_decay_bias: Optional[float] = None,
+ weight_decay_embedding: Optional[float] = None,
+ weight_decay_wg: Optional[float] = None,
+ cfg: dict = None,
+):
+ try:
+ from deepspeed.moe.utils import is_moe_param
+ except:
+ def is_moe_param(param: torch.Tensor) -> bool:
+ if hasattr(param, "allreduce") and not param.allreduce:
+ return True
+ return False
+
+ params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+
+ num_layers = cfg.MODEL.BERT.NUM_HIDDEN_LAYERS + 1
+ layer_decay = cfg.SOLVER.LAYER_LR_DECAY
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
+
+
+ if weight_decay_bias is None:
+ weight_decay_bias = weight_decay
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ # NaiveSyncBatchNorm inherits from BatchNorm2d
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+
+
+
+ no_decay_list = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay_list = model.no_weight_decay()
+
+ wg_list = {}
+ if hasattr(model, 'expert_gate_group'):
+ wg_list = model.expert_gate_group()
+
+
+
+ for module_name, module in model.named_modules():
+ no_decay = False
+ if module_name in no_decay_list:
+ no_decay = True
+ is_wg_param = False
+ for wg_name in wg_list:
+ if wg_name in module_name:
+ is_wg_param = True
+ continue
+
+ for module_param_name, value in module.named_parameters(recurse=False):
+ # layer_id = get_layer_id(module_name, num_layers)
+ this_scale = layer_scales[ get_layer_id(module_name, num_layers)] if layer_decay < 1.0 else 1.0
+ # if isinstance(module, torch.nn.Embedding):
+ # print(module_name, module_param_name)
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+ schedule_params = {
+ "lr": base_lr,
+ "weight_decay": weight_decay,
+ "moe": False,
+ }
+ if is_moe_param(value):
+ schedule_params['moe'] = True
+
+ if no_decay or (module_param_name in no_decay_list):
+ schedule_params["weight_decay"] = 0.
+ elif is_wg_param and isinstance(
+ module,
+ torch.nn.Linear) and module_param_name != "bias":
+ # only add linear weights in gate function
+ schedule_params["lr"] = base_lr * wg_lr_facetor
+ schedule_params["weight_decay"] = weight_decay_wg
+
+ elif isinstance(module, torch.nn.Embedding):
+ schedule_params['weight_decay'] = weight_decay_embedding
+
+ elif isinstance(module, norm_module_types):
+ if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias":
+ # ln bias use the same params as linear bias
+ schedule_params["lr"] = base_lr * bias_lr_factor
+ schedule_params['weight_decay'] = weight_decay_bias
+ else:
+ schedule_params['weight_decay'] = weight_decay_norm
+
+ elif module_param_name == "bias" or value.ndim == 1:
+ schedule_params["lr"] = base_lr * bias_lr_factor
+ schedule_params['weight_decay'] = weight_decay_bias
+
+ params += [{
+ "params": [value],
+ "lr": max(schedule_params["lr"] * this_scale, cfg.LR_SCHEDULER.get('MIN_LR', 1e-6)),
+ "moe": schedule_params['moe'],
+ "weight_decay": schedule_params["weight_decay"],
+ "name": f'{module_name}.{module_param_name}'
+ }]
+
+
+
+ return params
+
+
+def create_group_moe_param_groups(
+ model,
+ base_lr: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ weight_decay_norm: Optional[float] = None,
+ bias_lr_factor: Optional[float] = 1.0,
+ wg_lr_facetor: Optional[float] = 1.0,
+ weight_decay_bias: Optional[float] = None,
+ weight_decay_embedding: Optional[float] = None,
+ weight_decay_wg: Optional[float] = None,
+ cfg: dict = None,
+):
+ from deepspeed.moe.utils import is_moe_param
+
+ # params: List[Dict[str, Any]] = []
+ memo: Set[torch.nn.parameter.Parameter] = set()
+
+ if weight_decay_bias is None:
+ weight_decay_bias = weight_decay
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+ group_params_dict = {}
+
+ no_decay_list = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay_list = model.no_weight_decay()
+
+ wg_list = {}
+ if hasattr(model, 'expert_gate_group'):
+ wg_list = model.expert_gate_group()
+
+ for module_name, module in model.named_modules():
+ no_decay = False
+ if module_name in no_decay_list:
+ no_decay = True
+ is_wg_param = False
+ for wg_name in wg_list:
+ if wg_name in module_name:
+ is_wg_param = True
+ continue
+
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+
+ # default setting
+ lr_of_this_param = base_lr
+ wd_of_this_param = weight_decay
+ moe_of_this_param = False
+ if is_moe_param(value):
+ moe_of_this_param = True
+
+ if no_decay or (module_param_name in no_decay_list):
+
+ wd_of_this_param = 0.
+ elif is_wg_param and isinstance(
+ module, torch.nn.Linear) and module_param_name != "bias":
+ # only add linear weights in gate function
+ lr_of_this_param = base_lr * wg_lr_facetor
+ wd_of_this_param = weight_decay_wg
+
+ elif isinstance(module, torch.nn.Embedding):
+ wd_of_this_param = weight_decay_embedding
+
+ elif isinstance(module, norm_module_types):
+ if not cfg.SOLVER.WEIGHT_DECAY_NORMBIAS_WEIGHT and module_param_name == "bias":
+ # ln bias uses the same params as linear bias
+ lr_of_this_param = base_lr * bias_lr_factor
+ wd_of_this_param = weight_decay_bias
+ else:
+ wd_of_this_param = weight_decay_norm
+
+ elif module_param_name == "bias":
+ lr_of_this_param = base_lr * bias_lr_factor
+ wd_of_this_param = weight_decay_bias
+
+ param_group_name = f'lr_{lr_of_this_param}_wd_{wd_of_this_param}_moe_{moe_of_this_param}'
+ if param_group_name not in group_params_dict:
+ group_params_dict[param_group_name] = {
+ 'params': [],
+ "lr": lr_of_this_param,
+ "weight_decay": wd_of_this_param,
+ 'moe': moe_of_this_param,
+ 'name': param_group_name,
+ 'params_name': [],
+ }
+ group_params_dict[param_group_name]['params'].append(value)
+ group_params_dict[param_group_name]['params_name'].append(
+ f'{module_name}.{module_param_name}')
+
+
+ valid_params_groups = list(group_params_dict.values())
+ return valid_params_groups
+
+
+
+
+def create_moe_param_groups(
+ model,
+ base_lr: Optional[float] = None,
+ weight_decay: Optional[float] = None,
+ weight_decay_norm: Optional[float] = None,
+ bias_lr_factor: Optional[float] = 1.0,
+ wg_lr_facetor: Optional[float] = 1.0,
+ weight_decay_bias: Optional[float] = None,
+ weight_decay_embedding: Optional[float] = None,
+ weight_decay_wg: Optional[float] = None,
+
+):
+ from deepspeed.moe.utils import is_moe_param
+
+ '''
+ name:
+ '''
+ if weight_decay_bias is None:
+ weight_decay_bias = weight_decay
+ norm_module_types = (
+ torch.nn.BatchNorm1d,
+ torch.nn.BatchNorm2d,
+ torch.nn.BatchNorm3d,
+ torch.nn.SyncBatchNorm,
+ torch.nn.GroupNorm,
+ torch.nn.InstanceNorm1d,
+ torch.nn.InstanceNorm2d,
+ torch.nn.InstanceNorm3d,
+ torch.nn.LayerNorm,
+ torch.nn.LocalResponseNorm,
+ )
+
+ if weight_decay_embedding == 0.0:
+ norm_module_types = norm_module_types + (torch.nn.Embedding, )
+ else:
+ # if weight_decay_embedding is not 0.0, we set its weight_decay as normal weights
+ # assert weight_decay_embedding == weight_decay
+ pass
+
+
+
+ params_with_weight_decay = {
+ 'params': [],
+ 'name': 'weight_decay_params',
+ 'params_name': [],
+ }
+ params_without_weight_decay = {
+ 'params': [],
+ "weight_decay": 0.0,
+ 'name': 'without_weight_decay_params',
+ 'params_name': [],
+ }
+ bias_params = {
+ 'params': [],
+ "lr": base_lr * bias_lr_factor,
+ "weight_decay": weight_decay_bias,
+ 'name': 'bias_params',
+ 'params_name': [],
+ }
+ wg_params = {
+ 'params': [],
+ "lr": base_lr * wg_lr_facetor,
+ "weight_decay": weight_decay_wg,
+ 'name': 'wg_params',
+ 'params_name': [],
+ }
+ norm_params = {
+ 'params': [],
+ "weight_decay": weight_decay_norm,
+ 'name': 'norm_params',
+ 'params_name': [],
+ }
+ moe_params_with_weight_decay = {
+ 'params': [],
+ 'moe': True,
+ 'name': 'weight_decay_moe_params',
+ 'params_name': [],
+ }
+ moe_params_without_weight_decay = {
+ 'params': [],
+ "weight_decay": 0.0,
+ 'moe': True,
+ 'name': 'without_weight_decay_moe_params',
+ 'params_name': [],
+ }
+ moe_bias_params = {
+ 'params': [],
+ "lr": base_lr * bias_lr_factor,
+ "weight_decay": weight_decay_bias,
+ 'moe': True,
+ 'name': 'bias_moe_params',
+ 'params_name': [],
+ }
+ moe_norm_params = {
+ 'params': [],
+ "weight_decay": weight_decay_norm,
+ 'moe': True,
+ 'name': 'norm_moe_params',
+ 'params_name': [],
+ }
+
+ params_groups = [
+ params_with_weight_decay, params_without_weight_decay, norm_params, bias_params, wg_params, \
+ moe_params_with_weight_decay, moe_params_without_weight_decay, moe_norm_params, moe_bias_params
+ ]
+
+
+
+ no_decay_list = {}
+ if hasattr(model, 'no_weight_decay'):
+ no_decay_list = model.no_weight_decay()
+
+ wg_list = {}
+ if hasattr(model, 'expert_gate_group'):
+ wg_list = model.expert_gate_group()
+
+ memo: Set[torch.nn.parameter.Parameter] = set()
+
+ for module_name, module in model.named_modules():
+ no_decay = False
+ if module_name in no_decay_list:
+ no_decay = True
+ is_wg_param = False
+ for wg_name in wg_list:
+ if wg_name in module_name:
+ is_wg_param = True
+ continue
+
+ for module_param_name, value in module.named_parameters(recurse=False):
+ if not value.requires_grad:
+ continue
+ # Avoid duplicating parameters
+ if value in memo:
+ continue
+ memo.add(value)
+ if is_moe_param(value):
+ if no_decay or (module_param_name in no_decay_list):
+ moe_params_without_weight_decay['params'].append(value)
+ elif isinstance(module, norm_module_types):
+ moe_norm_params['params'].append(value)
+ elif module_param_name == "bias":
+ moe_bias_params['params'].append(value)
+ else:
+ moe_params_with_weight_decay['params'].append(value)
+ else:
+ if no_decay or (module_param_name in no_decay_list):
+ params_without_weight_decay['params'].append(value)
+ params_without_weight_decay['params_name'].append(f'{module_name}.{module_param_name}')
+ elif is_wg_param and isinstance(module, torch.nn.Linear) and module_param_name != "bias":
+ # only add linear weights in gate function
+ wg_params['params'].append(value)
+ wg_params['params_name'].append(
+ f'{module_name}.{module_param_name}')
+ elif isinstance(module, norm_module_types):
+ norm_params['params'].append(value)
+ norm_params['params_name'].append(
+ f'{module_name}.{module_param_name}')
+ elif module_param_name == "bias":
+ bias_params['params'].append(value)
+ bias_params['params_name'].append(
+ f'{module_name}.{module_param_name}')
+ else:
+ params_with_weight_decay['params'].append(value)
+ params_with_weight_decay['params_name'].append(
+ f'{module_name}.{module_param_name}')
+
+ valid_params_groups = [
+ group for group in params_groups if len(group['params']) > 0
+ ]
+
+ return valid_params_groups
+
+
+
+
+
+
+def _generate_optimizer_class_with_gradient_clipping(
+ optimizer: Type[torch.optim.Optimizer],
+ *,
+ per_param_clipper: Optional[_GradientClipper] = None,
+ global_clipper: Optional[_GradientClipper] = None,
+) -> Type[torch.optim.Optimizer]:
+ """
+ Dynamically creates a new type that inherits the type of a given instance
+ and overrides the `step` method to add gradient clipping
+ """
+ assert (
+ per_param_clipper is None or global_clipper is None
+ ), "Not allowed to use both per-parameter clipping and global clipping"
+
+ def optimizer_wgc_step(self, closure=None):
+ if per_param_clipper is not None:
+ for group in self.param_groups:
+ for p in group["params"]:
+ per_param_clipper(p)
+ else:
+ # global clipper for future use with detr
+ # (https://github.com/facebookresearch/detr/pull/287)
+ all_params = itertools.chain(*[g["params"] for g in self.param_groups])
+ norm_before_clip = global_clipper(all_params)
+
+ super(type(self), self).step(closure)
+
+ OptimizerWithGradientClip = type(
+ optimizer.__name__ + "WithGradientClip",
+ (optimizer,),
+ {"step": optimizer_wgc_step},
+ )
+ return OptimizerWithGradientClip
+
+def maybe_add_gradient_clipping(
+ cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
+) -> Type[torch.optim.Optimizer]:
+ """
+ If gradient clipping is enabled through config options, wraps the existing
+ optimizer type to become a new dynamically created class OptimizerWithGradientClip
+ that inherits the given optimizer and overrides the `step` method to
+ include gradient clipping.
+
+ Args:
+ cfg: CfgNode, configuration options
+ optimizer: type. A subclass of torch.optim.Optimizer
+
+ Return:
+ type: either the input `optimizer` (if gradient clipping is disabled), or
+ a subclass of it with gradient clipping included in the `step` method.
+ """
+ if cfg.SOLVER.GRAD_CLIP <= 0:
+ return optimizer
+ if isinstance(optimizer, torch.optim.Optimizer):
+ optimizer_type = type(optimizer)
+ else:
+ assert issubclass(optimizer, torch.optim.Optimizer), optimizer
+ optimizer_type = optimizer
+
+ per_param_clipper, global_clipper = _create_gradient_clipper(cfg)
+ OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
+ optimizer_type, per_param_clipper=per_param_clipper, global_clipper=global_clipper
+ )
+ if isinstance(optimizer, torch.optim.Optimizer):
+ optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
+ return optimizer
+ else:
+ return OptimizerWithGradientClip
+
+def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
+ """
+ Build an optimizer from config.
+ """
+ # params = get_default_optimizer_params(
+ # model,
+ # base_lr=cfg.SOLVER.BASE_LR,
+ # weight_decay=cfg.SOLVER.WEIGHT_DECAY,
+ # weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
+ # bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
+ # weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
+ # )
+ params = create_seperate_moe_param_groups(
+ model,
+ base_lr=cfg.SOLVER.BASE_LR,
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
+ weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
+ bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
+ wg_lr_facetor=cfg.SOLVER.WG_LR_FACTOR,
+ weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
+ weight_decay_embedding=cfg.SOLVER.WEIGHT_DECAY_EMBEDDING,
+ weight_decay_wg=cfg.SOLVER.WEIGHT_DECAY_WG,
+ cfg=cfg,
+ )
+ if cfg.SOLVER.NAME == 'LAMB':
+ from uniperceiver.optim import LAMB
+ optimizer = LAMB(
+ params,
+ lr=cfg.SOLVER.BASE_LR,
+ betas=cfg.SOLVER.BETAS,
+ eps=cfg.SOLVER.EPS,
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY, )
+
+ else:
+ optimizer = torch.optim.AdamW(
+ params,
+ lr=cfg.SOLVER.BASE_LR,
+ betas=cfg.SOLVER.BETAS,
+ eps=cfg.SOLVER.EPS,
+ weight_decay=cfg.SOLVER.WEIGHT_DECAY,
+ )
+ # optimizer = SOLVER_REGISTRY.get(cfg.SOLVER.NAME)
+ # return maybe_add_gradient_clipping(cfg, optimizer)(cfg, params)
+ return optimizer
diff --git a/uniperceiver/optim/lamb.py b/uniperceiver/optim/lamb.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b840431d0f01a8e00915446184d25665dcb383d
--- /dev/null
+++ b/uniperceiver/optim/lamb.py
@@ -0,0 +1,195 @@
+# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_lamb.py
+
+import torch
+
+from apex.multi_tensor_apply import multi_tensor_applier
+class LAMB(torch.optim.Optimizer):
+ """Implements LAMB algorithm.
+
+ Currently GPU-only. Requires Apex to be installed via
+ ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
+
+ This version of fused LAMB implements 2 fusions.
+
+ * Fusion of the LAMB update's elementwise operations
+ * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
+
+ :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
+
+ opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
+ ...
+ opt.step()
+
+ :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
+ you may choose any ``opt_level``::
+
+ opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
+ model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
+ ...
+ opt.step()
+
+ In general, ``opt_level="O1"`` is recommended.
+
+ LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups.
+ lr (float, optional): learning rate. (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its norm. (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability. (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ NOT SUPPORTED now! (default: False)
+ adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
+ True for decoupled weight decay(also known as AdamW) (default: True)
+ grad_averaging (bool, optional): whether apply (1-beta2) to grad when
+ calculating running averages of gradient. (default: True)
+ set_grad_none (bool, optional): whether set grad to None when zero_grad()
+ method is called. (default: True)
+ max_grad_norm (float, optional): value used to clip global grad norm
+ (default: 1.0)
+ use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
+ weight decay parameter (default: False)
+
+ .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
+ https://arxiv.org/abs/1904.00962
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+ def __init__(self,
+ params,
+ lr=1e-3,
+ bias_correction=True,
+ betas=(0.9, 0.999),
+ eps=1e-6,
+ weight_decay=0.01,
+ amsgrad=False,
+ adam_w_mode=True,
+ grad_averaging=True,
+ set_grad_none=True,
+ max_grad_norm=1.0,
+ use_nvlamb=False):
+ if amsgrad:
+ raise RuntimeError('FusedLAMB does not support the AMSGrad variant.')
+ defaults = dict(lr=lr,
+ bias_correction=bias_correction,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ grad_averaging=grad_averaging,
+ max_grad_norm=max_grad_norm)
+ super(LAMB, self).__init__(params, defaults)
+ if multi_tensor_applier.available:
+ import amp_C
+ self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
+ # Skip buffer
+ self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device)
+ self.multi_tensor_lamb = amp_C.multi_tensor_lamb
+ else:
+ raise RuntimeError('apex.optimizers.FusedLAMB requires cuda extensions')
+
+ self.adam_w_mode = 1 if adam_w_mode else 0
+ self.set_grad_none = set_grad_none
+ self.use_nvlamb = use_nvlamb
+
+ def zero_grad(self):
+ if self.set_grad_none:
+ for group in self.param_groups:
+ for p in group['params']:
+ p.grad = None
+ else:
+ super(LAMB, self).zero_grad()
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ # create separate grad lists for fp32 and fp16 params
+ g_all_32, g_all_16 = [], []
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ if p.dtype == torch.float32:
+ g_all_32.append(p.grad.data)
+ elif p.dtype == torch.float16:
+ g_all_16.append(p.grad.data)
+ else:
+ raise RuntimeError('FusedLAMB only support fp16 and fp32.')
+
+ device = self.param_groups[0]["params"][0].device
+ g_norm_32, g_norm_16 = torch.zeros(1, device=device), torch.zeros(1, device=device)
+ # compute grad norm for two lists
+ if len(g_all_32) > 0:
+ g_norm_32 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_32], False)[0]
+ if len(g_all_16) > 0:
+ g_norm_16 = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [g_all_16], False)[0]
+
+ # blend two grad norms to get global grad norm
+ global_grad_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._dummy_overflow_buf, [[g_norm_32, g_norm_16]], False)[0]
+ max_grad_norm = self.defaults['max_grad_norm']
+
+ for group in self.param_groups:
+ bias_correction = 1 if group['bias_correction'] else 0
+ beta1, beta2 = group['betas']
+ grad_averaging = 1 if group['grad_averaging'] else 0
+
+ # assume same step across group now to simplify things
+ # per parameter step can be easily support by making it tensor, or pass list into kernel
+ if 'step' in group:
+ group['step'] += 1
+ else:
+ group['step'] = 1
+
+ # create lists for multi-tensor apply
+ g_16, p_16, m_16, v_16 = [], [], [], []
+ g_32, p_32, m_32, v_32 = [], [], [], []
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ if p.grad.data.is_sparse:
+ raise RuntimeError('FusedLAMB does not support sparse gradients, please consider SparseAdam instead')
+
+ state = self.state[p]
+ # State initialization
+ if len(state) == 0:
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+ if p.dtype == torch.float16:
+ g_16.append(p.grad.data)
+ p_16.append(p.data)
+ m_16.append(state['exp_avg'])
+ v_16.append(state['exp_avg_sq'])
+ elif p.dtype == torch.float32:
+ g_32.append(p.grad.data)
+ p_32.append(p.data)
+ m_32.append(state['exp_avg'])
+ v_32.append(state['exp_avg_sq'])
+ else:
+ raise RuntimeError('FusedLAMB only support fp16 and fp32.')
+
+ if (len(g_16) > 0):
+ multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_16, p_16, m_16, v_16], group['lr'], beta1, beta2,
+ group['eps'], group['step'], bias_correction, group['weight_decay'], grad_averaging, self.adam_w_mode,
+ global_grad_norm, max_grad_norm, self.use_nvlamb)
+ if (len(g_32) > 0):
+ multi_tensor_applier(self.multi_tensor_lamb, self._dummy_overflow_buf, [g_32, p_32, m_32, v_32], group['lr'], beta1, beta2,
+ group['eps'], group['step'], bias_correction, group['weight_decay'], grad_averaging, self.adam_w_mode,
+ global_grad_norm, max_grad_norm, self.use_nvlamb)
+
+ return loss
diff --git a/uniperceiver/task_moe/__init__.py b/uniperceiver/task_moe/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd247a7e5fc315faab7ccd86b1b7a8d5927af110
--- /dev/null
+++ b/uniperceiver/task_moe/__init__.py
@@ -0,0 +1,6 @@
+
+
+
+from .layer import TaskMoE
+
+__all__ = list(globals().keys())
diff --git a/uniperceiver/task_moe/experts.py b/uniperceiver/task_moe/experts.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ae9a916030503f1ea4592d39bd744f3d2957202
--- /dev/null
+++ b/uniperceiver/task_moe/experts.py
@@ -0,0 +1,162 @@
+'''
+Copyright 2020 The Microsoft DeepSpeed Team
+'''
+
+import torch
+import copy
+from .gate import one_hot_with_dtype
+from uniperceiver.utils import comm
+
+import torch.nn.functional as F
+
+from torch.cuda.amp import autocast
+
+
+class FusedExperts(torch.nn.Module):
+ def __init__(self, expert, cfg, num_local_experts=1):
+ super(FusedExperts, self).__init__()
+ self.cfg = cfg
+
+ self.deepspeed_experts = torch.nn.ModuleList(
+ [copy.deepcopy(expert) for i in range(num_local_experts)])
+ self.num_local_experts = num_local_experts
+
+ self.bias_merge = self.deepspeed_experts[0].bias is not None
+
+
+ def top1_expert_forward(self, x, indice, gate, mode=None, **kwargs):
+ assert mode is None, "unified qkv inference is not supported for top1"
+ if indice.size(0)== 1:
+ #unimodal
+ x = self.deepspeed_experts[indice[0]](x) * gate[0].to(x)
+ elif indice.size(0) == 2:
+ # mulmodal
+ data1_length = kwargs['sample_info']['data_cum_length'][1]
+ x = torch.cat([
+ self.deepspeed_experts[indice[0]](x[:, :data1_length, :]) * gate[0].to(x),
+ self.deepspeed_experts[indice[1]](x[:, data1_length:, :]) * gate[1].to(x)
+ ],
+ dim=1)
+
+ else:
+ raise NotImplementedError('only support one or two modality')
+ return x
+
+ def mergelayer(self, x, index1, index2, gate1, gate2, mode=None):
+
+ if not self.cfg.SOLVER.FORCE_EXPERT_ADDING_FP16:
+ if mode == 'q':
+ # hidden_states
+ _start = 0
+ _end = self.deepspeed_experts[index1].weight.shape[0] // 3
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight[_start:_end, :] * gate1 +
+ self.deepspeed_experts[index2].weight[_start:_end, :] * gate2,
+ bias=self.deepspeed_experts[index1].bias[_start:_end] * gate1 +
+ self.deepspeed_experts[index2].bias[_start:_end] * gate2
+ if self.bias_merge else None,
+ )
+
+ elif mode == 'kv':
+ # history_states
+ _start = self.deepspeed_experts[index1].weight.shape[0] // 3
+
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight[_start:, :] * gate1 +
+ self.deepspeed_experts[index2].weight[_start:, :] * gate2,
+ bias=self.deepspeed_experts[index1].bias[_start:] * gate1 +
+ self.deepspeed_experts[index2].bias[_start:] * gate2
+ if self.bias_merge else None,
+ )
+
+ else:
+
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight * gate1 +
+ self.deepspeed_experts[index2].weight * gate2,
+ bias=self.deepspeed_experts[index1].bias * gate1 +
+ self.deepspeed_experts[index2].bias * gate2 if self.bias_merge else None,
+ )
+ else:
+ if mode == 'q':
+ # hidden_states
+ _start = 0
+ _end = self.deepspeed_experts[index1].weight.shape[0] // 3
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight[_start:_end, :].half() * gate1 +
+ self.deepspeed_experts[index2].weight[_start:_end, :].half() * gate2,
+ bias=self.deepspeed_experts[index1].bias[_start:_end].half() * gate1 +
+ self.deepspeed_experts[index2].bias[_start:_end].half() * gate2 if self.bias_merge else None,
+ )
+
+ elif mode == 'kv':
+ # history_states
+ _start = self.deepspeed_experts[index1].weight.shape[0] // 3
+
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight[_start:, :].half() * gate1 +
+ self.deepspeed_experts[index2].weight[_start:, :].half() * gate2,
+ bias=self.deepspeed_experts[index1].bias[_start:].half() * gate1 +
+ self.deepspeed_experts[index2].bias[_start:].half() * gate2 if self.bias_merge else None,
+ )
+
+ else:
+
+ return F.linear(
+ x,
+ self.deepspeed_experts[index1].weight.half() * gate1 + self.deepspeed_experts[index2].weight.half() * gate2,
+ bias=self.deepspeed_experts[index1].bias.half() * gate1 +
+ self.deepspeed_experts[index2].bias.half() * gate2 if self.bias_merge else None,
+ )
+
+
+ def top2_expert_forward(self, x, indices, gates, mode=None, **kwargs):
+
+ # caption eval mode
+ if comm._CAPTION_GEN_MODE and x.shape[1] == 1:
+ #
+ return self.mergelayer(x,
+ indices[0][1], indices[1][1],
+ gates[0][1], gates[1][1], mode=mode)
+
+ # unimodal
+ if indices[0].size(0) == 1:
+ x = self.mergelayer(x, indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode)
+ elif indices[0].size(0) == 2:
+ data1_length = kwargs['sample_info']['data_cum_length'][1]
+ if mode == 'kv' and kwargs['sample_info'].get('pe_length', 0) > 0:
+ # may have prompt embedding for kv embedding
+ data1_length += kwargs['sample_info'].get('pe_length', 0)
+ x = torch.cat([
+ self.mergelayer(x[:, :data1_length, :], indices[0][0], indices[1][0], gates[0][0], gates[1][0], mode=mode),
+ self.mergelayer(x[:, data1_length:, :], indices[0][1], indices[1][1], gates[0][1], gates[1][1], mode=mode)
+ ],
+ dim=1)
+
+ else:
+ raise NotImplementedError('only support one or two modality')
+ return x
+
+ def forward(self, hidden_states, top_indices=None, gates=None, **kwargs):
+
+ # top1
+ if len(top_indices) == 1:
+ out = self.top1_expert_forward(hidden_states, top_indices[0], gates[0], **kwargs)
+
+ # top2
+ elif len(top_indices) == 2:
+ out = self.top2_expert_forward(hidden_states, top_indices, gates, **kwargs)
+
+ else:
+ raise NotImplementedError("only support top1 and top2 ")
+
+
+
+ assert out.shape[1] == hidden_states.shape[1]
+
+ return out
diff --git a/uniperceiver/task_moe/gate.py b/uniperceiver/task_moe/gate.py
new file mode 100644
index 0000000000000000000000000000000000000000..dec290f760790a2908665f57f8595f7b1f974e72
--- /dev/null
+++ b/uniperceiver/task_moe/gate.py
@@ -0,0 +1,536 @@
+'''
+Copyright 2021 The Microsoft DeepSpeed Team
+'''
+# The file has been adapted from two fairscale files:
+# (1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py
+# (2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py
+# Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf
+# We retain the following license from the original files:
+
+# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
+#
+# This source code is licensed under the BSD license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union, cast
+
+import time
+from time import perf_counter
+import torch
+from torch import nn
+from torch import Tensor
+import torch.distributed as dist
+from torch.nn import Module, ModuleList
+import torch.nn.functional as F
+from uniperceiver.utils.events import get_event_storage
+from torch.cuda.amp import autocast
+
+
+
+if TYPE_CHECKING:
+ Base = Module[Tensor]
+else:
+ Base = Module
+
+uniform_map: Dict[torch.device, Callable] = {}
+gumbel_map: Dict[torch.device, Callable] = {}
+normal_map: Dict[torch.device, Callable] = {}
+exp_selection_uniform_map: Dict[torch.device, Callable] = {}
+
+
+import torch.distributed.nn
+from uniperceiver.utils import comm
+from uniperceiver.modeling.layers import FP16LayerNorm
+
+
+
+def multiplicative_jitter(x, device: torch.device, epsilon=1e-2):
+ """
+ Modified from switch transformer paper. mesh transformers
+ Multiply values by a random number between 1-epsilon and 1+epsilon.
+ Makes models more resilient to rounding errors introduced by bfloat16.
+ This seems particularly important for logits.
+ Args:
+ x: a torch.tensor
+ device: torch.device
+ epsilon: a floating point value
+ Returns:
+ a jittered x.
+ """
+ if epsilon == 0:
+ return x
+ uniform = uniform_map.get(device)
+ if uniform is None:
+ uniform = torch.distributions.uniform.Uniform(
+ low=torch.tensor(1.0 - epsilon, device=device),
+ high=torch.tensor(1.0 + epsilon,
+ device=device)).rsample # type: ignore
+ uniform_map[device] = uniform
+ return x * uniform(x.shape)
+
+
+def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor:
+ gumbel = gumbel_map.get(device)
+ if gumbel is None:
+ one = torch.tensor(1.0, device=device)
+ zero = torch.tensor(0.0, device=device)
+ gumbel = torch.distributions.gumbel.Gumbel(zero,
+ one).rsample # type: ignore
+ gumbel_map[device] = gumbel
+ return gumbel(shape)
+
+
+def normal_rsample(shape: Tuple, device: torch.device, num_expert: int) -> Tensor:
+ normal = normal_map.get(device)
+ if normal is None:
+ std = torch.tensor(1.0/num_expert, device=device)
+ mean = torch.tensor(0.0, device=device)
+ normal = torch.distributions.normal.Normal(mean, std).rsample # type: ignore
+ normal_map[device] = normal
+ return normal(shape)
+
+
+def one_hot_with_dtype(data, num_classes, dtype):
+ result = torch.zeros([data.size(0), num_classes],
+ device=data.device,
+ dtype=dtype)
+ result.scatter_(1, data.unsqueeze(-1), 1)
+ return result
+
+@torch.jit.script
+def _top_idx(source, k):
+ return torch.topk(source, k=k, dim=0)[1]
+
+
+@torch.jit.script
+def _one_hot_to_float(x, num_classes):
+ return F.one_hot(x, num_classes=num_classes).float()
+
+
+
+
+class TopKGate(nn.Module):
+ """Gate module which implements Top2Gating as described in Gshard_.
+ ::
+
+ gate = TopKGate(model_dim, num_experts)
+ l_aux, combine_weights, dispatch_mask = gate(input)
+
+ .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf
+
+ Args:
+ model_dim (int):
+ size of model embedding dimension
+ num_experts (ints):
+ number of experts in model
+ """
+
+ # wg: torch.nn.Linear
+
+ def __init__(self,
+ model_dim: int,
+ num_experts: int,
+ k: int = 1,
+ noisy_gate_policy: Optional[str] = None,
+ cfg: dict = None,
+ moe_type: str = None,
+ **kwargs):
+ super().__init__( )
+
+ if k != 1 and k != 2:
+ raise ValueError('Only top-1 and top-2 gatings are supported.')
+ self.model_dim = model_dim
+ self.k = k
+
+ self.cfg = cfg
+
+ self.noisy_gate_policy = noisy_gate_policy
+ self.noise_std = self.cfg.MOE.NOISE_STD
+
+ self.batch_prioritized_routing = self.cfg.MOE.BATCH_PRIO
+ self.gate = self.cfg.MOE.GATE_TYPE
+
+
+
+ self.layer_type = kwargs.pop('moe_type', 'ffn')
+
+ self.tag_transform_enable = self.cfg.MOE.TAG_Transform
+
+ self.moe_type = moe_type
+
+ if self.cfg.SOLVER.FORCE_LN_FP16:
+ LayerNormModule = FP16LayerNorm
+ else:
+ LayerNormModule = torch.nn.LayerNorm
+ if self.tag_transform_enable and self.cfg.MOE.TAG_Transform_ACT:
+ self.tag_transform = torch.nn.Sequential(torch.nn.Linear(self.cfg.MOE.ATTRIBUTE_LENGTH, self.model_dim), torch.nn.GELU(),
+ LayerNormModule(self.model_dim))
+ else:
+ self.tag_transform = torch.nn.Sequential(torch.nn.Linear(self.cfg.MOE.ATTRIBUTE_LENGTH, self.model_dim), LayerNormModule(self.model_dim))
+
+ self.wg = torch.nn.Linear(model_dim, num_experts, bias=False).float()
+
+
+
+
+
+ pass
+
+
+
+ def tag_gate(self, x, data_type=None, moe_embedding:torch.Tensor = None, **kwargs):
+ if self.cfg.MODEL.TAG_TRANSFORM_FP32:
+ with autocast(enabled=False):
+ gate_embed = self.tag_transform.float()(moe_embedding.float())
+ else:
+ gate_embed = self.tag_transform(moe_embedding)
+
+
+ return gate_embed
+
+
+
+
+ def forward(
+ self,
+ input,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Tensor]: # type: ignore
+
+
+ if self.tag_transform_enable:
+ input = self.tag_gate(input, **kwargs)
+ if self.wg.weight.dtype != torch.float32:
+ self.wg = self.wg.float()
+ input_fp32 = input.float()
+ # input jittering
+ if self.noisy_gate_policy == 'Jitter' and self.training:
+ input_fp32 = multiplicative_jitter(input_fp32, device=input.device)
+ with autocast(enabled=not self.cfg.MODEL.GATE_FP32):
+ if self.cfg.SOLVER.FORCE_WG_RECAST:
+ # used for dbeugging only
+ logits = self.wg.half().float()(input_fp32)
+ else:
+ logits = self.wg(input_fp32)
+
+ if self.k == 1 and self.gate == 'deepspeed':
+ gate_output = self.top1gating(
+ logits,
+ self.noisy_gate_policy if self.training else None,
+ **kwargs)
+
+ # tutel gate function
+ else:
+ gate_output = self.top2gating(
+ logits,
+ self.noisy_gate_policy if self.training else None,
+ **kwargs )
+
+
+ return gate_output
+
+ def load_balance(self, gates, mask1, num_experts, data_type=None):
+ # Compute l_aux
+ if self.balance_loss and self.training:
+ # TODO: for retrieval task, these maybe some gpu do not have this input
+
+ if data_type == 'INPUT':
+ if comm._LOCAL_IMAGE_LENGTH > 0 and not comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
+ # input image features only
+ me = gates.sum(dim=0)
+ ce = mask1.sum(dim=0)
+
+ # maybe has retrieval pair
+ if comm._MOE_TARGET_MECE_LIST.get(str(comm._LOCAL_CURRENT_LAYER)+'_'+self.layer_type, None) is not None:
+ # if len(comm._MOE_TARGET_MECE_LIST) > 0:
+ me_t, ce_t = comm._MOE_TARGET_MECE_LIST[
+ str(comm._LOCAL_CURRENT_LAYER) + '_' +
+ self.layer_type]
+ me = me + me_t
+ ce = ce + ce_t
+
+ me = me * self.task_weights[comm._LOCAL_CURRENT_TASK]
+ ce = ce * self.task_weights[comm._LOCAL_CURRENT_TASK]
+
+ elif comm._LOCAL_IMAGE_LENGTH > 0 and comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
+
+ # sum of these two distribution from two modalities
+ me = gates.sum(dim=0)
+ ce = mask1.sum(dim=0)
+
+ me = me * self.task_weights[comm._LOCAL_CURRENT_TASK]
+ ce = ce * self.task_weights[comm._LOCAL_CURRENT_TASK]
+
+ elif comm._LOCAL_IMAGE_LENGTH <= 0 and comm._LOCAL_UTOKEN_LENGTH + comm._LOCAL_GTOKEN_LENGTH > 0:
+
+ me = gates.sum(
+ dim=0) * self.task_weights[comm._LOCAL_CURRENT_TASK]
+ ce = mask1.sum(
+ dim=0) * self.task_weights[comm._LOCAL_CURRENT_TASK]
+ # raise NotImplementedError
+ else:
+
+ raise NotImplementedError
+
+ elif data_type == 'TARGET':
+ # the retrieval embedding
+
+ # only remove the padding contributions
+
+ comm._MOE_TARGET_MECE_LIST[str(comm._LOCAL_CURRENT_LAYER) + '_' +self.layer_type] = [gates.sum(dim=0), mask1.sum(dim=0)]
+
+ elif data_type == 'IN_LABEL':
+ # remove paddings contributions
+
+ me = gates.sum(dim=0)
+ ce = mask1.sum(dim=0)
+
+ elif data_type == 'WORD_VOCAB':
+ # do not need padding mask
+ me = gates.sum(dim=0)
+ ce = mask1.sum(dim=0)
+ else:
+ raise NotImplementedError
+
+ # debug left
+
+ if not data_type == 'TARGET':
+ me = torch.distributed.nn.all_reduce(
+ me) / comm.get_world_size()
+ ce = torch.distributed.nn.all_reduce(
+ ce) / comm.get_world_size()
+
+ if data_type not in comm._MOE_LOSSES_COLLECTIONS[
+ 'exp_balance']:
+ comm._MOE_LOSSES_COLLECTIONS['exp_balance'][
+ data_type] = []
+ comm._MOE_LOSSES_COLLECTIONS['exp_balance'][
+ data_type].append([me, ce])
+
+
+ def top1gating(
+ self,
+ logits: Tensor,
+ noisy_gate_policy: Optional[str] = None,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Implements Top1Gating on logits."""
+
+ logits_w_noise = None
+ if noisy_gate_policy == 'RSample':
+ logits_w_noise = logits + gumbel_rsample(logits.shape,
+ device=logits.device)
+ elif noisy_gate_policy == 'vmoe':
+ num_experts = int(logits.shape[-1])
+ logits_w_noise = logits + normal_rsample(logits.shape,
+ device=logits.device,
+ num_expert=num_experts/self.noise_std)
+
+ # everything is in fp32 in this function
+ gates = F.softmax(logits, dim=1)
+ # Create a mask for 1st's expert per token
+ # noisy gating
+ indices1_s = torch.argmax(logits_w_noise if logits_w_noise is not None else gates, dim=1)
+
+ num_experts = int(gates.shape[1])
+ mask1 = F.one_hot(indices1_s, num_classes=num_experts)
+
+ # gating decisions
+ exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
+
+ self.load_balance(gates, mask1, num_experts)
+
+ self.tb_output(
+ mask1,
+ exp_counts,
+ gates=None
+ )
+
+ gates = (gates*mask1).sum(dim=1)
+ self.tb_output(mask1=None, exp_counts=None, gates=[gates])
+
+ return [indices1_s], [gates]
+
+
+
+
+ def top2gating(
+ self,
+ logits: Tensor,
+ noisy_gate_policy: Optional[str] = None,
+ **kwargs,
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ """Implements Top2Gating on logits."""
+ # everything is in fp32 in this function
+
+ num_experts = int(logits.shape[-1])
+
+ logits_w_noise = None
+ if noisy_gate_policy == 'RSample':
+ logits_w_noise = logits + gumbel_rsample(logits.shape,
+ device=logits.device) * self.noise_std
+ elif noisy_gate_policy == 'vmoe':
+ logits_w_noise = logits + normal_rsample(logits.shape,
+ device=logits.device,
+ num_expert=num_experts/self.noise_std)
+
+ # topk_indices = torch.topk(logits, self.k, dim=1).indices
+ topk_indices = torch.topk(
+ logits_w_noise
+ if logits_w_noise is not None else logits,
+ self.k,
+ dim=1).indices
+
+ indices_s = [x.view(-1) for x in topk_indices.chunk(self.k, dim=1)]
+ masks_se = [
+ one_hot_with_dtype(x, num_classes=num_experts, dtype=x.dtype)
+ for x in indices_s
+ ]
+
+
+ if noisy_gate_policy == 'vmoe':
+ gates = F.softmax(logits_w_noise, dim=1)
+
+ else:
+ gates = F.softmax(logits, dim=1)
+
+ # self.load_balance(gates, masks_se[0], num_experts)
+ gates_s = [(gates * x).sum(dim=1) for x in masks_se]
+
+ # gating decisions
+ exp_counts = torch.sum(masks_se[0], dim=0).detach().to('cpu')
+
+ # self.tb_output(masks_se[0], exp_counts, gates=None)
+ # if self.k>1:
+ # for k in range(1, self.k):
+ # self.tb_output(masks_se[k], torch.sum(masks_se[k], dim=0).detach().to('cpu'), None, postfix='_top{}'.format(k+1))
+
+
+
+ if self.k > 1:
+
+ # Normalize Gate
+ denom_s = torch.clamp(sum(gates_s),
+ min=torch.finfo(gates_s[0].dtype).eps)
+ gates_s = [x / denom_s for x in gates_s]
+
+ # self.tb_output(mask1=None, exp_counts=None, gates=gates_s)
+
+ return indices_s, gates_s
+
+
+ def tb_output(self, data_type=None, mask1=None, exp_counts=None, gates=None, postfix=''):
+ if self.training:
+ storage = get_event_storage()
+ else:
+ return
+
+ if not (comm._LOCAL_CURRENT_TASK == 'imagenet' or comm._LOCAL_CURRENT_TASK.startswith('bookswiki') or comm._LOCAL_CURRENT_TASK.startswith('cc3m') or comm._LOCAL_CURRENT_TASK.startswith('cc12m') or comm._LOCAL_CURRENT_TASK.startswith('tqa')):
+ # to save time
+ return
+
+ if (storage._iter+1)%(comm._EXPERT_LOG_INTERVAL//10) != 0:
+ # to save time
+ return
+
+
+ if storage is not None and comm.is_main_process():
+ # pass
+ # for each expert
+
+ if gates is not None:
+ if data_type == "INPUT" and comm._LOCAL_IMAGE_LENGTH > 0:
+
+
+ gate_logs = {
+ "logits_layer{}_expert_{}/top{}_{}_{}_v".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type,
+ e_id+1, comm._LOCAL_CURRENT_TASK,
+ data_type): ratio[0]
+ for e_id, ratio in enumerate(gates)
+ }
+ storage.put_scalars(**gate_logs, avg_hint=True)
+
+
+ if gates[0].shape[0] > 1:
+ gates_t_logs = {
+ "logits_layer{}_expert_{}/top{}_{}_{}_t".
+ format(comm._LOCAL_CURRENT_LAYER,
+ self.layer_type, e_id+1,
+ comm._LOCAL_CURRENT_TASK,
+ data_type): ratio[1]
+ for e_id, ratio in enumerate(gates)
+ }
+ storage.put_scalars(**gates_t_logs, avg_hint=True)
+
+ elif data_type in ['IN_LABEL', 'WORD_VOCAB']:
+
+ gates_logs = {
+ "logits_layer{}_expert_{}/top{}_{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type,
+ e_id+1, data_type): ratio[0]
+ for e_id, ratio in enumerate(gates)
+ }
+ storage.put_scalars(**gates_logs, avg_hint=True)
+
+ else:
+
+ gates_logs = {
+ "layer{}_expert_{}/top{}_{}_{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type,
+ e_id+1, comm._LOCAL_CURRENT_TASK,
+ data_type): ratio[0]
+ for e_id, ratio in enumerate(gates)
+ }
+ storage.put_scalars(**gates_logs, avg_hint=True)
+
+ else:
+
+ if data_type == "INPUT" and comm._LOCAL_IMAGE_LENGTH > 0:
+
+ exp_counts_v = mask1[0]
+ exp_count_logs = {
+ "layer{}_expert_{}/E{}_{}_{}_v{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
+ comm._LOCAL_CURRENT_TASK, data_type,
+ postfix): ratio
+ for e_id, ratio in enumerate((exp_counts_v /
+ exp_counts_v.sum()).tolist())
+ }
+ storage.put_scalars(**exp_count_logs, avg_hint=True)
+
+ if mask1.size(0)>1:
+ exp_counts_t = mask1[1]
+ exp_count_logs = {
+ "layer{}_expert_{}/E{}_{}_{}_t{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
+ comm._LOCAL_CURRENT_TASK,
+ data_type, postfix): ratio
+ for e_id, ratio in enumerate((
+ exp_counts_t / exp_counts_t.sum()).tolist())
+ }
+ storage.put_scalars(**exp_count_logs, avg_hint=True)
+
+
+
+ elif data_type in ['IN_LABEL', 'WORD_VOCAB']:
+ exp_count_logs = {
+ "layer{}_expert_{}/E{}_{}{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
+ data_type, postfix): ratio
+ for e_id, ratio in enumerate((exp_counts /
+ exp_counts.sum()).tolist())
+ }
+ storage.put_scalars(**exp_count_logs, avg_hint=True)
+
+ else:
+ exp_count_logs = {
+ "layer{}_expert_{}/E{}_{}_{}{}".format(
+ comm._LOCAL_CURRENT_LAYER, self.layer_type, e_id,
+ comm._LOCAL_CURRENT_TASK, data_type,
+ postfix): ratio
+ for e_id, ratio in enumerate((exp_counts /
+ exp_counts.sum()).tolist())
+ }
+ storage.put_scalars(**exp_count_logs, avg_hint=True)
diff --git a/uniperceiver/task_moe/layer.py b/uniperceiver/task_moe/layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..080c16819ea25862012430a487864d4e117e4eb9
--- /dev/null
+++ b/uniperceiver/task_moe/layer.py
@@ -0,0 +1,103 @@
+'''
+Copyright 2020 The Microsoft DeepSpeed Team
+'''
+
+import torch.nn.init as init
+import torch
+from torch import nn
+import torch.distributed as dist
+
+
+
+
+from .gate import TopKGate
+import copy
+import typing
+
+from .experts import FusedExperts as Experts
+
+
+class TaskMoE(torch.nn.Module):
+ def __init__(self,
+ hidden_size,
+ expert,
+ num_experts=1,
+ k=1,
+ capacity_factor=1.,
+ eval_capacity_factor=1.,
+ min_capacity=4,
+ noisy_gate_policy: typing.Optional[str] = None,
+ drop_tokens: bool = True,
+ use_rts=True,
+ use_tutel: bool = False,
+ cfg=None):
+ """Initialize an MoE layer.
+
+ Arguments:
+ hidden_size (int): the hidden dimension of the model, importantly this is also the input and output dimension.
+
+ expert (torch.nn.Module): the torch module that defines the expert (e.g., MLP, torch.linear).
+
+ num_experts (int, optional): default=1, the total number of experts per layer.
+
+ k (int, optional): default=1, top-k gating value, only supports k=1 or k=2.
+
+ capacity_factor (float, optional): default=1.0, the capacity of the expert at training time.
+
+ eval_capacity_factor (float, optional): default=1.0, the capacity of the expert at eval time.
+
+ min_capacity (int, optional): default=4, the minimum capacity per expert regardless of the capacity_factor.
+
+ noisy_gate_policy (str, optional): default=None, noisy gate policy, valid options are 'Jitter', 'RSample' or 'None'.
+
+ drop_tokens (bool, optional): default=True, whether to drop tokens - (setting to False is equivalent to infinite capacity).
+
+ use_rts (bool, optional): default=True, whether to use Random Token Selection.
+
+ use_tutel (bool, optional): default=False, whether to use Tutel optimizations (if installed).
+ """
+
+ super().__init__()
+
+
+ self.num_experts = num_experts
+
+ if isinstance(expert, nn.Linear):
+ self.expert_type = 'linear'
+ elif isinstance(expert, nn.MultiheadAttention):
+ self.expert_type = 'attention'
+ else:
+ raise NotImplementedError('please check expert type')
+
+ experts = Experts(expert, cfg, num_experts)
+
+ self.gate = TopKGate(hidden_size,
+ num_experts,
+ k,
+ noisy_gate_policy,
+ cfg,
+ moe_type=self.expert_type)
+
+
+ self.experts = experts
+
+
+
+ def forward(self, hidden_states, gate_decision=None, **kwargs):
+ """ MoE forward
+ Arguments:
+ hidden_states (Tensor): input to the layer
+ Returns:
+ A tuple including output
+ * output (Tensor): output of the model
+ """
+
+
+ if gate_decision is not None:
+ top_indices, gates = gate_decision
+ else:
+ top_indices, gates = self.gate(hidden_states, **kwargs)
+
+ expert_output = self.experts(hidden_states, top_indices, gates, **kwargs)
+
+ return expert_output, [top_indices, gates]
diff --git a/uniperceiver/tokenization/__init__.py b/uniperceiver/tokenization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..885c11c0467241be116f0ef876397ed336a7e4c2
--- /dev/null
+++ b/uniperceiver/tokenization/__init__.py
@@ -0,0 +1,2 @@
+
+from .tokenization_clip import ClipTokenizer
\ No newline at end of file
diff --git a/uniperceiver/tokenization/bpe_simple_vocab_16e6.txt.gz b/uniperceiver/tokenization/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113
--- /dev/null
+++ b/uniperceiver/tokenization/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/uniperceiver/tokenization/tokenization_clip.py b/uniperceiver/tokenization/tokenization_clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..84ba60e8733787f390d40f9157cbadc12ca78242
--- /dev/null
+++ b/uniperceiver/tokenization/tokenization_clip.py
@@ -0,0 +1,194 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+from torch._C import Value
+import pandas as pd
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class ClipTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>', '<|mask|>','<|gen|>', '<|spe|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.encoder = pd.Series(list(self.encoder.values()), index=self.encoder.keys())
+
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.decoder = pd.Series(list(self.decoder.values()), index=self.decoder.keys())
+
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>', '<|mask|>': '<|mask|>', '<|gen|>': '<|gen|>', '<|spe|>': '<|spe|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|<\|gen\|>|<\|spe\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ self.vocab = self.encoder
+ self.ids_to_tokens = self.decoder
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
+
+ def basic_tokenize(self, text):
+ text = whitespace_clean(basic_clean(text)).lower()
+ return list(re.findall(self.pat, text))
+
+ def encode_basic_tokenized_token(self, token):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens = [self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')]
+ return bpe_tokens
+
+ def tokenize(self, text):
+ tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
+ return tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ return [self.encoder[bpe_token] for bpe_token in tokens]
+
+ def add_special_tokens_single_sentence(self, token_ids, start_type='SoT'):
+ if start_type == 'SoT':
+ return [self.encoder['<|startoftext|>']] + token_ids + [self.encoder['<|endoftext|>']]
+ elif start_type == 'Gen':
+ return [self.encoder['<|gen|>']] + token_ids + [self.encoder['<|endoftext|>']]
+ elif start_type == 'SPE':
+ return token_ids
+ else:
+ raise ValueError
+
+ def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, start_type='SoT'):
+ sep = [self.encoder['<|endoftext|>']]
+ if start_type == 'SoT':
+ cls = [self.encoder['<|startoftext|>']]
+ elif start_type == 'Gen':
+ cls = [self.encoder['<|gen|>']]
+ elif start_type == 'SPE':
+ cls = []
+ else:
+ raise ValueError
+ return cls + token_ids_0 + sep + token_ids_1
+
+ def get_cls_token_id(self):
+ return self.encoder['<|startoftext|>']
+
+ def get_eos_token_id(self):
+ return self.encoder['<|endoftext|>']
+
+ def convert_tokens_to_string(self, tokens):
+ text = ''.join(tokens).strip()
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
diff --git a/uniperceiver/utils/__init__.py b/uniperceiver/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/uniperceiver/utils/collect_env.py b/uniperceiver/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..275d44d495e02c5c111713a9cb865156d3dbef85
--- /dev/null
+++ b/uniperceiver/utils/collect_env.py
@@ -0,0 +1,242 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import importlib
+import numpy as np
+import os
+import re
+import subprocess
+import sys
+from collections import defaultdict
+import PIL
+import torch
+import torchvision
+from tabulate import tabulate
+
+__all__ = ["collect_env_info"]
+
+
+def collect_torch_env():
+ try:
+ import torch.__config__
+
+ return torch.__config__.show()
+ except ImportError:
+ # compatible with older versions of pytorch
+ from torch.utils.collect_env import get_pretty_env_info
+
+ return get_pretty_env_info()
+
+
+def get_env_module():
+ var_name = "DETECTRON2_ENV_MODULE"
+ return var_name, os.environ.get(var_name, "")
+
+
+def detect_compute_compatibility(CUDA_HOME, so_file):
+ try:
+ cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
+ if os.path.isfile(cuobjdump):
+ output = subprocess.check_output(
+ "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
+ )
+ output = output.decode("utf-8").strip().split("\n")
+ arch = []
+ for line in output:
+ line = re.findall(r"\.sm_([0-9]*)\.", line)[0]
+ arch.append(".".join(line))
+ arch = sorted(set(arch))
+ return ", ".join(arch)
+ else:
+ return so_file + "; cannot find cuobjdump"
+ except Exception:
+ # unhandled failure
+ return so_file
+
+
+def collect_env_info():
+ has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
+ torch_version = torch.__version__
+
+ # NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional
+ from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
+
+ has_rocm = False
+ if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
+ has_rocm = True
+ has_cuda = has_gpu and (not has_rocm)
+
+ data = []
+ data.append(("sys.platform", sys.platform)) # check-template.yml depends on it
+ data.append(("Python", sys.version.replace("\n", "")))
+ data.append(("numpy", np.__version__))
+
+ try:
+ import detectron2 # noqa
+
+ data.append(
+ ("detectron2", detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__))
+ )
+ except ImportError:
+ data.append(("detectron2", "failed to import"))
+ except AttributeError:
+ data.append(("detectron2", "imported a wrong installation"))
+
+ try:
+ import detectron2._C as _C
+ except ImportError as e:
+ data.append(("detectron2._C", f"not built correctly: {e}"))
+
+ # print system compilers when extension fails to build
+ if sys.platform != "win32": # don't know what to do for windows
+ try:
+ # this is how torch/utils/cpp_extensions.py choose compiler
+ cxx = os.environ.get("CXX", "c++")
+ cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True)
+ cxx = cxx.decode("utf-8").strip().split("\n")[0]
+ except subprocess.SubprocessError:
+ cxx = "Not found"
+ data.append(("Compiler ($CXX)", cxx))
+
+ if has_cuda and CUDA_HOME is not None:
+ try:
+ nvcc = os.path.join(CUDA_HOME, "bin", "nvcc")
+ nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True)
+ nvcc = nvcc.decode("utf-8").strip().split("\n")[-1]
+ except subprocess.SubprocessError:
+ nvcc = "Not found"
+ data.append(("CUDA compiler", nvcc))
+ if has_cuda and sys.platform != "win32":
+ try:
+ so_file = importlib.util.find_spec("detectron2._C").origin
+ except (ImportError, AttributeError):
+ pass
+ else:
+ data.append(
+ ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, so_file))
+ )
+ else:
+ # print compilers that are used to build extension
+ data.append(("Compiler", _C.get_compiler_version()))
+ data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip
+ if has_cuda and getattr(_C, "has_cuda", lambda: True)():
+ data.append(
+ ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__))
+ )
+
+ data.append(get_env_module())
+ data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
+ data.append(("PyTorch debug build", torch.version.debug))
+
+ if not has_gpu:
+ has_gpu_text = "No: torch.cuda.is_available() == False"
+ else:
+ has_gpu_text = "Yes"
+ data.append(("GPU available", has_gpu_text))
+ if has_gpu:
+ devices = defaultdict(list)
+ for k in range(torch.cuda.device_count()):
+ cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k)))
+ name = torch.cuda.get_device_name(k) + f" (arch={cap})"
+ devices[name].append(str(k))
+ for name, devids in devices.items():
+ data.append(("GPU " + ",".join(devids), name))
+
+ if has_rocm:
+ msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else ""
+ data.append(("ROCM_HOME", str(ROCM_HOME) + msg))
+ else:
+ try:
+ from torch.utils.collect_env import get_nvidia_driver_version, run as _run
+
+ data.append(("Driver version", get_nvidia_driver_version(_run)))
+ except Exception:
+ pass
+ msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else ""
+ data.append(("CUDA_HOME", str(CUDA_HOME) + msg))
+
+ cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+ if cuda_arch_list:
+ data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
+ data.append(("Pillow", PIL.__version__))
+
+ try:
+ data.append(
+ (
+ "torchvision",
+ str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
+ )
+ )
+ if has_cuda:
+ try:
+ torchvision_C = importlib.util.find_spec("torchvision._C").origin
+ msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
+ data.append(("torchvision arch flags", msg))
+ except (ImportError, AttributeError):
+ data.append(("torchvision._C", "Not found"))
+ except AttributeError:
+ data.append(("torchvision", "unknown"))
+
+ try:
+ import fvcore
+
+ data.append(("fvcore", fvcore.__version__))
+ except (ImportError, AttributeError):
+ pass
+
+ try:
+ import iopath
+
+ data.append(("iopath", iopath.__version__))
+ except (ImportError, AttributeError):
+ pass
+
+ try:
+ import cv2
+
+ data.append(("cv2", cv2.__version__))
+ except (ImportError, AttributeError):
+ data.append(("cv2", "Not found"))
+ env_str = tabulate(data) + "\n"
+ env_str += collect_torch_env()
+ return env_str
+
+
+def test_nccl_ops():
+ num_gpu = torch.cuda.device_count()
+ if os.access("/tmp", os.W_OK):
+ import torch.multiprocessing as mp
+
+ dist_url = "file:///tmp/nccl_tmp_file"
+ print("Testing NCCL connectivity ... this should not hang.")
+ mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False)
+ print("NCCL succeeded.")
+
+
+def _test_nccl_worker(rank, num_gpu, dist_url):
+ import torch.distributed as dist
+
+ dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu)
+ dist.barrier(device_ids=[rank])
+
+
+if __name__ == "__main__":
+ try:
+ from uniperceiver.utils.collect_env import collect_env_info as f
+
+ print(f())
+ except ImportError:
+ print(collect_env_info())
+
+ if torch.cuda.is_available():
+ num_gpu = torch.cuda.device_count()
+ for k in range(num_gpu):
+ device = f"cuda:{k}"
+ try:
+ x = torch.tensor([1, 2.0], dtype=torch.float32)
+ x = x.to(device)
+ except Exception as e:
+ print(
+ f"Unable to copy tensor to device={device}: {e}. "
+ "Your CUDA environment is broken."
+ )
+ if num_gpu > 1:
+ test_nccl_ops()
diff --git a/uniperceiver/utils/comm.py b/uniperceiver/utils/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5ba38ff1b5adb66305a3f27a08588b01752d361
--- /dev/null
+++ b/uniperceiver/utils/comm.py
@@ -0,0 +1,316 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import logging
+import numpy as np
+import pickle
+import torch
+import torch.distributed as dist
+
+_LOCAL_PROCESS_GROUP = None
+
+_CAPTION_GEN_MODE = False
+
+temp_dir = TEMP_DIR = './data/temp'
+IDS = 'IDS'
+image_features = 'image_features'
+text_features = 'text_features'
+
+
+old_checkpoint = True
+
+
+"""
+A torch process group which only includes processes that on the same machine as the current process.
+This variable is set when processes are spawned by `launch()` in "engine/launch.py".
+"""
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size() -> int:
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank() -> int:
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ # assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def is_main_process() -> bool:
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+@functools.lru_cache()
+def _get_global_gloo_group():
+ """
+ Return a process group based on gloo backend, containing all the ranks
+ The result is cached.
+ """
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def _serialize_to_tensor(data, group):
+ backend = dist.get_backend(group)
+ assert backend in ["gloo", "nccl"]
+ device = torch.device("cpu" if backend == "gloo" else "cuda")
+
+ buffer = pickle.dumps(data)
+ if len(buffer) > 1024 ** 3:
+ logger = logging.getLogger(__name__)
+ logger.warning(
+ "Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
+ get_rank(), len(buffer) / (1024 ** 3), device
+ )
+ )
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to(device=device)
+ return tensor
+
+
+def _pad_to_largest_tensor(tensor, group):
+ """
+ Returns:
+ list[int]: size of the tensor, on each rank
+ Tensor: padded tensor that has the max size
+ """
+ world_size = dist.get_world_size(group=group)
+ assert (
+ world_size >= 1
+ ), "comm.gather/all_gather must be called from ranks within the given group!"
+ local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
+ size_list = [
+ torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
+ ]
+ dist.all_gather(size_list, local_size, group=group)
+ size_list = [int(size.item()) for size in size_list]
+
+ max_size = max(size_list)
+
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ if local_size != max_size:
+ padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
+ tensor = torch.cat((tensor, padding), dim=0)
+ return size_list, tensor
+
+
+def all_gather(data, group=None):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group) == 1:
+ return [data]
+
+ tensor = _serialize_to_tensor(data, group)
+
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.all_gather(tensor_list, tensor, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def gather(data, dst=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = _get_global_gloo_group()
+ if dist.get_world_size(group=group) == 1:
+ return [data]
+ rank = dist.get_rank(group=group)
+
+ tensor = _serialize_to_tensor(data, group)
+ size_list, tensor = _pad_to_largest_tensor(tensor, group)
+
+ # receiving Tensor from all ranks
+ if rank == dst:
+ max_size = max(size_list)
+ tensor_list = [
+ torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
+ ]
+ dist.gather(tensor, tensor_list, dst=dst, group=group)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+ return data_list
+ else:
+ dist.gather(tensor, [], dst=dst, group=group)
+ return []
+
+def broadcast_object(data, src=0, group=None):
+ """
+ Run gather on arbitrary picklable data (not necessarily tensors).
+
+ Args:
+ data: any picklable object
+ dst (int): destination rank
+ group: a torch process group. By default, will use a group which
+ contains all ranks on gloo backend.
+
+ Returns:
+ list[data]: on dst, a list of data gathered from each rank. Otherwise,
+ an empty list.
+ """
+ # if get_world_size() == 1:
+ # return data
+ # if group is None:
+ # group = _get_global_gloo_group()
+ # if dist.get_world_size(group=group) == 1:
+ # return data
+
+ if not isinstance(data, list):
+ data_list = [data]
+ dist.broadcast_object_list(data_list, src=src, group=group)
+ return data_list[0]
+ else:
+ dist.broadcast_object_list(data, src=src, group=group)
+ return data
+ return data
+
+
+
+def shared_random_seed():
+ """
+ Returns:
+ int: a random number that is the same across all workers.
+ If workers need a shared RNG, they can use this shared seed to
+ create one.
+
+ All workers must call this function, otherwise it will deadlock.
+ """
+ ints = np.random.randint(2 ** 31)
+ all_ints = all_gather(ints)
+ return all_ints[0]
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the reduced results.
+
+ Args:
+ input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
+ average (bool): whether to do average or sum
+
+ Returns:
+ a dict with the same keys as input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
+
+def unwrap_model(model):
+ return model.module if hasattr(model, 'module') else model
\ No newline at end of file
diff --git a/uniperceiver/utils/engine_util.py b/uniperceiver/utils/engine_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..edbe1013822d02cf64dbd647b6ede46daab58cfe
--- /dev/null
+++ b/uniperceiver/utils/engine_util.py
@@ -0,0 +1,98 @@
+import uniperceiver.utils.comm as comm
+import torch
+import numpy as np
+from uniperceiver.utils.events import get_event_storage
+from typing import Dict
+from uniperceiver.datasets import (
+ build_standard_valtest_loader,
+ build_unified_train_loader,
+)
+import weakref
+
+def write_metrics(loss_dict: Dict[str, torch.Tensor],
+ data_time: float,
+ prefix: str = "",
+ ):
+ """
+ Args:
+ loss_dict (dict): dict of scalar losses
+ data_time (float): time taken by the dataloader iteration
+ """
+ metrics_dict = {}
+ for k, v in loss_dict.items():
+ if isinstance(v, torch.Tensor):
+ metrics_dict.update({k: v.detach().cpu().item()})
+ else:
+ metrics_dict.update({k: v})
+ metrics_dict["data_time"] = data_time
+
+ # Gather metrics among all workers for logging
+ # This assumes we do DDP-style training, which is currently the only
+ # supported method in detectron2.
+ all_metrics_dict = [metrics_dict]
+ if comm.is_main_process():
+ # print(all_metrics_dict)
+ storage = get_event_storage()
+
+ # data_time among workers can have high variance. The actual latency
+ # caused by data_time is the maximum among workers.
+ data_time = np.max([x.pop("data_time") for x in all_metrics_dict])
+ storage.put_scalar("data_time", data_time)
+ # average the rest metrics
+ metrics_dict = {
+ k: np.mean([x[k] for x in all_metrics_dict])
+ for k in all_metrics_dict[0].keys()
+ }
+ total_losses_reduced = sum(metrics_dict.values())
+ storage.put_scalar("{}total_loss".format(prefix),
+ total_losses_reduced)
+ if len(metrics_dict) > 1:
+ for k, v in metrics_dict.items():
+ if k != 'null_loss':
+ storage.put_scalar(f'{prefix}{k}', v)
+
+def build_writers(cfg, max_iter):
+ from uniperceiver.engine.defaults import default_writers
+ return default_writers(cfg.OUTPUT_DIR, max_iter)
+
+def build_train_loader(cfg, task_cfg, model):
+ loader = dict()
+ if cfg.DATALOADER.UNIFIED_DATASET:
+ loader = build_unified_train_loader(cfg, task_cfg, model=weakref.proxy(comm.unwrap_model(model)) if cfg.DATALOADER.LOAD_INLABEL else None)
+ return loader
+ else:
+ raise NotImplementedError('please use unified dataset.')
+
+def build_test_loader(cfg, task_cfg):
+ loaders = dict()
+ #TODO: move multi-gpu eval in config file
+ for name, new_cfg in task_cfg.items():
+ multi_gpu = name in [
+ 'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption',
+ 'flickr30k_caption', 'K700_retrieve', 'imagenet_caption'
+ ]
+ loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='test', multi_gpu_eval=multi_gpu)
+ return loaders
+
+def build_val_loader(cfg, task_cfg):
+ loaders = dict()
+ for name, new_cfg in task_cfg.items():
+ #TODO: move multi-gpu eval in config file
+ multi_gpu = name in [
+ 'K400_retrieve', 'imagenet', 'vqa', 'mscoco_caption',
+ 'flickr30k_caption', 'K700_retrieve', 'imagenet_caption'
+ ]
+ loaders[name] = build_standard_valtest_loader(new_cfg, task_cfg, stage='val', multi_gpu_eval=multi_gpu)
+ return loaders
+
+def get_batch_data(cfg, train_data_loader_iter, train_data_loader):
+ if not cfg.DATALOADER.FAKE_DATA:
+ try:
+ data = next(train_data_loader_iter)
+ except StopIteration:
+ train_data_loader_iter = iter(train_data_loader)
+ data = next(train_data_loader_iter)
+ else:
+ # fake data
+ bs = 32
+ return data
diff --git a/uniperceiver/utils/env.py b/uniperceiver/utils/env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e933b52509b2399cc59db2b91f85d0658fe74eb0
--- /dev/null
+++ b/uniperceiver/utils/env.py
@@ -0,0 +1,215 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import importlib
+import importlib.util
+import logging
+import numpy as np
+import os
+import random
+import sys
+from datetime import datetime
+import torch
+import socket
+import subprocess
+import time
+from . import comm
+
+__all__ = ["seed_all_rng"]
+
+
+TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
+"""
+PyTorch version as a tuple of 2 ints. Useful for comparison.
+"""
+
+
+def seed_all_rng(seed=None):
+ """
+ Set the random seed for the RNG in torch, numpy and python.
+
+ Args:
+ seed (int): if None, will use a strong random seed.
+ """
+ if seed is None:
+ seed = (
+ os.getpid()
+ + int(datetime.now().strftime("%S%f"))
+ + int.from_bytes(os.urandom(2), "big")
+ )
+ logger = logging.getLogger(__name__)
+ logger.info("Using a generated random seed {}".format(seed))
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ random.seed(seed)
+
+
+# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
+def _import_file(module_name, file_path, make_importable=False):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ if make_importable:
+ sys.modules[module_name] = module
+ return module
+
+
+def _configure_libraries():
+ """
+ Configurations for some libraries.
+ """
+ # An environment option to disable `import cv2` globally,
+ # in case it leads to negative performance impact
+ disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
+ if disable_cv2:
+ sys.modules["cv2"] = None
+ else:
+ # Disable opencl in opencv since its interaction with cuda often has negative effects
+ # This envvar is supported after OpenCV 3.4.0
+ os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
+ try:
+ import cv2
+
+ if int(cv2.__version__.split(".")[0]) >= 3:
+ cv2.ocl.setUseOpenCL(False)
+ except ModuleNotFoundError:
+ # Other types of ImportError, if happened, should not be ignored.
+ # Because a failed opencv import could mess up address space
+ # https://github.com/skvark/opencv-python/issues/381
+ pass
+
+ def get_version(module, digit=2):
+ return tuple(map(int, module.__version__.split(".")[:digit]))
+
+ # fmt: off
+ assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
+ import fvcore
+ assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2"
+ import yaml
+ assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
+ # fmt: on
+
+
+_ENV_SETUP_DONE = False
+
+
+def setup_environment():
+ """Perform environment setup work. The default setup is a no-op, but this
+ function allows the user to specify a Python source file or a module in
+ the $DETECTRON2_ENV_MODULE environment variable, that performs
+ custom setup work that may be necessary to their computing environment.
+ """
+ global _ENV_SETUP_DONE
+ if _ENV_SETUP_DONE:
+ return
+ _ENV_SETUP_DONE = True
+
+ _configure_libraries()
+
+ custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE")
+
+ if custom_module_path:
+ setup_custom_environment(custom_module_path)
+ else:
+ # The default setup is a no-op
+ pass
+
+
+def setup_custom_environment(custom_module):
+ """
+ Load custom environment setup by importing a Python source file or a
+ module, and run the setup function.
+ """
+ if custom_module.endswith(".py"):
+ module = _import_file("detectron2.utils.env.custom_module", custom_module)
+ else:
+ module = importlib.import_module(custom_module)
+ assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
+ "Custom environment module defined in {} does not have the "
+ "required callable attribute 'setup_environment'."
+ ).format(custom_module)
+ module.setup_environment()
+
+def check_dist_portfile():
+ if "SLURM_JOB_ID" in os.environ and int(os.environ["SLURM_PROCID"]) == 0: # rank==0
+ hostfile = "dist_url_" + os.environ["SLURM_JOBID"] + ".txt"
+ if os.path.exists(hostfile):
+ os.remove(hostfile)
+
+def find_free_port():
+ s = socket.socket()
+ s.bind(('', 0)) # Bind to a free port provided by the host.
+ return s.getsockname()[1] # Return the port number assigned.
+
+def init_distributed_mode(args):
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+ if int(os.environ["RANK"])==0:
+ print('this task is not running on cluster!')
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ['WORLD_SIZE'])
+ args.gpu = int(os.environ['LOCAL_RANK'])
+ args.dist_url = 'env://'
+ os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
+ addr = socket.gethostname()
+
+ elif 'SLURM_PROCID' in os.environ:
+ proc_id = int(os.environ['SLURM_PROCID'])
+ if proc_id==0:
+ print('Init dist using slurm!')
+ print("Job Id is {} on {} ".format(os.environ["SLURM_JOBID"], os.environ['SLURM_NODELIST']))
+ ntasks = int(os.environ['SLURM_NTASKS'])
+ node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ addr = subprocess.getoutput(
+ 'scontrol show hostname {} | head -n1'.format(node_list))
+ jobid = os.environ["SLURM_JOBID"]
+ hostfile = "dist_url_" + jobid + ".txt"
+ if proc_id == 0:
+ args.tcp_port = str( find_free_port())
+ print('write port {} to file: {} '.format(args.tcp_port, hostfile))
+ with open(hostfile, "w") as f:
+ f.write(args.tcp_port)
+ else:
+ print('read port from file: {}'.format(hostfile))
+ while not os.path.exists(hostfile):
+ time.sleep(1)
+ time.sleep(2)
+ with open(hostfile, "r") as f:
+ args.tcp_port = f.read()
+
+ os.environ['MASTER_PORT'] =str(args.tcp_port)
+ os.environ['MASTER_ADDR'] = addr
+ os.environ['WORLD_SIZE'] = str(ntasks)
+ os.environ['RANK'] = str(proc_id)
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
+ os.environ['LOCAL_SIZE'] = str(num_gpus)
+ args.dist_url = 'env://'
+ args.world_size = ntasks
+ args.rank = proc_id
+ args.gpu = proc_id % num_gpus
+ else:
+ print('Not using distributed mode')
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = 'nccl'
+ print('rank: {} addr: {} port: {}'.format(args.rank, addr, os.environ['MASTER_PORT']))
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+ world_size=args.world_size, rank=args.rank)
+ torch.distributed.barrier()
+ if 'SLURM_PROCID' in os.environ and args.rank == 0:
+ if os.path.isfile(hostfile):
+ os.remove(hostfile)
+ if args.world_size >= 1:
+ # Setup the local process group (which contains ranks within the same machine)
+ assert comm._LOCAL_PROCESS_GROUP is None
+ num_gpus = torch.cuda.device_count()
+ num_machines = args.world_size // num_gpus
+ for i in range(num_machines):
+ ranks_on_i = list(range(i * num_gpus, (i + 1) * num_gpus))
+ print('new_group: {}'.format(ranks_on_i))
+ pg = torch.distributed.new_group(ranks_on_i)
+ if args.rank in ranks_on_i:
+ # if i == os.environ['SLURM_NODEID']:
+ comm._LOCAL_PROCESS_GROUP = pg
diff --git a/uniperceiver/utils/events.py b/uniperceiver/utils/events.py
new file mode 100644
index 0000000000000000000000000000000000000000..29f70277153ba4982960bce62ce6fe1614f9fc03
--- /dev/null
+++ b/uniperceiver/utils/events.py
@@ -0,0 +1,535 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import datetime
+import json
+import logging
+import os
+import time
+from collections import defaultdict
+from contextlib import contextmanager
+from typing import Optional
+from fvcore.common.history_buffer import HistoryBuffer
+
+from uniperceiver.utils.file_io import PathManager
+
+from numpy import random
+
+from uniperceiver.utils import comm
+__all__ = [
+ "get_event_storage",
+ "JSONWriter",
+ "TensorboardXWriter",
+ "CommonMetricPrinter",
+ "EventStorage",
+]
+
+_CURRENT_STORAGE_STACK = []
+
+
+def get_event_storage():
+ """
+ Returns:
+ The :class:`EventStorage` object that's currently being used.
+ Throws an error if no :class:`EventStorage` is currently enabled.
+ """
+ assert len(
+ _CURRENT_STORAGE_STACK
+ ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
+ return _CURRENT_STORAGE_STACK[-1]
+
+
+class EventWriter:
+ """
+ Base class for writers that obtain events from :class:`EventStorage` and process them.
+ """
+
+ def write(self):
+ raise NotImplementedError
+
+ def close(self):
+ pass
+
+
+class JSONWriter(EventWriter):
+ """
+ Write scalars to a json file.
+
+ It saves scalars as one json per line (instead of a big json) for easy parsing.
+
+ Examples parsing such a json file:
+ ::
+ $ cat metrics.json | jq -s '.[0:2]'
+ [
+ {
+ "data_time": 0.008433341979980469,
+ "iteration": 19,
+ "loss": 1.9228371381759644,
+ "loss_box_reg": 0.050025828182697296,
+ "loss_classifier": 0.5316952466964722,
+ "loss_mask": 0.7236229181289673,
+ "loss_rpn_box": 0.0856662318110466,
+ "loss_rpn_cls": 0.48198649287223816,
+ "lr": 0.007173333333333333,
+ "time": 0.25401854515075684
+ },
+ {
+ "data_time": 0.007216215133666992,
+ "iteration": 39,
+ "loss": 1.282649278640747,
+ "loss_box_reg": 0.06222952902317047,
+ "loss_classifier": 0.30682939291000366,
+ "loss_mask": 0.6970193982124329,
+ "loss_rpn_box": 0.038663312792778015,
+ "loss_rpn_cls": 0.1471673548221588,
+ "lr": 0.007706666666666667,
+ "time": 0.2490077018737793
+ }
+ ]
+
+ $ cat metrics.json | jq '.loss_mask'
+ 0.7126231789588928
+ 0.689423680305481
+ 0.6776131987571716
+ ...
+
+ """
+
+ def __init__(self, json_file, window_size=20):
+ """
+ Args:
+ json_file (str): path to the json file. New data will be appended if the file exists.
+ window_size (int): the window size of median smoothing for the scalars whose
+ `smoothing_hint` are True.
+ """
+ self._file_handle = PathManager.open(json_file, "a")
+ self._window_size = window_size
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ to_save = defaultdict(dict)
+
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+ # keep scalars that have not been written
+ if iter <= self._last_write:
+ continue
+ to_save[iter][k] = v
+ if len(to_save):
+ all_iters = sorted(to_save.keys())
+ self._last_write = max(all_iters)
+
+ for itr, scalars_per_iter in to_save.items():
+ scalars_per_iter["iteration"] = itr
+ self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
+ self._file_handle.flush()
+ try:
+ os.fsync(self._file_handle.fileno())
+ except AttributeError:
+ pass
+
+ def close(self):
+ self._file_handle.close()
+
+
+class TensorboardXWriter(EventWriter):
+ """
+ Write all scalars to a tensorboard file.
+ """
+
+ def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
+ """
+ Args:
+ log_dir (str): the directory to save the output events
+ window_size (int): the scalars will be median-smoothed by this window size
+
+ kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
+ """
+ self._window_size = window_size
+
+ from torch.utils.tensorboard import SummaryWriter
+
+
+ self._writer = SummaryWriter(log_dir, max_queue=20000, flush_secs=10, **kwargs)
+ self._last_write = -1
+
+ def write(self):
+ storage = get_event_storage()
+ new_last_write = self._last_write
+ for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
+
+ if iter > self._last_write:
+ if k.startswith('total_grad'):
+ values_max_20iters = max([a[0] for a in storage.history(k).values()[-20:]])
+ self._writer.add_scalar(k.replace('total_grad', 'total_grad_max_20iters'), values_max_20iters, iter)
+ if k.endswith("time"):
+ values_max_20iters = max([a[0] for a in storage.history(k).values()[-20:]])
+ self._writer.add_scalar(k.replace('time', 'time_max_during_20iters'), values_max_20iters, iter)
+ self._writer.add_scalar(k, v, iter)
+ new_last_write = max(new_last_write, iter)
+
+ self._last_write = new_last_write
+
+ # storage.put_{image,histogram} is only meant to be used by
+ # tensorboard writer. So we access its internal fields directly from here.
+ if len(storage._vis_data) >= 1:
+ for img_name, img, step_num in storage._vis_data:
+ self._writer.add_image(img_name, img, step_num)
+ # Storage stores all image data and rely on this writer to clear them.
+ # As a result it assumes only one writer will use its image data.
+ # An alternative design is to let storage store limited recent
+ # data (e.g. only the most recent image) that all writers can access.
+ # In that case a writer may not see all image data if its period is long.
+ storage.clear_images()
+
+ if len(storage._histograms) >= 1:
+ for params in storage._histograms:
+ self._writer.add_histogram_raw(**params)
+ storage.clear_histograms()
+
+ def close(self):
+ if hasattr(self, "_writer"
+ ): # doesn't exist wheeventsn the code fails at import
+ self._writer.close()
+
+
+class CommonMetricPrinter(EventWriter):
+ """
+ Print **common** metrics to the terminal, including
+ iteration time, ETA, memory, all losses, and the learning rate.
+ It also applies smoothing using a window of 20 elements.
+
+ It's meant to print common metrics in common ways.
+ To print something in more customized ways, please implement a similar printer by yourself.
+ """
+
+ def __init__(self, max_iter: Optional[int] = None):
+ """
+ Args:
+ max_iter: the maximum number of iterations to train.
+ Used to compute ETA. If not given, ETA will not be printed.
+ """
+ self.logger = logging.getLogger(__name__)
+ self._max_iter = max_iter
+ self._last_write = None # (step, time) of last call to write(). Used to compute ETA
+
+ def _get_eta(self, storage) -> Optional[str]:
+ if self._max_iter is None:
+ return ""
+ iteration = storage.iter
+ try:
+ eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
+ storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
+ return str(datetime.timedelta(seconds=int(eta_seconds)))
+ except KeyError:
+ # estimate eta on our own - more noisy
+ eta_string = None
+ if self._last_write is not None:
+ estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
+ iteration - self._last_write[0]
+ )
+ eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ self._last_write = (iteration, time.perf_counter())
+ return eta_string
+
+ def write(self):
+ import torch
+
+ storage = get_event_storage()
+ iteration = storage.iter
+ if iteration == self._max_iter:
+ # This hook only reports training progress (loss, ETA, etc) but not other data,
+ # therefore do not write anything after training succeeds, even if this method
+ # is called.
+ return
+
+ try:
+ data_time = storage.history("data_time").avg(20)
+ except KeyError:
+ # they may not exist in the first few iterations (due to warmup)
+ # or when SimpleTrainer is not used
+ data_time = None
+ try:
+ iter_time = storage.history("time").global_avg()
+ except KeyError:
+ iter_time = None
+
+ try:
+ lr = "{:.5g}".format(storage.history("lr").latest())
+ except KeyError:
+ lr = "N/A"
+
+ eta_string = self._get_eta(storage)
+
+ if torch.cuda.is_available():
+ max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
+ else:
+ max_mem_mb = None
+
+ # NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
+ # {data_avg}{tocuda_time}{mixup_time}{fwd_time}{loss_time}{bwd_time}{zeroema_time}{step_time}\
+ # {before_time}{loop_time}{after_time} {total_time}
+ self.logger.info(
+ " {eta}iter: {iter} {losses} \t Global Info: {time}{data_time} lr: {lr} {memory} ".format(
+ eta=f"eta: {eta_string} " if eta_string else "",
+ iter=iteration,
+ losses=" ".join([
+ "{}: {:.4g}".format(k, v.median(20))
+ for k, v in storage.histories().items()
+ if not (k.startswith('layer') or k.startswith(
+ 'logits_layer') or k.startswith('update'))
+ ]),
+ time="time: {:.4f} ".format(iter_time)
+ if iter_time is not None else "",
+ data_time="data_time: {:.4f} ".format(data_time)
+ if data_time is not None else "",
+ lr=lr,
+ memory="max_mem: {:.0f}M".format(max_mem_mb)
+ if max_mem_mb is not None else "",
+ ))
+
+
+
+class EventStorage:
+ """
+ The user-facing class that provides metric storage functionalities.
+
+ In the future we may add support for storing / logging other types of data if needed.
+ """
+
+ def __init__(self, start_iter=0):
+ """
+ Args:
+ start_iter (int): the iteration number to start with
+ """
+ self._history = defaultdict(HistoryBuffer)
+ self._smoothing_hints = {}
+ self._avg_hints = {}
+ self._latest_scalars = {}
+ self._iter = start_iter
+ self._current_prefix = ""
+ self._vis_data = []
+ self._histograms = []
+
+ self.expert_log_prob = 0.05
+
+ self.expert_log_interval = getattr(comm, '_EXPERT_LOG_INTERVAL', 200)
+
+ def put_image(self, img_name, img_tensor):
+ """
+ Add an `img_tensor` associated with `img_name`, to be shown on
+ tensorboard.
+
+ Args:
+ img_name (str): The name of the image to put into tensorboard.
+ img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
+ Tensor of shape `[channel, height, width]` where `channel` is
+ 3. The image format should be RGB. The elements in img_tensor
+ can either have values in [0, 1] (float32) or [0, 255] (uint8).
+ The `img_tensor` will be visualized in tensorboard.
+ """
+ self._vis_data.append((img_name, img_tensor, self._iter))
+
+ def put_scalar(self, name, value, smoothing_hint=True, avg_hint=False):
+ """
+ Add a scalar `value` to the `HistoryBuffer` associated with `name`.
+
+ Args:
+ smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
+ smoothed when logged. The hint will be accessible through
+ :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
+ and apply custom smoothing rule.
+
+ It defaults to True because most scalars we save need to be smoothed to
+ provide any useful signal.
+ """
+ name = self._current_prefix + name
+ history = self._history[name]
+ value = float(value)
+ history.update(value, self._iter)
+ self._latest_scalars[name] = (value, self._iter)
+
+ existing_hint = self._smoothing_hints.get(name)
+ if existing_hint is not None:
+ assert (
+ existing_hint == smoothing_hint
+ ), "Scalar {} was put with a different smoothing_hint!".format(name)
+ else:
+ self._smoothing_hints[name] = smoothing_hint
+
+ existing_hint = self._avg_hints.get(name)
+ if existing_hint is not None:
+ assert (
+ existing_hint == avg_hint
+ ), "Scalar {} was put with a different smoothing_hint!".format(
+ name)
+ else:
+ self._avg_hints[name] = avg_hint
+
+ def put_scalars(self, *, smoothing_hint=True, avg_hint=False, **kwargs):
+ """
+ Put multiple scalars from keyword arguments.
+
+ Examples:
+
+ storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
+ """
+ for k, v in kwargs.items():
+ self.put_scalar(k,
+ v,
+ smoothing_hint=smoothing_hint,
+ avg_hint=avg_hint)
+
+ def put_histogram(self, hist_name, hist_tensor, bins=1000):
+ """
+ Create a histogram from a tensor.
+
+ Args:
+ hist_name (str): The name of the histogram to put into tensorboard.
+ hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
+ into a histogram.
+ bins (int): Number of histogram bins.
+ """
+ import torch
+
+ ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
+ if ht_max <= ht_min:
+ ht_max = ht_max + 0.0001
+ ht_min = ht_min - 0.0001
+ # Create a histogram with PyTorch
+ hist_counts = torch.histc(hist_tensor, bins=bins)
+ hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
+
+ # Parameter for the add_histogram_raw function of SummaryWriter
+ hist_params = dict(
+ tag=hist_name,
+ min=ht_min,
+ max=ht_max,
+ num=len(hist_tensor),
+ sum=float(hist_tensor.sum()),
+ sum_squares=float(torch.sum(hist_tensor ** 2)),
+ bucket_limits=hist_edges[1:].tolist(),
+ bucket_counts=hist_counts.tolist(),
+ global_step=self._iter,
+ )
+ self._histograms.append(hist_params)
+
+ def history(self, name):
+ """
+ Returns:
+ HistoryBuffer: the scalar history for name
+ """
+ ret = self._history.get(name, None)
+ if ret is None:
+ raise KeyError("No history metric available for {}!".format(name))
+ return ret
+
+ def histories(self):
+ """
+ Returns:
+ dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
+ """
+ return self._history
+
+ def latest(self):
+ """
+ Returns:
+ dict[str -> (float, int)]: mapping from the name of each scalar to the most
+ recent value and the iteration number its added.
+ """
+ return self._latest_scalars
+
+ def latest_with_smoothing_hint(self, window_size=20):
+ """
+ Similar to :meth:`latest`, but the returned values
+ are either the un-smoothed original latest value,
+ or a median of the given window_size,
+ depend on whether the smoothing_hint is True.
+
+ This provides a default behavior that other writers can use.
+ """
+ result = {}
+ for k, (v, itr) in self._latest_scalars.items():
+ if (k.startswith('layer') or k.startswith('logits_layer')) and self._avg_hints[k]:
+ # if random.rand() < self.expert_log_prob:
+ if (self._iter + 1) % self.expert_log_interval == 0:
+ result[k] = (
+ self._history[k].avg(10),
+ itr,
+ )
+ else:
+ continue
+ else:
+ result[k] = (
+ self._history[k].median(window_size) if self._smoothing_hints[k] else v,
+ itr,
+ )
+ return result
+
+ def smoothing_hints(self):
+ """
+ Returns:
+ dict[name -> bool]: the user-provided hint on whether the scalar
+ is noisy and needs smoothing.
+ """
+ return self._smoothing_hints
+
+ def step(self):
+ """
+ User should either: (1) Call this function to increment storage.iter when needed. Or
+ (2) Set `storage.iter` to the correct iteration number before each iteration.
+
+ The storage will then be able to associate the new data with an iteration number.
+ """
+ self._iter += 1
+
+ @property
+ def iter(self):
+ """
+ Returns:
+ int: The current iteration number. When used together with a trainer,
+ this is ensured to be the same as trainer.iter.
+ """
+ return self._iter
+
+ @iter.setter
+ def iter(self, val):
+ self._iter = int(val)
+
+ @property
+ def iteration(self):
+ # for backward compatibility
+ return self._iter
+
+ def __enter__(self):
+ _CURRENT_STORAGE_STACK.append(self)
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ assert _CURRENT_STORAGE_STACK[-1] == self
+ _CURRENT_STORAGE_STACK.pop()
+
+ @contextmanager
+ def name_scope(self, name):
+ """
+ Yields:
+ A context within which all the events added to this storage
+ will be prefixed by the name scope.
+ """
+ old_prefix = self._current_prefix
+ self._current_prefix = name.rstrip("/") + "/"
+ yield
+ self._current_prefix = old_prefix
+
+ def clear_images(self):
+ """
+ Delete all the stored images for visualization. This should be called
+ after images are written to tensorboard.
+ """
+ self._vis_data = []
+
+ def clear_histograms(self):
+ """
+ Delete all the stored histograms for visualization.
+ This should be called after histograms are written to tensorboard.
+ """
+ self._histograms = []
diff --git a/uniperceiver/utils/file_io.py b/uniperceiver/utils/file_io.py
new file mode 100644
index 0000000000000000000000000000000000000000..46ee4ec31d04eee77976ff3edbbf84762a3409ed
--- /dev/null
+++ b/uniperceiver/utils/file_io.py
@@ -0,0 +1,37 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler
+from iopath.common.file_io import PathManager as PathManagerBase
+
+__all__ = ["PathManager", "PathHandler"]
+
+
+PathManager = PathManagerBase()
+"""
+This is a detectron2 project-specific PathManager.
+We try to stay away from global PathManager in fvcore as it
+introduces potential conflicts among other libraries.
+"""
+
+
+class Detectron2Handler(PathHandler):
+ """
+ Resolve anything that's hosted under detectron2's namespace.
+ """
+
+ PREFIX = "detectron2://"
+ S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/"
+
+ def _get_supported_prefixes(self):
+ return [self.PREFIX]
+
+ def _get_local_path(self, path, **kwargs):
+ name = path[len(self.PREFIX) :]
+ return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs)
+
+ def _open(self, path, mode="r", **kwargs):
+ return PathManager.open(self._get_local_path(path), mode, **kwargs)
+
+
+PathManager.register_handler(HTTPURLHandler())
+PathManager.register_handler(OneDrivePathHandler())
+PathManager.register_handler(Detectron2Handler())
diff --git a/uniperceiver/utils/logger.py b/uniperceiver/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a90b1bc6dd29a2d9626f07a8c9717603d2d00470
--- /dev/null
+++ b/uniperceiver/utils/logger.py
@@ -0,0 +1,237 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import atexit
+import functools
+import logging
+import os
+import sys
+import time
+from collections import Counter
+import torch
+from tabulate import tabulate
+from termcolor import colored
+
+from uniperceiver.utils.file_io import PathManager
+
+__all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"]
+
+
+class _ColorfulFormatter(logging.Formatter):
+ def __init__(self, *args, **kwargs):
+ self._root_name = kwargs.pop("root_name") + "."
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
+ if len(self._abbrev_name):
+ self._abbrev_name = self._abbrev_name + "."
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
+
+ def formatMessage(self, record):
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
+ log = super(_ColorfulFormatter, self).formatMessage(record)
+ if record.levelno == logging.WARNING:
+ prefix = colored("WARNING", "red", attrs=["blink"])
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
+ else:
+ return log
+ return prefix + " " + log
+
+
+@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
+def setup_logger(
+ output=None, distributed_rank=0, *, color=True, name="uniperceiver", abbrev_name=None
+):
+ """
+ Initialize the detectron2 logger and set its verbosity level to "DEBUG".
+
+ Args:
+ output (str): a file name or a directory to save log. If None, will not save log file.
+ If ends with ".txt" or ".log", assumed to be a file name.
+ Otherwise, logs will be saved to `output/log.txt`.
+ name (str): the root module name of this logger
+ abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
+ Set to "" to not log the root module in logs.
+ By default, will abbreviate "detectron2" to "d2" and leave other
+ modules unchanged.
+
+ Returns:
+ logging.Logger: a logger
+ """
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ logger.propagate = False
+
+ if abbrev_name is None:
+ abbrev_name = "uni" if name == "uniperceiver" else name
+
+ plain_formatter = logging.Formatter(
+ "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
+ )
+ # stdout logging: master only
+ if distributed_rank == 0:
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ if color:
+ formatter = _ColorfulFormatter(
+ colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
+ datefmt="%m/%d %H:%M:%S",
+ root_name=name,
+ abbrev_name=str(abbrev_name),
+ )
+ else:
+ formatter = plain_formatter
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ # file logging: all workers
+ if output is not None:
+ if output.endswith(".txt") or output.endswith(".log"):
+ filename = output
+ else:
+ filename = os.path.join(output, "log.txt")
+ if distributed_rank > 0:
+ filename = filename + ".rank{}".format(distributed_rank)
+ PathManager.mkdirs(os.path.dirname(filename))
+
+ fh = logging.StreamHandler(_cached_log_stream(filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(plain_formatter)
+ logger.addHandler(fh)
+
+ return logger
+
+
+# cache the opened file object, so that different calls to `setup_logger`
+# with the same file name can safely write to the same file.
+@functools.lru_cache(maxsize=None)
+def _cached_log_stream(filename):
+ # use 1K buffer if writing to cloud storage
+ io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1)
+ atexit.register(io.close)
+ return io
+
+
+"""
+Below are some other convenient logging methods.
+They are mainly adopted from
+https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
+"""
+
+
+def _find_caller():
+ """
+ Returns:
+ str: module name of the caller
+ tuple: a hashable key to be used to identify different callers
+ """
+ frame = sys._getframe(2)
+ while frame:
+ code = frame.f_code
+ if os.path.join("utils", "logger.") not in code.co_filename:
+ mod_name = frame.f_globals["__name__"]
+ if mod_name == "__main__":
+ mod_name = "uniperceiver"
+ return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
+ frame = frame.f_back
+
+
+_LOG_COUNTER = Counter()
+_LOG_TIMER = {}
+
+
+def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
+ """
+ Log only for the first n times.
+
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ key (str or tuple[str]): the string(s) can be one of "caller" or
+ "message", which defines how to identify duplicated logs.
+ For example, if called with `n=1, key="caller"`, this function
+ will only log the first call from the same caller, regardless of
+ the message content.
+ If called with `n=1, key="message"`, this function will log the
+ same content only once, even if they are called from different places.
+ If called with `n=1, key=("caller", "message")`, this function
+ will not log only if the same caller has logged the same message before.
+ """
+ if isinstance(key, str):
+ key = (key,)
+ assert len(key) > 0
+
+ caller_module, caller_key = _find_caller()
+ hash_key = ()
+ if "caller" in key:
+ hash_key = hash_key + caller_key
+ if "message" in key:
+ hash_key = hash_key + (msg,)
+
+ _LOG_COUNTER[hash_key] += 1
+ if _LOG_COUNTER[hash_key] <= n:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n(lvl, msg, n=1, *, name=None):
+ """
+ Log once per n times.
+
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ """
+ caller_module, key = _find_caller()
+ _LOG_COUNTER[key] += 1
+ if n == 1 or _LOG_COUNTER[key] % n == 1:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+
+
+def log_every_n_seconds(lvl, msg, n=1, *, name=None):
+ """
+ Log no more than once per n seconds.
+
+ Args:
+ lvl (int): the logging level
+ msg (str):
+ n (int):
+ name (str): name of the logger to use. Will use the caller's module by default.
+ """
+ caller_module, key = _find_caller()
+ last_logged = _LOG_TIMER.get(key, None)
+ current_time = time.time()
+ if last_logged is None or current_time - last_logged >= n:
+ logging.getLogger(name or caller_module).log(lvl, msg)
+ _LOG_TIMER[key] = current_time
+
+
+def create_small_table(small_dict):
+ """
+ Create a small table using the keys of small_dict as headers. This is only
+ suitable for small dictionaries.
+
+ Args:
+ small_dict (dict): a result dictionary of only a few items.
+
+ Returns:
+ str: the table as a string.
+ """
+ keys, values = tuple(zip(*small_dict.items()))
+ table = tabulate(
+ [values],
+ headers=keys,
+ tablefmt="pipe",
+ floatfmt=".3f",
+ stralign="center",
+ numalign="center",
+ )
+ return table
+
+
+def _log_api_usage(identifier: str):
+ """
+ Internal function used to log the usage of different detectron2 components
+ inside facebook's infra.
+ """
+ torch._C._log_api_usage_once("uniperceiver." + identifier)
diff --git a/uniperceiver/utils/misc.py b/uniperceiver/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..40183d77eccc25ea6d23f0d2757c7b5e908c2529
--- /dev/null
+++ b/uniperceiver/utils/misc.py
@@ -0,0 +1,125 @@
+import torch
+import torch.distributed as dist
+from torch._six import inf
+import io
+from timm.utils import get_state_dict
+try:
+ from apex import amp
+ APEX_INSTALLED = True
+except:
+ print('apex has not been installed.')
+ APEX_INSTALLED = False
+
+
+class NativeScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, enabled=True, growth_interval=500, init_scale=2.**13):
+ self.enabled = enabled
+ self._scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval, enabled=self.enabled)
+
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True,
+ fp16=False, iter=0, min_loss_scale= 2048.0, loss_scale_window=200):
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+
+ if fp16:
+ # used for stable training
+ if iter > 5000 and self._scaler.get_scale() < min_loss_scale:
+ min_growth_interval = 5
+ if self._scaler.get_growth_interval() != min_growth_interval:
+ self._scaler.set_growth_interval(min_growth_interval)
+
+ elif iter > 5000 and self._scaler.get_growth_interval() == 5:
+ self._scaler.set_growth_interval(loss_scale_window)
+
+ if update_grad:
+ if clip_grad is not None and clip_grad > 0.0:
+ assert parameters is not None
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ else:
+ norm = None
+ return norm
+
+ def step(self, optimizer):
+ self._scaler.step(optimizer)
+
+ def update(self):
+ self._scaler.update()
+
+ def get_scale(self):
+ return self._scaler.get_scale()
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+class ApexScalerWithGradNormCount:
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, enabled=True):
+ self.enabled = enabled
+ self._scaler = amp
+
+ def __call__(self,
+ loss,
+ optimizer,
+ clip_grad=None,
+ parameters=None,
+ create_graph=False,
+ update_grad=True,
+ fp16=False,
+ iter=0,
+ min_loss_scale=2048.0,
+ loss_scale_window=200):
+
+ with self._scaler.scale_loss(loss, optimizer) as scaled_loss:
+ scaled_loss.backward()
+
+ if update_grad:
+ if clip_grad is not None and clip_grad > 0.0:
+ norm = torch.nn.utils.clip_grad_norm_(
+ amp.master_params(optimizer), clip_grad)
+ else:
+
+ norm = get_grad_norm_(amp.master_params(optimizer))
+ else:
+ norm = None
+ return norm
+
+ def step(self, optimizer):
+ optimizer.step()
+
+
+ def update(self):
+ pass
+
+ def get_scale(self):
+ return self._scaler.state_dict()['loss_scaler0']['loss_scale']
+
+ def state_dict(self):
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
+ return total_norm
\ No newline at end of file
diff --git a/uniperceiver/utils/registry.py b/uniperceiver/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b01e9007c2578a7b5ae555c926cc06c8a3010f9
--- /dev/null
+++ b/uniperceiver/utils/registry.py
@@ -0,0 +1,60 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+
+from typing import Any
+import pydoc
+from fvcore.common.registry import Registry # for backward compatibility.
+
+"""
+``Registry`` and `locate` provide ways to map a string (typically found
+in config files) to callable objects.
+"""
+
+__all__ = ["Registry", "locate"]
+
+
+def _convert_target_to_string(t: Any) -> str:
+ """
+ Inverse of ``locate()``.
+
+ Args:
+ t: any object with ``__module__`` and ``__qualname__``
+ """
+ module, qualname = t.__module__, t.__qualname__
+
+ # Compress the path to this object, e.g. ``module.submodule._impl.class``
+ # may become ``module.submodule.class``, if the later also resolves to the same
+ # object. This simplifies the string, and also is less affected by moving the
+ # class implementation.
+ module_parts = module.split(".")
+ for k in range(1, len(module_parts)):
+ prefix = ".".join(module_parts[:k])
+ candidate = f"{prefix}.{qualname}"
+ try:
+ if locate(candidate) is t:
+ return candidate
+ except ImportError:
+ pass
+ return f"{module}.{qualname}"
+
+
+def locate(name: str) -> Any:
+ """
+ Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``,
+ such as "module.submodule.class_name".
+
+ Raise Exception if it cannot be found.
+ """
+ obj = pydoc.locate(name)
+
+ # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly
+ # by pydoc.locate. Try a private function from hydra.
+ if obj is None:
+ try:
+ # from hydra.utils import get_method - will print many errors
+ from hydra.utils import _locate
+ except ImportError as e:
+ raise ImportError(f"Cannot dynamically locate object {name}!") from e
+ else:
+ obj = _locate(name) # it raises if fails
+
+ return obj
diff --git a/uniperceiver/utils/serialize.py b/uniperceiver/utils/serialize.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b38862804b70cf1159a9bc93acdef73c184d883
--- /dev/null
+++ b/uniperceiver/utils/serialize.py
@@ -0,0 +1,32 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+import cloudpickle
+
+
+class PicklableWrapper(object):
+ """
+ Wrap an object to make it more picklable, note that it uses
+ heavy weight serialization libraries that are slower than pickle.
+ It's best to use it only on closures (which are usually not picklable).
+
+ This is a simplified version of
+ https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py
+ """
+
+ def __init__(self, obj):
+ while isinstance(obj, PicklableWrapper):
+ # Wrapping an object twice is no-op
+ obj = obj._obj
+ self._obj = obj
+
+ def __reduce__(self):
+ s = cloudpickle.dumps(self._obj)
+ return cloudpickle.loads, (s,)
+
+ def __call__(self, *args, **kwargs):
+ return self._obj(*args, **kwargs)
+
+ def __getattr__(self, attr):
+ # Ensure that the wrapped object can be used seamlessly as the previous object.
+ if attr not in ["_obj"]:
+ return getattr(self._obj, attr)
+ return getattr(self, attr)
diff --git a/uniperceiver/utils/transformer_util.py b/uniperceiver/utils/transformer_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..18ce991d387a490aa1cd9946f4d8d9789b01145b
--- /dev/null
+++ b/uniperceiver/utils/transformer_util.py
@@ -0,0 +1,323 @@
+import torch
+from torch import nn
+import math
+import warnings
+from torch.nn import init
+import numpy as np
+from uniperceiver.utils import comm
+
+INIT_STD = 0.02
+INIT_EMBEDDING_STD = 0.02
+
+def null_loss_check(outputs_dict):
+ ret = {}
+ if 'null_loss' in outputs_dict:
+ null_loss = outputs_dict['null_loss']
+ else:
+ null_loss = 0
+ for shared_target in outputs_dict['shared_target_sets'].values():
+ null_loss += torch.sum(shared_target[0]['data']*0)
+ ret.update({'null_loss': null_loss})
+ return ret
+
+def build_2d_sincos_position_embedding(cfg, video_embed, cls_token=False, temperature=10000., pos_emd_fix=False):
+ h, w = int(video_embed.max_spatial_size**.5), int(video_embed.max_spatial_size**.5)
+
+ grid_w = torch.arange(w, dtype=torch.float32)
+ grid_h = torch.arange(h, dtype=torch.float32)
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
+ if cfg.MODEL.POSEMBED_SCALE != 1.0:
+ grid_w = grid_w * cfg.MODEL.POSEMBED_SCALE
+ grid_h = grid_h * cfg.MODEL.POSEMBED_SCALE
+
+ assert cfg.MODEL.BERT.HIDDEN_SIZE % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+ pos_dim = cfg.MODEL.BERT.HIDDEN_SIZE // 4
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+ omega = 1. / (temperature**omega)
+ out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega])
+ out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega])
+ pos_emb = torch.cat([
+ torch.sin(out_w),
+ torch.cos(out_w),
+ torch.sin(out_h),
+ torch.cos(out_h)
+ ],
+ dim=1)[ :, :]
+
+ # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]'
+ if cls_token:
+ pe_token = torch.zeros([ 1, cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
+ video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(torch.cat([pe_token, pos_emb], dim=0))
+ else:
+ video_embed.embeddings_st_pos.spatial_pos_embed.weight = nn.Parameter(pos_emb)
+ if cfg.MODEL.POSEMBEDFIX:
+ video_embed.embeddings_st_pos.spatial_pos_embed.weight.requires_grad = False
+
+
+def _no_grad_trunc_normal_(tensor, mean, std, a, b):
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
+ def norm_cdf(x):
+ # Computes standard normal cumulative distribution function
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
+
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
+ warnings.warn(
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
+ "The distribution of values may be incorrect.",
+ stacklevel=2)
+
+ with torch.no_grad():
+ # Values are generated by using a truncated uniform distribution and
+ # then using the inverse CDF for the normal distribution.
+ # Get upper and lower cdf values
+ l = norm_cdf((a - mean) / std)
+ u = norm_cdf((b - mean) / std)
+
+ # Uniformly fill tensor with values from [l, u], then translate to
+ # [2l-1, 2u-1].
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
+
+ # Use inverse cdf transform for normal distribution to get truncated
+ # standard normal
+ tensor.erfinv_()
+
+ # Transform to proper mean, std
+ tensor.mul_(std * math.sqrt(2.))
+ tensor.add_(mean)
+
+ # Clamp to ensure it's in the proper range
+ tensor.clamp_(min=a, max=b)
+ return tensor
+
+def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
+
+def truncated_normal_(tensor, mode='fan_in',):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ fan = init._calculate_correct_fan(tensor, mode=mode)
+ gain = 0.1
+ std = math.sqrt(gain/fan)
+ init.trunc_normal_(tensor, mean=0.0, std=std)
+
+def normal_(data):
+ # with FSDP, module params will be on CUDA, so we cast them back to CPU
+ # so that the RNG is consistent with and without FSDP
+ data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
+
+def init_bert_params(module):
+ if isinstance(module, nn.Linear):
+ normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ if isinstance(module, nn.MultiheadAttention):
+ # normal_(module.q_proj.weight.data)
+ # normal_(module.k_proj.weight.data)
+ # normal_(module.v_proj.weight.data)
+ normal_(module.in_proj_weight.data)
+
+def init_switchtransformer_params(module):
+ if isinstance(module, nn.Linear):
+ truncated_normal_(module.weight)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ if isinstance(module, nn.Embedding):
+ normal_(module.weight.data)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+def init_timm_params(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=INIT_STD)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if isinstance(m, nn.Embedding):
+ trunc_normal_(m.weight.data, std=INIT_EMBEDDING_STD)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+ if isinstance(m, nn.MultiheadAttention):
+ trunc_normal_(m.q_proj.weight.data, std=INIT_STD)
+ trunc_normal_(m.k_proj.weight.data, std=INIT_STD)
+ trunc_normal_(m.v_proj.weight.data, std=INIT_STD)
+
+def initialize_weights_as_mae(model):
+ # initialization
+
+ # initialize nn.Linear and nn.LayerNorm
+ model.apply(init_weights_mae)
+
+ # initialize (and freeze) pos_embed by sin-cos embedding
+ if model.video_embed is not None:
+ build_2d_sincos_position_embedding(model.cfg, model.video_embed)
+
+
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
+ w = model.video_embed.embeddings.weight.data
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
+
+ if model.video_embed.embeddings.bias is not None:
+ nn.init.zeros_(model.video_embed.embeddings.bias)
+
+
+def initialize_weights_as_mocov3(model):
+ model.initialize_weights_as_mae()
+
+ # cls token with smaller std
+ # temp = torch.zeros([ 1, self.cfg.MODEL.BERT.HIDDEN_SIZE], dtype=torch.float32)
+ nn.init.normal_(model.token_embed.embeddings.weight[-1, :], std=1e-6) # small std for cls token
+
+
+def init_weights_mae(m):
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
+ # torch.nn.init.normal_(self.cls_token, std=.02)
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+
+ if m.weight.shape[0] == m.weight.shape[1] * 3:
+ # treat the weights of Q, K, V separately
+ val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
+ nn.init.uniform_(m.weight, -val, val)
+ else:
+ torch.nn.init.xavier_uniform_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ # all word embedding e.g. word. spe. type embedding postion embed
+ # MAE only has embedding like cls_token and mask tokens
+ elif isinstance(m, nn.Embedding):
+ torch.nn.init.normal_(m.weight.data, std=INIT_EMBEDDING_STD)
+ if m.padding_idx is not None:
+ m.weight.data[m.padding_idx].zero_()
+
+
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ elif isinstance(m, nn.MultiheadAttention):
+ if m.q_proj_weight is not None:
+ torch.nn.init.xavier_uniform_(m.q_proj_weight.data)
+ torch.nn.init.xavier_uniform_(m.k_proj_weight.data)
+ torch.nn.init.xavier_uniform_(m.v_proj_weight.data)
+ else:
+ # treat the weights of Q, K, V separately
+ val = math.sqrt(6. / float(m.in_proj_weight.shape[0] // 3 + m.in_proj_weight.shape[1]))
+ nn.init.uniform_(m.in_proj_weight, -val, val)
+
+def data_half(fp16, bf16, data):
+ if fp16:
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
+ data[k] = v.half()
+ # print(k)
+
+ elif bf16:
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor) and v.dtype == torch.float32:
+ data[k] = v.to(torch.bfloat16)
+ # print(k)
+
+ return data
+
+def postprocess(data_dict:dict, task_info:dict ):
+ if data_dict.get('sample_info', None) is not None and data_dict['sample_info'].get('distributed', False):
+ data = data_dict['data']
+ hidden_states = data[:, 0].contiguous(
+ ) # HERE only use the spe token feature!
+ hidden_states = torch.cat(torch.distributed.nn.all_gather(hidden_states))
+
+ total_length = data_dict['sample_info']['total_num']
+
+ if hidden_states.shape[0] > total_length:
+ hidden_states = hidden_states[:total_length]
+
+ data_dict['data'] = hidden_states.unsqueeze(1)
+
+
+def get_spe_token(tokenizer, token_embed):
+ if comm.old_checkpoint:
+ a = torch.tensor(tokenizer.encode('<|spe|>')).cuda().unsqueeze(0) # bs, 1
+ return token_embed(a, type_embed=False, pos_embed=False)
+ else:
+ a = torch.tensor(tokenizer.encode('spe')).cuda().unsqueeze(0) # bs, 1
+ return token_embed(a)
+
+def preprocess(tokenizer, token_embed, data_list:list, task_info:dict):
+ # perparation for fused_encoder input
+ bs = data_list[0]['data'].shape[0]
+ device = data_list[0]['data'].device
+ mask_dtype = torch.uint8
+
+ #TODO: prompt embedding
+
+ prefix_spe_before_fuse = task_info.get('prefix_spe_before_fuse', True)
+
+ combined_data = []
+ # spe embedding
+ spe_token = get_spe_token(tokenizer, token_embed).expand(bs, -1, -1)
+
+ length = [ data_dict['data'].shape[1] for data_dict in data_list]
+ if prefix_spe_before_fuse:
+ length = [1] + length
+ combined_data.append(spe_token)
+
+ cum_length = np.cumsum(length).tolist()
+
+ invalid_mask_active = any([ data_dict.get('invalid_mask', None) is not None for data_dict in data_list])
+ if invalid_mask_active:
+
+ combined_valid_mask = torch.zeros((bs, cum_length[-1]), dtype=mask_dtype, device=device)
+ else:
+ combined_valid_mask = None
+
+ for i, data_dict in enumerate(data_list):
+ combined_data.append(data_dict['data'])
+ if data_dict.get('invalid_mask', None) is not None:
+ combined_valid_mask[:, cum_length[i]:cum_length[i+1]] = data_dict['invalid_mask']
+
+ combined_data = torch.cat(combined_data, dim=1)
+
+ sample_info = {
+ 'data_length': length,
+ 'data_cum_length': cum_length,
+ 'sample_info_per_sample': []}
+
+ # for caption task inference
+ if comm._CAPTION_GEN_MODE:
+ sample_info['data_cum_length'] = data_list[0]['sample_info']['data_cum_length']
+
+
+ for data_dict in data_list:
+ if data_dict.get('sample_info', None) is not None:
+ if isinstance(data_dict['sample_info'], dict):
+ sample_info.update(data_dict['sample_info'])
+ elif isinstance(data_dict['sample_info'], list):
+ if isinstance(data_dict['sample_info'][0], dict):
+ sample_info.update(data_dict['sample_info'][0])
+ sample_info['sample_info_per_sample'].append(data_dict['sample_info'])
+
+ moe_embedding = None
+ for data_dict in data_list:
+ if 'data_type' in data_dict:
+ data_type = data_dict['data_type']
+ if 'moe_embedding' in data_dict:
+ moe_embedding = data_dict['moe_embedding']
+
+ return {
+ 'data': combined_data,
+ 'invalid_mask': combined_valid_mask,
+ 'data_type': data_type,
+ 'sample_info': sample_info,
+ 'moe_embedding': moe_embedding,
+ }
+
+def share_token_embed_ln(video_embed, token_embed):
+ if video_embed is not None and token_embed is not None:
+ del video_embed.embeddings_norm
+ video_embed.embeddings_norm = token_embed.embeddings_norm