Panda
Browse files- code/PandaGPT +1 -0
- code/gpt4-clone/demo.py +55 -0
- code/gpt4-clone/eval_configs/minigpt4_eval.yaml +25 -0
- code/gpt4-clone/gpt4clone/__init__.py +18 -0
- code/gpt4-clone/gpt4clone/__pycache__/__init__.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/common/__pycache__/config.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/common/__pycache__/registry.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/common/__pycache__/utils.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/common/config.py +86 -0
- code/gpt4-clone/gpt4clone/common/registry.py +91 -0
- code/gpt4-clone/gpt4clone/common/utils.py +6 -0
- code/gpt4-clone/gpt4clone/configs/default.yaml +5 -0
- code/gpt4-clone/gpt4clone/configs/models/minigpt4.yaml +33 -0
- code/gpt4-clone/gpt4clone/models/__init__.py +5 -0
- code/gpt4-clone/gpt4clone/models/__pycache__/__init__.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/models/__pycache__/mini_gpt4.cpython-39.pyc +0 -0
- code/gpt4-clone/gpt4clone/models/mini_gpt4.py +125 -0
code/PandaGPT
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 9b77e0412d42a468362b67d1892ff0f4f659a4f5
|
code/gpt4-clone/demo.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import gradio as gr
|
9 |
+
|
10 |
+
from gpt4clone.common.config import Config
|
11 |
+
from gpt4clone.common.registry import registry
|
12 |
+
|
13 |
+
# imports modules for registration
|
14 |
+
from gpt4clone.models import *
|
15 |
+
|
16 |
+
def parse_args():
|
17 |
+
parser = argparse.ArgumentParser(description="Demo")
|
18 |
+
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
|
19 |
+
parser.add_argument("--gpu-id", type=int, default=0, help="specify the gpu to load the model.")
|
20 |
+
parser.add_argument(
|
21 |
+
"--options",
|
22 |
+
nargs="+",
|
23 |
+
help="override some settings in the used config, the key-value pair "
|
24 |
+
"in xxx=yyy format will be merged into config file (deprecate), "
|
25 |
+
"change to --cfg-options instead.",
|
26 |
+
)
|
27 |
+
args = parser.parse_args()
|
28 |
+
return args
|
29 |
+
|
30 |
+
|
31 |
+
def setup_seeds(config):
|
32 |
+
seed = config.run_cfg.seed + get_rank()
|
33 |
+
|
34 |
+
random.seed(seed)
|
35 |
+
np.random.seed(seed)
|
36 |
+
torch.manual_seed(seed)
|
37 |
+
|
38 |
+
cudnn.benchmark = False
|
39 |
+
cudnn.deterministic = True
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
# ========================================
|
44 |
+
# Model Initialization
|
45 |
+
# ========================================
|
46 |
+
|
47 |
+
print('Initializing Chat')
|
48 |
+
args = parse_args()
|
49 |
+
cfg = Config(args)
|
50 |
+
model_config = cfg.model_cfg
|
51 |
+
model_config.device_8bit = args.gpu_id
|
52 |
+
model_cls = registry.get_model_class(model_config.arch)
|
53 |
+
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))
|
54 |
+
|
55 |
+
|
code/gpt4-clone/eval_configs/minigpt4_eval.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
model_type: pretrain_vicuna
|
4 |
+
freeze_vit: True
|
5 |
+
freeze_qformer: True
|
6 |
+
max_txt_len: 160
|
7 |
+
end_sym: "###"
|
8 |
+
low_resource: True
|
9 |
+
prompt_path: "prompts/alignment.txt"
|
10 |
+
prompt_template: '###Human: {} ###Assistant: '
|
11 |
+
ckpt: '/workspace/weights/minigpt4/prerained_minigpt4_7b.pth'
|
12 |
+
|
13 |
+
|
14 |
+
datasets:
|
15 |
+
cc_sbu_align:
|
16 |
+
vis_processor:
|
17 |
+
train:
|
18 |
+
name: "blip2_image_eval"
|
19 |
+
image_size: 224
|
20 |
+
text_processor:
|
21 |
+
train:
|
22 |
+
name: "blip_caption"
|
23 |
+
|
24 |
+
run:
|
25 |
+
task: image_text_pretrain
|
code/gpt4-clone/gpt4clone/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
|
6 |
+
from gpt4clone.common.registry import registry
|
7 |
+
|
8 |
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
|
10 |
+
|
11 |
+
registry.register_path("library_root", root_dir)
|
12 |
+
repo_root = os.path.join(root_dir, "..")
|
13 |
+
registry.register_path("repo_root", repo_root)
|
14 |
+
cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
|
15 |
+
registry.register_path("cache_root", cache_root)
|
16 |
+
|
17 |
+
registry.register("MAX_INT", sys.maxsize)
|
18 |
+
registry.register("SPLIT_NAMES", ["train", "val", "test"])
|
code/gpt4-clone/gpt4clone/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (648 Bytes). View file
|
|
code/gpt4-clone/gpt4clone/common/__pycache__/config.cpython-39.pyc
ADDED
Binary file (2.45 kB). View file
|
|
code/gpt4-clone/gpt4clone/common/__pycache__/registry.cpython-39.pyc
ADDED
Binary file (2.48 kB). View file
|
|
code/gpt4-clone/gpt4clone/common/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (388 Bytes). View file
|
|
code/gpt4-clone/gpt4clone/common/config.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import logging
|
3 |
+
import json
|
4 |
+
from typing import Dict
|
5 |
+
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
from gpt4clone.common.registry import registry
|
8 |
+
|
9 |
+
class Config:
|
10 |
+
def __init__(self, args):
|
11 |
+
self.config = {}
|
12 |
+
|
13 |
+
self.args = args
|
14 |
+
|
15 |
+
# Register the config and configuration for setup
|
16 |
+
registry.register("configuration", self)
|
17 |
+
|
18 |
+
user_config = self._build_opt_list(self.args.options)
|
19 |
+
|
20 |
+
config = OmegaConf.load(self.args.cfg_path)
|
21 |
+
|
22 |
+
runner_config = self.build_runner_config(config)
|
23 |
+
model_config = self.build_model_config(config, **user_config)
|
24 |
+
# dataset_config = self.build_dataset_config(config)
|
25 |
+
# # Override the default configuration with user options.
|
26 |
+
self.config = OmegaConf.merge(
|
27 |
+
runner_config, model_config
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
def _build_opt_list(self, opts):
|
32 |
+
opts_dot_list = self._convert_to_dot_list(opts)
|
33 |
+
return OmegaConf.from_dotlist(opts_dot_list)
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
def _convert_to_dot_list(self, opts):
|
38 |
+
if opts is None:
|
39 |
+
opts = []
|
40 |
+
|
41 |
+
if len(opts) == 0:
|
42 |
+
return opts
|
43 |
+
|
44 |
+
has_equal = opts[0].find("=") != -1
|
45 |
+
|
46 |
+
if has_equal:
|
47 |
+
return opts
|
48 |
+
|
49 |
+
return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
def build_runner_config(config):
|
53 |
+
return {"run": config.run}
|
54 |
+
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def build_model_config(config, **kwargs):
|
58 |
+
model = config.get("model", None)
|
59 |
+
assert model is not None, "Missing model configuration file."
|
60 |
+
|
61 |
+
model_cls = registry.get_model_class(model.arch)
|
62 |
+
assert model_cls is not None, f"Model '{model.arch}' has not been registered."
|
63 |
+
|
64 |
+
model_type = kwargs.get('model.model_type', None)
|
65 |
+
if not model_type:
|
66 |
+
model_type = model.get('model_type', None)
|
67 |
+
|
68 |
+
assert model_type is not None, "Missing model_type"
|
69 |
+
|
70 |
+
model_config_path = model_cls.default_config_path(model_type=model_type)
|
71 |
+
|
72 |
+
model_config = OmegaConf.create()
|
73 |
+
print('config[model]', config['model'])
|
74 |
+
model_config = OmegaConf.merge(
|
75 |
+
# model_config,
|
76 |
+
OmegaConf.load(model_config_path),
|
77 |
+
{"model": config['model']},
|
78 |
+
)
|
79 |
+
|
80 |
+
return model_config
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
@property
|
85 |
+
def model_cfg(self):
|
86 |
+
return self.config.model
|
code/gpt4-clone/gpt4clone/common/registry.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
class Registry:
|
3 |
+
mapping = {
|
4 |
+
"builder_name_mapping": {},
|
5 |
+
"task_name_mapping": {},
|
6 |
+
"processor_name_mapping": {},
|
7 |
+
"model_name_mapping": {},
|
8 |
+
"lr_scheduler_name_mapping": {},
|
9 |
+
"runner_name_mapping": {},
|
10 |
+
"state": {},
|
11 |
+
"paths": {},
|
12 |
+
}
|
13 |
+
|
14 |
+
|
15 |
+
@classmethod
|
16 |
+
def register_path(cls, name, path):
|
17 |
+
r"""Register a path to registry with key 'name'
|
18 |
+
|
19 |
+
Args:
|
20 |
+
name: Key with which the path will be registered.
|
21 |
+
|
22 |
+
Usage:
|
23 |
+
|
24 |
+
from minigpt4.common.registry import registry
|
25 |
+
"""
|
26 |
+
assert isinstance(path, str), "All path must be str."
|
27 |
+
if name in cls.mapping["paths"]:
|
28 |
+
raise KeyError("Name '{}' already registered.".format(name))
|
29 |
+
cls.mapping["paths"][name] = path
|
30 |
+
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def register(cls, name, obj):
|
34 |
+
r"""Register an item to registry with key 'name'
|
35 |
+
|
36 |
+
Args:
|
37 |
+
name: Key with which the item will be registered.
|
38 |
+
|
39 |
+
Usage::
|
40 |
+
|
41 |
+
from minigpt4.common.registry import registry
|
42 |
+
|
43 |
+
registry.register("config", {})
|
44 |
+
"""
|
45 |
+
path = name.split(".")
|
46 |
+
current = cls.mapping["state"]
|
47 |
+
|
48 |
+
for part in path[:-1]:
|
49 |
+
if part not in current:
|
50 |
+
current[part] = {}
|
51 |
+
current = current[part]
|
52 |
+
|
53 |
+
current[path[-1]] = obj
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def get_model_class(cls, name):
|
57 |
+
return cls.mapping["model_name_mapping"].get(name, None)
|
58 |
+
|
59 |
+
|
60 |
+
@classmethod
|
61 |
+
def register_model(cls, name):
|
62 |
+
r"""Register a task to registry with key 'name'
|
63 |
+
|
64 |
+
Args:
|
65 |
+
name: Key with which the task will be registered.
|
66 |
+
|
67 |
+
Usage:
|
68 |
+
|
69 |
+
from minigpt4.common.registry import registry
|
70 |
+
"""
|
71 |
+
|
72 |
+
def wrap(model_cls):
|
73 |
+
if name in cls.mapping['model_name_mapping']:
|
74 |
+
raise KeyError(
|
75 |
+
"Name '{}' already registered for {}.".format(
|
76 |
+
name, cls.mapping["model_name_mapping"][name]
|
77 |
+
)
|
78 |
+
)
|
79 |
+
cls.mapping['model_name_mapping'][name] = model_cls
|
80 |
+
return model_cls
|
81 |
+
|
82 |
+
return wrap
|
83 |
+
|
84 |
+
@classmethod
|
85 |
+
def get_path(cls, name):
|
86 |
+
return cls.mapping["paths"].get(name, None)
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
|
91 |
+
registry = Registry()
|
code/gpt4-clone/gpt4clone/common/utils.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from gpt4clone.common.registry import registry
|
3 |
+
|
4 |
+
|
5 |
+
def get_abs_path(rel_path):
|
6 |
+
return os.path.join(registry.get_path("library_root"), rel_path)
|
code/gpt4-clone/gpt4clone/configs/default.yaml
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
env:
|
2 |
+
# For default users
|
3 |
+
# cache_root: "cache"
|
4 |
+
# For internal use with persistent storage
|
5 |
+
cache_root: "/export/home/.cache/minigpt4"
|
code/gpt4-clone/gpt4clone/configs/models/minigpt4.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
arch: mini_gpt4
|
3 |
+
|
4 |
+
# vit encoder
|
5 |
+
image_size: 224
|
6 |
+
drop_path_rate: 0
|
7 |
+
use_grad_checkpoint: False
|
8 |
+
vit_precision: "fp16"
|
9 |
+
freeze_vit: True
|
10 |
+
freeze_qformer: True
|
11 |
+
|
12 |
+
# Q-Former
|
13 |
+
num_query_token: 32
|
14 |
+
|
15 |
+
# Vicuna
|
16 |
+
llama_model: "/workspace/weights/vicuna_7b/7b_v0/"
|
17 |
+
|
18 |
+
# generation configs
|
19 |
+
prompt: ""
|
20 |
+
|
21 |
+
preprocess:
|
22 |
+
vis_processor:
|
23 |
+
train:
|
24 |
+
name: "blip2_image_train"
|
25 |
+
image_size: 224
|
26 |
+
eval:
|
27 |
+
name: "blip2_image_eval"
|
28 |
+
image_size: 224
|
29 |
+
text_processor:
|
30 |
+
train:
|
31 |
+
name: "blip_caption"
|
32 |
+
eval:
|
33 |
+
name: "blip_caption"
|
code/gpt4-clone/gpt4clone/models/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gpt4clone.models.mini_gpt4 import MiniGPT4
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'MiniGPT4'
|
5 |
+
]
|
code/gpt4-clone/gpt4clone/models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (229 Bytes). View file
|
|
code/gpt4-clone/gpt4clone/models/__pycache__/mini_gpt4.cpython-39.pyc
ADDED
Binary file (981 Bytes). View file
|
|
code/gpt4-clone/gpt4clone/models/mini_gpt4.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gpt4clone.common.registry import registry
|
2 |
+
from gpt4clone.models.blip2 import Blip2Base, disabled_train
|
3 |
+
from gpt4clone.common.utils import get_abs_path
|
4 |
+
|
5 |
+
|
6 |
+
@registry.register_model('mini_gpt4')
|
7 |
+
class MiniGPT4(Blip2Base):
|
8 |
+
print('register mini_gpt4')
|
9 |
+
|
10 |
+
PRETRAINED_MODEL_CONFIG_DICT = {
|
11 |
+
"pretrain_vicuna": "configs/models/minigpt4.yaml",
|
12 |
+
}
|
13 |
+
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
vit_model="eva_clip_g",
|
17 |
+
q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
|
18 |
+
img_size=224,
|
19 |
+
drop_path_rate=0,
|
20 |
+
use_grad_checkpoint=False,
|
21 |
+
vit_precision="fp16",
|
22 |
+
freeze_vit=True,
|
23 |
+
freeze_qformer=True,
|
24 |
+
num_query_token=32,
|
25 |
+
llama_model="",
|
26 |
+
prompt_path="",
|
27 |
+
prompt_template="",
|
28 |
+
max_txt_len=32,
|
29 |
+
end_sym='\n',
|
30 |
+
low_resource=False, # use 8 bit and put vit in cpu
|
31 |
+
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
|
32 |
+
):
|
33 |
+
super().__init__()
|
34 |
+
|
35 |
+
self.tokenizer = self.init_tokenizer()
|
36 |
+
self.low_resource = low_resource
|
37 |
+
|
38 |
+
print('Loading VIT')
|
39 |
+
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
|
40 |
+
vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
|
41 |
+
)
|
42 |
+
|
43 |
+
if freeze_vit:
|
44 |
+
for name, param in self.visual_encoder.named_parameters():
|
45 |
+
param.requires_grad = False
|
46 |
+
self.visual_encoder = self.visual_encoder.eval()
|
47 |
+
self.visual_encoder.train = disabled_train
|
48 |
+
for name, param in self.ln_vision.named_parameters():
|
49 |
+
param.requires_grad = False
|
50 |
+
self.ln_vision = self.ln_vision.eval()
|
51 |
+
self.ln_vision.train = disabled_train
|
52 |
+
logging.info("freeze vision encoder")
|
53 |
+
print('Loading VIT Done')
|
54 |
+
|
55 |
+
if freeze_qformer:
|
56 |
+
for name, param in self.Qformer.named_parameters():
|
57 |
+
param.requires_grad = False
|
58 |
+
self.Qformer = self.Qformer.eval()
|
59 |
+
self.Qformer.train = disabled_train
|
60 |
+
self.query_tokens.requires_grad = False
|
61 |
+
logging.info("freeze Qformer")
|
62 |
+
print('Loading Q-Former Done')
|
63 |
+
|
64 |
+
|
65 |
+
print('Loading LLAMA')
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
@classmethod
|
71 |
+
def default_config_path(cls, model_type):
|
72 |
+
return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def from_config(cls, cfg):
|
78 |
+
print('from_config', cfg)
|
79 |
+
vit_model = cfg.get('vit_model', 'eva_clip_g')
|
80 |
+
q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
|
81 |
+
img_size = cfg.get('image_size')
|
82 |
+
num_query_token = cfg.get("num_query_token")
|
83 |
+
llama_model = cfg.get("llama_model")
|
84 |
+
|
85 |
+
drop_path_rate = cfg.get("drop_path_rate", 0)
|
86 |
+
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
|
87 |
+
vit_precision = cfg.get("vit_precision", "fp16")
|
88 |
+
freeze_vit = cfg.get("freeze_vit", True)
|
89 |
+
freeze_qformer = cfg.get("freeze_qformer", True)
|
90 |
+
low_resource = cfg.get("low_resource", False)
|
91 |
+
device_8bit = cfg.get("device_8bit", 0)
|
92 |
+
|
93 |
+
prompt_path = cfg.get("prompt_path", "")
|
94 |
+
prompt_template = cfg.get("prompt_template", "")
|
95 |
+
max_txt_len = cfg.get("max_txt_len", 32)
|
96 |
+
end_sym = cfg.get("end_sym", '\n')
|
97 |
+
|
98 |
+
model = cls(
|
99 |
+
vit_model=vit_model,
|
100 |
+
q_former_model=q_former_model,
|
101 |
+
img_size=img_size,
|
102 |
+
drop_path_rate=drop_path_rate,
|
103 |
+
use_grad_checkpoint=use_grad_checkpoint,
|
104 |
+
vit_precision=vit_precision,
|
105 |
+
freeze_vit=freeze_vit,
|
106 |
+
freeze_qformer=freeze_qformer,
|
107 |
+
num_query_token=num_query_token,
|
108 |
+
llama_model=llama_model,
|
109 |
+
prompt_path=prompt_path,
|
110 |
+
prompt_template=prompt_template,
|
111 |
+
max_txt_len=max_txt_len,
|
112 |
+
end_sym=end_sym,
|
113 |
+
low_resource=low_resource,
|
114 |
+
device_8bit=device_8bit,
|
115 |
+
)
|
116 |
+
|
117 |
+
ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
|
118 |
+
if ckpt_path:
|
119 |
+
print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
|
120 |
+
ckpt = torch.load(ckpt_path, map_location="cpu")
|
121 |
+
msg = model.load_state_dict(ckpt['model'], strict=False)
|
122 |
+
|
123 |
+
return model
|
124 |
+
|
125 |
+
|