SunderAli17 commited on
Commit
1295a0c
·
verified ·
1 Parent(s): 48eb44d

Create utils.py

Browse files
Files changed (1) hide show
  1. module/ip_adapter/utils.py +248 -0
module/ip_adapter/utils.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import namedtuple, OrderedDict
3
+ from safetensors import safe_open
4
+ from .attention_processor import init_attn_proc
5
+ from .ip_adapter import MultiIPAdapterImageProjection
6
+ from .resampler import Resampler
7
+ from transformers import (
8
+ AutoModel, AutoImageProcessor,
9
+ CLIPVisionModelWithProjection, CLIPImageProcessor)
10
+
11
+
12
+ def init_adapter_in_unet(
13
+ unet,
14
+ image_proj_model=None,
15
+ pretrained_model_path_or_dict=None,
16
+ adapter_tokens=64,
17
+ embedding_dim=None,
18
+ use_lcm=False,
19
+ use_adaln=True,
20
+ ):
21
+ device = unet.device
22
+ dtype = unet.dtype
23
+ if image_proj_model is None:
24
+ assert embedding_dim is not None, "embedding_dim must be provided if image_proj_model is None."
25
+ image_proj_model = Resampler(
26
+ embedding_dim=embedding_dim,
27
+ output_dim=unet.config.cross_attention_dim,
28
+ num_queries=adapter_tokens,
29
+ )
30
+ if pretrained_model_path_or_dict is not None:
31
+ if not isinstance(pretrained_model_path_or_dict, dict):
32
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
33
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
34
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=unet.device) as f:
35
+ for key in f.keys():
36
+ if key.startswith("image_proj."):
37
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
38
+ elif key.startswith("ip_adapter."):
39
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
40
+ else:
41
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=unet.device)
42
+ else:
43
+ state_dict = pretrained_model_path_or_dict
44
+ keys = list(state_dict.keys())
45
+ if "image_proj" not in keys and "ip_adapter" not in keys:
46
+ state_dict = revise_state_dict(state_dict)
47
+
48
+ # Creat IP cross-attention in unet.
49
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
50
+ unet.set_attn_processor(attn_procs)
51
+
52
+ # Load pretrinaed model if needed.
53
+ if pretrained_model_path_or_dict is not None:
54
+ if "ip_adapter" in state_dict.keys():
55
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
56
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
57
+ for mk in missing:
58
+ if "ln" not in mk:
59
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
60
+ if "image_proj" in state_dict.keys():
61
+ image_proj_model.load_state_dict(state_dict["image_proj"])
62
+
63
+ # Load image projectors into iterable ModuleList.
64
+ image_projection_layers = []
65
+ image_projection_layers.append(image_proj_model)
66
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
67
+
68
+ # Adjust unet config to handle addtional ip hidden states.
69
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
70
+ unet.to(dtype=dtype, device=device)
71
+
72
+
73
+ def load_adapter_to_pipe(
74
+ pipe,
75
+ pretrained_model_path_or_dict,
76
+ image_encoder_or_path=None,
77
+ feature_extractor_or_path=None,
78
+ use_clip_encoder=False,
79
+ adapter_tokens=64,
80
+ use_lcm=False,
81
+ use_adaln=True,
82
+ ):
83
+
84
+ if not isinstance(pretrained_model_path_or_dict, dict):
85
+ if pretrained_model_path_or_dict.endswith(".safetensors"):
86
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
87
+ with safe_open(pretrained_model_path_or_dict, framework="pt", device=pipe.device) as f:
88
+ for key in f.keys():
89
+ if key.startswith("image_proj."):
90
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
91
+ elif key.startswith("ip_adapter."):
92
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
93
+ else:
94
+ state_dict = torch.load(pretrained_model_path_or_dict, map_location=pipe.device)
95
+ else:
96
+ state_dict = pretrained_model_path_or_dict
97
+ keys = list(state_dict.keys())
98
+ if "image_proj" not in keys and "ip_adapter" not in keys:
99
+ state_dict = revise_state_dict(state_dict)
100
+
101
+ # load CLIP image encoder here if it has not been registered to the pipeline yet
102
+ if image_encoder_or_path is not None:
103
+ if isinstance(image_encoder_or_path, str):
104
+ feature_extractor_or_path = image_encoder_or_path if feature_extractor_or_path is None else feature_extractor_or_path
105
+
106
+ image_encoder_or_path = (
107
+ CLIPVisionModelWithProjection.from_pretrained(
108
+ image_encoder_or_path
109
+ ) if use_clip_encoder else
110
+ AutoModel.from_pretrained(image_encoder_or_path)
111
+ )
112
+
113
+ if feature_extractor_or_path is not None:
114
+ if isinstance(feature_extractor_or_path, str):
115
+ feature_extractor_or_path = (
116
+ CLIPImageProcessor() if use_clip_encoder else
117
+ AutoImageProcessor.from_pretrained(feature_extractor_or_path)
118
+ )
119
+
120
+ # create image encoder if it has not been registered to the pipeline yet
121
+ if hasattr(pipe, "image_encoder") and getattr(pipe, "image_encoder", None) is None:
122
+ image_encoder = image_encoder_or_path.to(pipe.device, dtype=pipe.dtype)
123
+ pipe.register_modules(image_encoder=image_encoder)
124
+ else:
125
+ image_encoder = pipe.image_encoder
126
+
127
+ # create feature extractor if it has not been registered to the pipeline yet
128
+ if hasattr(pipe, "feature_extractor") and getattr(pipe, "feature_extractor", None) is None:
129
+ feature_extractor = feature_extractor_or_path
130
+ pipe.register_modules(feature_extractor=feature_extractor)
131
+ else:
132
+ feature_extractor = pipe.feature_extractor
133
+
134
+ # load adapter into unet
135
+ unet = getattr(pipe, pipe.unet_name) if not hasattr(pipe, "unet") else pipe.unet
136
+ attn_procs = init_attn_proc(unet, adapter_tokens, use_lcm, use_adaln)
137
+ unet.set_attn_processor(attn_procs)
138
+ image_proj_model = Resampler(
139
+ embedding_dim=image_encoder.config.hidden_size,
140
+ output_dim=unet.config.cross_attention_dim,
141
+ num_queries=adapter_tokens,
142
+ )
143
+
144
+ # Load pretrinaed model if needed.
145
+ if "ip_adapter" in state_dict.keys():
146
+ adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
147
+ missing, unexpected = adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=False)
148
+ for mk in missing:
149
+ if "ln" not in mk:
150
+ raise ValueError(f"Missing keys in adapter_modules: {missing}")
151
+ if "image_proj" in state_dict.keys():
152
+ image_proj_model.load_state_dict(state_dict["image_proj"])
153
+
154
+ # convert IP-Adapter Image Projection layers to diffusers
155
+ image_projection_layers = []
156
+ image_projection_layers.append(image_proj_model)
157
+ unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
158
+
159
+ # Adjust unet config to handle addtional ip hidden states.
160
+ unet.config.encoder_hid_dim_type = "ip_image_proj"
161
+ unet.to(dtype=pipe.dtype, device=pipe.device)
162
+
163
+
164
+ def revise_state_dict(old_state_dict_or_path, map_location="cpu"):
165
+ new_state_dict = OrderedDict()
166
+ new_state_dict["image_proj"] = OrderedDict()
167
+ new_state_dict["ip_adapter"] = OrderedDict()
168
+ if isinstance(old_state_dict_or_path, str):
169
+ old_state_dict = torch.load(old_state_dict_or_path, map_location=map_location)
170
+ else:
171
+ old_state_dict = old_state_dict_or_path
172
+ for name, weight in old_state_dict.items():
173
+ if name.startswith("image_proj_model."):
174
+ new_state_dict["image_proj"][name[len("image_proj_model."):]] = weight
175
+ elif name.startswith("adapter_modules."):
176
+ new_state_dict["ip_adapter"][name[len("adapter_modules."):]] = weight
177
+ return new_state_dict
178
+
179
+
180
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
181
+ def encode_image(image_encoder, feature_extractor, image, device, num_images_per_prompt, output_hidden_states=None):
182
+ dtype = next(image_encoder.parameters()).dtype
183
+
184
+ if not isinstance(image, torch.Tensor):
185
+ image = feature_extractor(image, return_tensors="pt").pixel_values
186
+
187
+ image = image.to(device=device, dtype=dtype)
188
+ if output_hidden_states:
189
+ image_enc_hidden_states = image_encoder(image, output_hidden_states=True).hidden_states[-2]
190
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
191
+ return image_enc_hidden_states
192
+ else:
193
+ if isinstance(image_encoder, CLIPVisionModelWithProjection):
194
+ # CLIP image encoder.
195
+ image_embeds = image_encoder(image).image_embeds
196
+ else:
197
+ # DINO image encoder.
198
+ image_embeds = image_encoder(image).last_hidden_state
199
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
200
+ return image_embeds
201
+
202
+
203
+ def prepare_training_image_embeds(
204
+ image_encoder, feature_extractor,
205
+ ip_adapter_image, ip_adapter_image_embeds,
206
+ device, drop_rate, output_hidden_state, idx_to_replace=None
207
+ ):
208
+ if ip_adapter_image_embeds is None:
209
+ if not isinstance(ip_adapter_image, list):
210
+ ip_adapter_image = [ip_adapter_image]
211
+
212
+ # if len(ip_adapter_image) != len(unet.encoder_hid_proj.image_projection_layers):
213
+ # raise ValueError(
214
+ # f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
215
+ # )
216
+
217
+ image_embeds = []
218
+ for single_ip_adapter_image in ip_adapter_image:
219
+ if idx_to_replace is None:
220
+ idx_to_replace = torch.rand(len(single_ip_adapter_image)) < drop_rate
221
+ zero_ip_adapter_image = torch.zeros_like(single_ip_adapter_image)
222
+ single_ip_adapter_image[idx_to_replace] = zero_ip_adapter_image[idx_to_replace]
223
+ single_image_embeds = encode_image(
224
+ image_encoder, feature_extractor, single_ip_adapter_image, device, 1, output_hidden_state
225
+ )
226
+ single_image_embeds = torch.stack([single_image_embeds], dim=1) # FIXME
227
+
228
+ image_embeds.append(single_image_embeds)
229
+ else:
230
+ repeat_dims = [1]
231
+ image_embeds = []
232
+ for single_image_embeds in ip_adapter_image_embeds:
233
+ if do_classifier_free_guidance:
234
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
235
+ single_image_embeds = single_image_embeds.repeat(
236
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
237
+ )
238
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
239
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
240
+ )
241
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
242
+ else:
243
+ single_image_embeds = single_image_embeds.repeat(
244
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
245
+ )
246
+ image_embeds.append(single_image_embeds)
247
+
248
+ return image_embeds