Baykon commited on
Commit
d7b0443
1 Parent(s): 54f5ad4
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. MagicQuill/.DS_Store +0 -0
  2. MagicQuill/brushnet/brushnet.json +0 -58
  3. MagicQuill/brushnet/brushnet.py +0 -949
  4. MagicQuill/brushnet/brushnet_ca.py +0 -983
  5. MagicQuill/brushnet/brushnet_xl.json +0 -63
  6. MagicQuill/brushnet/powerpaint.json +0 -57
  7. MagicQuill/brushnet/powerpaint_utils.py +0 -496
  8. MagicQuill/brushnet/unet_2d_blocks.py +0 -0
  9. MagicQuill/brushnet/unet_2d_condition.py +0 -1355
  10. MagicQuill/brushnet_nodes.py +0 -1094
  11. MagicQuill/comfy/.DS_Store +0 -0
  12. MagicQuill/comfy/checkpoint_pickle.py +0 -13
  13. MagicQuill/comfy/cldm/__pycache__/cldm.cpython-310.pyc +0 -0
  14. MagicQuill/comfy/cldm/cldm.py +0 -313
  15. MagicQuill/comfy/cli_args.py +0 -143
  16. MagicQuill/comfy/clip_config_bigg.json +0 -23
  17. MagicQuill/comfy/clip_model.py +0 -194
  18. MagicQuill/comfy/clip_vision.py +0 -117
  19. MagicQuill/comfy/clip_vision_config_g.json +0 -18
  20. MagicQuill/comfy/clip_vision_config_h.json +0 -18
  21. MagicQuill/comfy/clip_vision_config_vitl.json +0 -18
  22. MagicQuill/comfy/conds.py +0 -83
  23. MagicQuill/comfy/controlnet.py +0 -554
  24. MagicQuill/comfy/diffusers_convert.py +0 -281
  25. MagicQuill/comfy/diffusers_load.py +0 -36
  26. MagicQuill/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc +0 -0
  27. MagicQuill/comfy/extra_samplers/uni_pc.py +0 -875
  28. MagicQuill/comfy/gligen.py +0 -343
  29. MagicQuill/comfy/k_diffusion/__pycache__/sampling.cpython-310.pyc +0 -0
  30. MagicQuill/comfy/k_diffusion/__pycache__/utils.cpython-310.pyc +0 -0
  31. MagicQuill/comfy/k_diffusion/sampling.py +0 -843
  32. MagicQuill/comfy/k_diffusion/utils.py +0 -313
  33. MagicQuill/comfy/latent_formats.py +0 -141
  34. MagicQuill/comfy/ldm/.DS_Store +0 -0
  35. MagicQuill/comfy/ldm/__pycache__/util.cpython-310.pyc +0 -0
  36. MagicQuill/comfy/ldm/audio/__pycache__/autoencoder.cpython-310.pyc +0 -0
  37. MagicQuill/comfy/ldm/audio/__pycache__/dit.cpython-310.pyc +0 -0
  38. MagicQuill/comfy/ldm/audio/__pycache__/embedders.cpython-310.pyc +0 -0
  39. MagicQuill/comfy/ldm/audio/autoencoder.py +0 -282
  40. MagicQuill/comfy/ldm/audio/dit.py +0 -888
  41. MagicQuill/comfy/ldm/audio/embedders.py +0 -108
  42. MagicQuill/comfy/ldm/cascade/__pycache__/common.cpython-310.pyc +0 -0
  43. MagicQuill/comfy/ldm/cascade/__pycache__/controlnet.cpython-310.pyc +0 -0
  44. MagicQuill/comfy/ldm/cascade/__pycache__/stage_a.cpython-310.pyc +0 -0
  45. MagicQuill/comfy/ldm/cascade/__pycache__/stage_b.cpython-310.pyc +0 -0
  46. MagicQuill/comfy/ldm/cascade/__pycache__/stage_c.cpython-310.pyc +0 -0
  47. MagicQuill/comfy/ldm/cascade/__pycache__/stage_c_coder.cpython-310.pyc +0 -0
  48. MagicQuill/comfy/ldm/cascade/common.py +0 -161
  49. MagicQuill/comfy/ldm/cascade/controlnet.py +0 -93
  50. MagicQuill/comfy/ldm/cascade/stage_a.py +0 -255
MagicQuill/.DS_Store DELETED
Binary file (6.15 kB)
 
MagicQuill/brushnet/brushnet.json DELETED
@@ -1,58 +0,0 @@
1
- {
2
- "_class_name": "BrushNetModel",
3
- "_diffusers_version": "0.27.0.dev0",
4
- "_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
5
- "act_fn": "silu",
6
- "addition_embed_type": null,
7
- "addition_embed_type_num_heads": 64,
8
- "addition_time_embed_dim": null,
9
- "attention_head_dim": 8,
10
- "block_out_channels": [
11
- 320,
12
- 640,
13
- 1280,
14
- 1280
15
- ],
16
- "brushnet_conditioning_channel_order": "rgb",
17
- "class_embed_type": null,
18
- "conditioning_channels": 5,
19
- "conditioning_embedding_out_channels": [
20
- 16,
21
- 32,
22
- 96,
23
- 256
24
- ],
25
- "cross_attention_dim": 768,
26
- "down_block_types": [
27
- "DownBlock2D",
28
- "DownBlock2D",
29
- "DownBlock2D",
30
- "DownBlock2D"
31
- ],
32
- "downsample_padding": 1,
33
- "encoder_hid_dim": null,
34
- "encoder_hid_dim_type": null,
35
- "flip_sin_to_cos": true,
36
- "freq_shift": 0,
37
- "global_pool_conditions": false,
38
- "in_channels": 4,
39
- "layers_per_block": 2,
40
- "mid_block_scale_factor": 1,
41
- "mid_block_type": "MidBlock2D",
42
- "norm_eps": 1e-05,
43
- "norm_num_groups": 32,
44
- "num_attention_heads": null,
45
- "num_class_embeds": null,
46
- "only_cross_attention": false,
47
- "projection_class_embeddings_input_dim": null,
48
- "resnet_time_scale_shift": "default",
49
- "transformer_layers_per_block": 1,
50
- "up_block_types": [
51
- "UpBlock2D",
52
- "UpBlock2D",
53
- "UpBlock2D",
54
- "UpBlock2D"
55
- ],
56
- "upcast_attention": false,
57
- "use_linear_projection": false
58
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/brushnet.py DELETED
@@ -1,949 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
- from torch.nn import functional as F
7
-
8
- from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from diffusers.utils import BaseOutput, logging
10
- from diffusers.models.attention_processor import (
11
- ADDED_KV_ATTENTION_PROCESSORS,
12
- CROSS_ATTENTION_PROCESSORS,
13
- AttentionProcessor,
14
- AttnAddedKVProcessor,
15
- AttnProcessor,
16
- )
17
- from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
18
- from diffusers.models.modeling_utils import ModelMixin
19
-
20
- from .unet_2d_blocks import (
21
- CrossAttnDownBlock2D,
22
- DownBlock2D,
23
- UNetMidBlock2D,
24
- UNetMidBlock2DCrossAttn,
25
- get_down_block,
26
- get_mid_block,
27
- get_up_block,
28
- MidBlock2D
29
- )
30
-
31
- from .unet_2d_condition import UNet2DConditionModel
32
-
33
-
34
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
-
36
-
37
- @dataclass
38
- class BrushNetOutput(BaseOutput):
39
- """
40
- The output of [`BrushNetModel`].
41
-
42
- Args:
43
- up_block_res_samples (`tuple[torch.Tensor]`):
44
- A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
45
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
46
- used to condition the original UNet's upsampling activations.
47
- down_block_res_samples (`tuple[torch.Tensor]`):
48
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
49
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
50
- used to condition the original UNet's downsampling activations.
51
- mid_down_block_re_sample (`torch.Tensor`):
52
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
53
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
54
- Output can be used to condition the original UNet's middle block activation.
55
- """
56
-
57
- up_block_res_samples: Tuple[torch.Tensor]
58
- down_block_res_samples: Tuple[torch.Tensor]
59
- mid_block_res_sample: torch.Tensor
60
-
61
-
62
- class BrushNetModel(ModelMixin, ConfigMixin):
63
- """
64
- A BrushNet model.
65
-
66
- Args:
67
- in_channels (`int`, defaults to 4):
68
- The number of channels in the input sample.
69
- flip_sin_to_cos (`bool`, defaults to `True`):
70
- Whether to flip the sin to cos in the time embedding.
71
- freq_shift (`int`, defaults to 0):
72
- The frequency shift to apply to the time embedding.
73
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
74
- The tuple of downsample blocks to use.
75
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
76
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
77
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
78
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
79
- The tuple of upsample blocks to use.
80
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
81
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
82
- The tuple of output channels for each block.
83
- layers_per_block (`int`, defaults to 2):
84
- The number of layers per block.
85
- downsample_padding (`int`, defaults to 1):
86
- The padding to use for the downsampling convolution.
87
- mid_block_scale_factor (`float`, defaults to 1):
88
- The scale factor to use for the mid block.
89
- act_fn (`str`, defaults to "silu"):
90
- The activation function to use.
91
- norm_num_groups (`int`, *optional*, defaults to 32):
92
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
93
- in post-processing.
94
- norm_eps (`float`, defaults to 1e-5):
95
- The epsilon to use for the normalization.
96
- cross_attention_dim (`int`, defaults to 1280):
97
- The dimension of the cross attention features.
98
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
99
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
100
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
101
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
102
- encoder_hid_dim (`int`, *optional*, defaults to None):
103
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
104
- dimension to `cross_attention_dim`.
105
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
106
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
107
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
108
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
109
- The dimension of the attention heads.
110
- use_linear_projection (`bool`, defaults to `False`):
111
- class_embed_type (`str`, *optional*, defaults to `None`):
112
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
113
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
114
- addition_embed_type (`str`, *optional*, defaults to `None`):
115
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
116
- "text". "text" will use the `TextTimeEmbedding` layer.
117
- num_class_embeds (`int`, *optional*, defaults to 0):
118
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
119
- class conditioning with `class_embed_type` equal to `None`.
120
- upcast_attention (`bool`, defaults to `False`):
121
- resnet_time_scale_shift (`str`, defaults to `"default"`):
122
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
123
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
124
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
125
- `class_embed_type="projection"`.
126
- brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
127
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
128
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
129
- The tuple of output channel for each block in the `conditioning_embedding` layer.
130
- global_pool_conditions (`bool`, defaults to `False`):
131
- TODO(Patrick) - unused parameter.
132
- addition_embed_type_num_heads (`int`, defaults to 64):
133
- The number of heads to use for the `TextTimeEmbedding` layer.
134
- """
135
-
136
- _supports_gradient_checkpointing = True
137
-
138
- @register_to_config
139
- def __init__(
140
- self,
141
- in_channels: int = 4,
142
- conditioning_channels: int = 5,
143
- flip_sin_to_cos: bool = True,
144
- freq_shift: int = 0,
145
- down_block_types: Tuple[str, ...] = (
146
- "DownBlock2D",
147
- "DownBlock2D",
148
- "DownBlock2D",
149
- "DownBlock2D",
150
- ),
151
- mid_block_type: Optional[str] = "UNetMidBlock2D",
152
- up_block_types: Tuple[str, ...] = (
153
- "UpBlock2D",
154
- "UpBlock2D",
155
- "UpBlock2D",
156
- "UpBlock2D",
157
- ),
158
- only_cross_attention: Union[bool, Tuple[bool]] = False,
159
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
160
- layers_per_block: int = 2,
161
- downsample_padding: int = 1,
162
- mid_block_scale_factor: float = 1,
163
- act_fn: str = "silu",
164
- norm_num_groups: Optional[int] = 32,
165
- norm_eps: float = 1e-5,
166
- cross_attention_dim: int = 1280,
167
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
168
- encoder_hid_dim: Optional[int] = None,
169
- encoder_hid_dim_type: Optional[str] = None,
170
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
171
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
172
- use_linear_projection: bool = False,
173
- class_embed_type: Optional[str] = None,
174
- addition_embed_type: Optional[str] = None,
175
- addition_time_embed_dim: Optional[int] = None,
176
- num_class_embeds: Optional[int] = None,
177
- upcast_attention: bool = False,
178
- resnet_time_scale_shift: str = "default",
179
- projection_class_embeddings_input_dim: Optional[int] = None,
180
- brushnet_conditioning_channel_order: str = "rgb",
181
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
182
- global_pool_conditions: bool = False,
183
- addition_embed_type_num_heads: int = 64,
184
- ):
185
- super().__init__()
186
-
187
- # If `num_attention_heads` is not defined (which is the case for most models)
188
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
189
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
190
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
191
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
192
- # which is why we correct for the naming here.
193
- num_attention_heads = num_attention_heads or attention_head_dim
194
-
195
- # Check inputs
196
- if len(down_block_types) != len(up_block_types):
197
- raise ValueError(
198
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
199
- )
200
-
201
- if len(block_out_channels) != len(down_block_types):
202
- raise ValueError(
203
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
204
- )
205
-
206
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
207
- raise ValueError(
208
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
209
- )
210
-
211
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
212
- raise ValueError(
213
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
214
- )
215
-
216
- if isinstance(transformer_layers_per_block, int):
217
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
218
-
219
- # input
220
- conv_in_kernel = 3
221
- conv_in_padding = (conv_in_kernel - 1) // 2
222
- self.conv_in_condition = nn.Conv2d(
223
- in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
224
- )
225
-
226
- # time
227
- time_embed_dim = block_out_channels[0] * 4
228
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
- timestep_input_dim = block_out_channels[0]
230
- self.time_embedding = TimestepEmbedding(
231
- timestep_input_dim,
232
- time_embed_dim,
233
- act_fn=act_fn,
234
- )
235
-
236
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
237
- encoder_hid_dim_type = "text_proj"
238
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
239
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
240
-
241
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
242
- raise ValueError(
243
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
244
- )
245
-
246
- if encoder_hid_dim_type == "text_proj":
247
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
248
- elif encoder_hid_dim_type == "text_image_proj":
249
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
250
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
251
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
252
- self.encoder_hid_proj = TextImageProjection(
253
- text_embed_dim=encoder_hid_dim,
254
- image_embed_dim=cross_attention_dim,
255
- cross_attention_dim=cross_attention_dim,
256
- )
257
-
258
- elif encoder_hid_dim_type is not None:
259
- raise ValueError(
260
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
261
- )
262
- else:
263
- self.encoder_hid_proj = None
264
-
265
- # class embedding
266
- if class_embed_type is None and num_class_embeds is not None:
267
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
268
- elif class_embed_type == "timestep":
269
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
270
- elif class_embed_type == "identity":
271
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
272
- elif class_embed_type == "projection":
273
- if projection_class_embeddings_input_dim is None:
274
- raise ValueError(
275
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
276
- )
277
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
278
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
279
- # 2. it projects from an arbitrary input dimension.
280
- #
281
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
282
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
283
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
284
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
285
- else:
286
- self.class_embedding = None
287
-
288
- if addition_embed_type == "text":
289
- if encoder_hid_dim is not None:
290
- text_time_embedding_from_dim = encoder_hid_dim
291
- else:
292
- text_time_embedding_from_dim = cross_attention_dim
293
-
294
- self.add_embedding = TextTimeEmbedding(
295
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
296
- )
297
- elif addition_embed_type == "text_image":
298
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
299
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
300
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
301
- self.add_embedding = TextImageTimeEmbedding(
302
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
303
- )
304
- elif addition_embed_type == "text_time":
305
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
306
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
307
-
308
- elif addition_embed_type is not None:
309
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
310
-
311
- self.down_blocks = nn.ModuleList([])
312
- self.brushnet_down_blocks = nn.ModuleList([])
313
-
314
- if isinstance(only_cross_attention, bool):
315
- only_cross_attention = [only_cross_attention] * len(down_block_types)
316
-
317
- if isinstance(attention_head_dim, int):
318
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
319
-
320
- if isinstance(num_attention_heads, int):
321
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
322
-
323
- # down
324
- output_channel = block_out_channels[0]
325
-
326
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
327
- brushnet_block = zero_module(brushnet_block)
328
- self.brushnet_down_blocks.append(brushnet_block)
329
-
330
- for i, down_block_type in enumerate(down_block_types):
331
- input_channel = output_channel
332
- output_channel = block_out_channels[i]
333
- is_final_block = i == len(block_out_channels) - 1
334
-
335
- down_block = get_down_block(
336
- down_block_type,
337
- num_layers=layers_per_block,
338
- transformer_layers_per_block=transformer_layers_per_block[i],
339
- in_channels=input_channel,
340
- out_channels=output_channel,
341
- temb_channels=time_embed_dim,
342
- add_downsample=not is_final_block,
343
- resnet_eps=norm_eps,
344
- resnet_act_fn=act_fn,
345
- resnet_groups=norm_num_groups,
346
- cross_attention_dim=cross_attention_dim,
347
- num_attention_heads=num_attention_heads[i],
348
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
349
- downsample_padding=downsample_padding,
350
- use_linear_projection=use_linear_projection,
351
- only_cross_attention=only_cross_attention[i],
352
- upcast_attention=upcast_attention,
353
- resnet_time_scale_shift=resnet_time_scale_shift,
354
- )
355
- self.down_blocks.append(down_block)
356
-
357
- for _ in range(layers_per_block):
358
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
359
- brushnet_block = zero_module(brushnet_block)
360
- self.brushnet_down_blocks.append(brushnet_block)
361
-
362
- if not is_final_block:
363
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
364
- brushnet_block = zero_module(brushnet_block)
365
- self.brushnet_down_blocks.append(brushnet_block)
366
-
367
- # mid
368
- mid_block_channel = block_out_channels[-1]
369
-
370
- brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
371
- brushnet_block = zero_module(brushnet_block)
372
- self.brushnet_mid_block = brushnet_block
373
-
374
- self.mid_block = get_mid_block(
375
- mid_block_type,
376
- transformer_layers_per_block=transformer_layers_per_block[-1],
377
- in_channels=mid_block_channel,
378
- temb_channels=time_embed_dim,
379
- resnet_eps=norm_eps,
380
- resnet_act_fn=act_fn,
381
- output_scale_factor=mid_block_scale_factor,
382
- resnet_time_scale_shift=resnet_time_scale_shift,
383
- cross_attention_dim=cross_attention_dim,
384
- num_attention_heads=num_attention_heads[-1],
385
- resnet_groups=norm_num_groups,
386
- use_linear_projection=use_linear_projection,
387
- upcast_attention=upcast_attention,
388
- )
389
-
390
- # count how many layers upsample the images
391
- self.num_upsamplers = 0
392
-
393
- # up
394
- reversed_block_out_channels = list(reversed(block_out_channels))
395
- reversed_num_attention_heads = list(reversed(num_attention_heads))
396
- reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
397
- only_cross_attention = list(reversed(only_cross_attention))
398
-
399
- output_channel = reversed_block_out_channels[0]
400
-
401
- self.up_blocks = nn.ModuleList([])
402
- self.brushnet_up_blocks = nn.ModuleList([])
403
-
404
- for i, up_block_type in enumerate(up_block_types):
405
- is_final_block = i == len(block_out_channels) - 1
406
-
407
- prev_output_channel = output_channel
408
- output_channel = reversed_block_out_channels[i]
409
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
410
-
411
- # add upsample block for all BUT final layer
412
- if not is_final_block:
413
- add_upsample = True
414
- self.num_upsamplers += 1
415
- else:
416
- add_upsample = False
417
-
418
- up_block = get_up_block(
419
- up_block_type,
420
- num_layers=layers_per_block+1,
421
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
422
- in_channels=input_channel,
423
- out_channels=output_channel,
424
- prev_output_channel=prev_output_channel,
425
- temb_channels=time_embed_dim,
426
- add_upsample=add_upsample,
427
- resnet_eps=norm_eps,
428
- resnet_act_fn=act_fn,
429
- resolution_idx=i,
430
- resnet_groups=norm_num_groups,
431
- cross_attention_dim=cross_attention_dim,
432
- num_attention_heads=reversed_num_attention_heads[i],
433
- use_linear_projection=use_linear_projection,
434
- only_cross_attention=only_cross_attention[i],
435
- upcast_attention=upcast_attention,
436
- resnet_time_scale_shift=resnet_time_scale_shift,
437
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
438
- )
439
- self.up_blocks.append(up_block)
440
- prev_output_channel = output_channel
441
-
442
- for _ in range(layers_per_block+1):
443
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
444
- brushnet_block = zero_module(brushnet_block)
445
- self.brushnet_up_blocks.append(brushnet_block)
446
-
447
- if not is_final_block:
448
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
449
- brushnet_block = zero_module(brushnet_block)
450
- self.brushnet_up_blocks.append(brushnet_block)
451
-
452
-
453
- @classmethod
454
- def from_unet(
455
- cls,
456
- unet: UNet2DConditionModel,
457
- brushnet_conditioning_channel_order: str = "rgb",
458
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
459
- load_weights_from_unet: bool = True,
460
- conditioning_channels: int = 5,
461
- ):
462
- r"""
463
- Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
464
-
465
- Parameters:
466
- unet (`UNet2DConditionModel`):
467
- The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
468
- where applicable.
469
- """
470
- transformer_layers_per_block = (
471
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
472
- )
473
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
474
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
475
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
476
- addition_time_embed_dim = (
477
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
478
- )
479
-
480
- brushnet = cls(
481
- in_channels=unet.config.in_channels,
482
- conditioning_channels=conditioning_channels,
483
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
484
- freq_shift=unet.config.freq_shift,
485
- down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
486
- mid_block_type='MidBlock2D',
487
- up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
488
- only_cross_attention=unet.config.only_cross_attention,
489
- block_out_channels=unet.config.block_out_channels,
490
- layers_per_block=unet.config.layers_per_block,
491
- downsample_padding=unet.config.downsample_padding,
492
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
493
- act_fn=unet.config.act_fn,
494
- norm_num_groups=unet.config.norm_num_groups,
495
- norm_eps=unet.config.norm_eps,
496
- cross_attention_dim=unet.config.cross_attention_dim,
497
- transformer_layers_per_block=transformer_layers_per_block,
498
- encoder_hid_dim=encoder_hid_dim,
499
- encoder_hid_dim_type=encoder_hid_dim_type,
500
- attention_head_dim=unet.config.attention_head_dim,
501
- num_attention_heads=unet.config.num_attention_heads,
502
- use_linear_projection=unet.config.use_linear_projection,
503
- class_embed_type=unet.config.class_embed_type,
504
- addition_embed_type=addition_embed_type,
505
- addition_time_embed_dim=addition_time_embed_dim,
506
- num_class_embeds=unet.config.num_class_embeds,
507
- upcast_attention=unet.config.upcast_attention,
508
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
- brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
511
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
512
- )
513
-
514
- if load_weights_from_unet:
515
- conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
516
- conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
517
- conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
518
- brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
519
- brushnet.conv_in_condition.bias=unet.conv_in.bias
520
-
521
- brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
522
- brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
523
-
524
- if brushnet.class_embedding:
525
- brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
526
-
527
- brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
528
- brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
529
- brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
530
-
531
- return brushnet
532
-
533
- @property
534
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
535
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
536
- r"""
537
- Returns:
538
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
539
- indexed by its weight name.
540
- """
541
- # set recursively
542
- processors = {}
543
-
544
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
545
- if hasattr(module, "get_processor"):
546
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
547
-
548
- for sub_name, child in module.named_children():
549
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
550
-
551
- return processors
552
-
553
- for name, module in self.named_children():
554
- fn_recursive_add_processors(name, module, processors)
555
-
556
- return processors
557
-
558
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
559
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
560
- r"""
561
- Sets the attention processor to use to compute attention.
562
-
563
- Parameters:
564
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
565
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
566
- for **all** `Attention` layers.
567
-
568
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
569
- processor. This is strongly recommended when setting trainable attention processors.
570
-
571
- """
572
- count = len(self.attn_processors.keys())
573
-
574
- if isinstance(processor, dict) and len(processor) != count:
575
- raise ValueError(
576
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
577
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
578
- )
579
-
580
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
581
- if hasattr(module, "set_processor"):
582
- if not isinstance(processor, dict):
583
- module.set_processor(processor)
584
- else:
585
- module.set_processor(processor.pop(f"{name}.processor"))
586
-
587
- for sub_name, child in module.named_children():
588
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
589
-
590
- for name, module in self.named_children():
591
- fn_recursive_attn_processor(name, module, processor)
592
-
593
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
594
- def set_default_attn_processor(self):
595
- """
596
- Disables custom attention processors and sets the default attention implementation.
597
- """
598
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
599
- processor = AttnAddedKVProcessor()
600
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
601
- processor = AttnProcessor()
602
- else:
603
- raise ValueError(
604
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
605
- )
606
-
607
- self.set_attn_processor(processor)
608
-
609
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
610
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
611
- r"""
612
- Enable sliced attention computation.
613
-
614
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
615
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
616
-
617
- Args:
618
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
619
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
620
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
621
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
622
- must be a multiple of `slice_size`.
623
- """
624
- sliceable_head_dims = []
625
-
626
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
627
- if hasattr(module, "set_attention_slice"):
628
- sliceable_head_dims.append(module.sliceable_head_dim)
629
-
630
- for child in module.children():
631
- fn_recursive_retrieve_sliceable_dims(child)
632
-
633
- # retrieve number of attention layers
634
- for module in self.children():
635
- fn_recursive_retrieve_sliceable_dims(module)
636
-
637
- num_sliceable_layers = len(sliceable_head_dims)
638
-
639
- if slice_size == "auto":
640
- # half the attention head size is usually a good trade-off between
641
- # speed and memory
642
- slice_size = [dim // 2 for dim in sliceable_head_dims]
643
- elif slice_size == "max":
644
- # make smallest slice possible
645
- slice_size = num_sliceable_layers * [1]
646
-
647
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
648
-
649
- if len(slice_size) != len(sliceable_head_dims):
650
- raise ValueError(
651
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
652
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
653
- )
654
-
655
- for i in range(len(slice_size)):
656
- size = slice_size[i]
657
- dim = sliceable_head_dims[i]
658
- if size is not None and size > dim:
659
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
660
-
661
- # Recursively walk through all the children.
662
- # Any children which exposes the set_attention_slice method
663
- # gets the message
664
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
665
- if hasattr(module, "set_attention_slice"):
666
- module.set_attention_slice(slice_size.pop())
667
-
668
- for child in module.children():
669
- fn_recursive_set_attention_slice(child, slice_size)
670
-
671
- reversed_slice_size = list(reversed(slice_size))
672
- for module in self.children():
673
- fn_recursive_set_attention_slice(module, reversed_slice_size)
674
-
675
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
676
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
677
- module.gradient_checkpointing = value
678
-
679
- def forward(
680
- self,
681
- sample: torch.FloatTensor,
682
- encoder_hidden_states: torch.Tensor,
683
- brushnet_cond: torch.FloatTensor,
684
- timestep = None,
685
- time_emb = None,
686
- conditioning_scale: float = 1.0,
687
- class_labels: Optional[torch.Tensor] = None,
688
- timestep_cond: Optional[torch.Tensor] = None,
689
- attention_mask: Optional[torch.Tensor] = None,
690
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
691
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
692
- guess_mode: bool = False,
693
- return_dict: bool = True,
694
- debug = False,
695
- ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
696
- """
697
- The [`BrushNetModel`] forward method.
698
-
699
- Args:
700
- sample (`torch.FloatTensor`):
701
- The noisy input tensor.
702
- timestep (`Union[torch.Tensor, float, int]`):
703
- The number of timesteps to denoise an input.
704
- encoder_hidden_states (`torch.Tensor`):
705
- The encoder hidden states.
706
- brushnet_cond (`torch.FloatTensor`):
707
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
708
- conditioning_scale (`float`, defaults to `1.0`):
709
- The scale factor for BrushNet outputs.
710
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
711
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
712
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
713
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
714
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
715
- embeddings.
716
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
717
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
718
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
719
- negative values to the attention scores corresponding to "discard" tokens.
720
- added_cond_kwargs (`dict`):
721
- Additional conditions for the Stable Diffusion XL UNet.
722
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
723
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
724
- guess_mode (`bool`, defaults to `False`):
725
- In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
726
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
727
- return_dict (`bool`, defaults to `True`):
728
- Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
729
-
730
- Returns:
731
- [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
732
- If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
733
- returned where the first element is the sample tensor.
734
- """
735
-
736
- # check channel order
737
- channel_order = self.config.brushnet_conditioning_channel_order
738
-
739
- if channel_order == "rgb":
740
- # in rgb order by default
741
- ...
742
- elif channel_order == "bgr":
743
- brushnet_cond = torch.flip(brushnet_cond, dims=[1])
744
- else:
745
- raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
746
-
747
- # prepare attention_mask
748
- if attention_mask is not None:
749
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
750
- attention_mask = attention_mask.unsqueeze(1)
751
-
752
- if timestep is None and time_emb is None:
753
- raise ValueError(f"`timestep` and `emb` are both None")
754
-
755
- #print("BN: sample.device", sample.device)
756
- #print("BN: TE.device", self.time_embedding.linear_1.weight.device)
757
-
758
- if timestep is not None:
759
- # 1. time
760
- timesteps = timestep
761
- if not torch.is_tensor(timesteps):
762
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
763
- # This would be a good case for the `match` statement (Python 3.10+)
764
- is_mps = sample.device.type == "mps"
765
- if isinstance(timestep, float):
766
- dtype = torch.float32 if is_mps else torch.float64
767
- else:
768
- dtype = torch.int32 if is_mps else torch.int64
769
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
770
- elif len(timesteps.shape) == 0:
771
- timesteps = timesteps[None].to(sample.device)
772
-
773
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
774
- timesteps = timesteps.expand(sample.shape[0])
775
-
776
- t_emb = self.time_proj(timesteps)
777
-
778
- # timesteps does not contain any weights and will always return f32 tensors
779
- # but time_embedding might actually be running in fp16. so we need to cast here.
780
- # there might be better ways to encapsulate this.
781
- t_emb = t_emb.to(dtype=sample.dtype)
782
-
783
- #print("t_emb.device =",t_emb.device)
784
-
785
- emb = self.time_embedding(t_emb, timestep_cond)
786
- aug_emb = None
787
-
788
- #print('emb.shape', emb.shape)
789
-
790
- if self.class_embedding is not None:
791
- if class_labels is None:
792
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
793
-
794
- if self.config.class_embed_type == "timestep":
795
- class_labels = self.time_proj(class_labels)
796
-
797
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
798
- emb = emb + class_emb
799
-
800
- if self.config.addition_embed_type is not None:
801
- if self.config.addition_embed_type == "text":
802
- aug_emb = self.add_embedding(encoder_hidden_states)
803
-
804
- elif self.config.addition_embed_type == "text_time":
805
- if "text_embeds" not in added_cond_kwargs:
806
- raise ValueError(
807
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
808
- )
809
- text_embeds = added_cond_kwargs.get("text_embeds")
810
- if "time_ids" not in added_cond_kwargs:
811
- raise ValueError(
812
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
813
- )
814
- time_ids = added_cond_kwargs.get("time_ids")
815
- time_embeds = self.add_time_proj(time_ids.flatten())
816
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
817
-
818
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
819
- add_embeds = add_embeds.to(emb.dtype)
820
- aug_emb = self.add_embedding(add_embeds)
821
-
822
- #print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
823
-
824
- emb = emb + aug_emb if aug_emb is not None else emb
825
- else:
826
- emb = time_emb
827
-
828
- # 2. pre-process
829
-
830
- brushnet_cond=torch.concat([sample,brushnet_cond],1)
831
- sample = self.conv_in_condition(brushnet_cond)
832
-
833
- # 3. down
834
- down_block_res_samples = (sample,)
835
- for downsample_block in self.down_blocks:
836
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
837
- sample, res_samples = downsample_block(
838
- hidden_states=sample,
839
- temb=emb,
840
- encoder_hidden_states=encoder_hidden_states,
841
- attention_mask=attention_mask,
842
- cross_attention_kwargs=cross_attention_kwargs,
843
- )
844
- else:
845
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
846
-
847
- down_block_res_samples += res_samples
848
-
849
- # 4. PaintingNet down blocks
850
- brushnet_down_block_res_samples = ()
851
- for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
852
- down_block_res_sample = brushnet_down_block(down_block_res_sample)
853
- brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
854
-
855
-
856
- # 5. mid
857
- if self.mid_block is not None:
858
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
859
- sample = self.mid_block(
860
- sample,
861
- emb,
862
- encoder_hidden_states=encoder_hidden_states,
863
- attention_mask=attention_mask,
864
- cross_attention_kwargs=cross_attention_kwargs,
865
- )
866
- else:
867
- sample = self.mid_block(sample, emb)
868
-
869
- # 6. BrushNet mid blocks
870
- brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
871
-
872
- # 7. up
873
- up_block_res_samples = ()
874
- for i, upsample_block in enumerate(self.up_blocks):
875
- is_final_block = i == len(self.up_blocks) - 1
876
-
877
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
878
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
879
-
880
- # if we have not reached the final block and need to forward the
881
- # upsample size, we do it here
882
- if not is_final_block:
883
- upsample_size = down_block_res_samples[-1].shape[2:]
884
-
885
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
886
- sample, up_res_samples = upsample_block(
887
- hidden_states=sample,
888
- temb=emb,
889
- res_hidden_states_tuple=res_samples,
890
- encoder_hidden_states=encoder_hidden_states,
891
- cross_attention_kwargs=cross_attention_kwargs,
892
- upsample_size=upsample_size,
893
- attention_mask=attention_mask,
894
- return_res_samples=True
895
- )
896
- else:
897
- sample, up_res_samples = upsample_block(
898
- hidden_states=sample,
899
- temb=emb,
900
- res_hidden_states_tuple=res_samples,
901
- upsample_size=upsample_size,
902
- return_res_samples=True
903
- )
904
-
905
- up_block_res_samples += up_res_samples
906
-
907
- # 8. BrushNet up blocks
908
- brushnet_up_block_res_samples = ()
909
- for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
910
- up_block_res_sample = brushnet_up_block(up_block_res_sample)
911
- brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
912
-
913
- # 6. scaling
914
- if guess_mode and not self.config.global_pool_conditions:
915
- scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
916
- scales = scales * conditioning_scale
917
-
918
- brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
919
- brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
920
- brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
921
- else:
922
- brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
923
- brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
924
- brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
925
-
926
-
927
- if self.config.global_pool_conditions:
928
- brushnet_down_block_res_samples = [
929
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
930
- ]
931
- brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
932
- brushnet_up_block_res_samples = [
933
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
934
- ]
935
-
936
- if not return_dict:
937
- return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
938
-
939
- return BrushNetOutput(
940
- down_block_res_samples=brushnet_down_block_res_samples,
941
- mid_block_res_sample=brushnet_mid_block_res_sample,
942
- up_block_res_samples=brushnet_up_block_res_samples
943
- )
944
-
945
-
946
- def zero_module(module):
947
- for p in module.parameters():
948
- nn.init.zeros_(p)
949
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/brushnet_ca.py DELETED
@@ -1,983 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Any, Dict, List, Optional, Tuple, Union
3
-
4
- import torch
5
- from torch import nn
6
-
7
- from diffusers.configuration_utils import ConfigMixin, register_to_config
8
- from diffusers.utils import BaseOutput, logging
9
- from diffusers.models.attention_processor import (
10
- ADDED_KV_ATTENTION_PROCESSORS,
11
- CROSS_ATTENTION_PROCESSORS,
12
- AttentionProcessor,
13
- AttnAddedKVProcessor,
14
- AttnProcessor,
15
- )
16
- from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
17
- from diffusers.models.modeling_utils import ModelMixin
18
-
19
- from .unet_2d_blocks import (
20
- CrossAttnDownBlock2D,
21
- DownBlock2D,
22
- UNetMidBlock2D,
23
- UNetMidBlock2DCrossAttn,
24
- get_down_block,
25
- get_mid_block,
26
- get_up_block,
27
- MidBlock2D
28
- )
29
-
30
- from .unet_2d_condition import UNet2DConditionModel
31
-
32
-
33
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
-
35
-
36
- @dataclass
37
- class BrushNetOutput(BaseOutput):
38
- """
39
- The output of [`BrushNetModel`].
40
-
41
- Args:
42
- up_block_res_samples (`tuple[torch.Tensor]`):
43
- A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
44
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
45
- used to condition the original UNet's upsampling activations.
46
- down_block_res_samples (`tuple[torch.Tensor]`):
47
- A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
48
- be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
49
- used to condition the original UNet's downsampling activations.
50
- mid_down_block_re_sample (`torch.Tensor`):
51
- The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
52
- `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
53
- Output can be used to condition the original UNet's middle block activation.
54
- """
55
-
56
- up_block_res_samples: Tuple[torch.Tensor]
57
- down_block_res_samples: Tuple[torch.Tensor]
58
- mid_block_res_sample: torch.Tensor
59
-
60
-
61
- class BrushNetModel(ModelMixin, ConfigMixin):
62
- """
63
- A BrushNet model.
64
-
65
- Args:
66
- in_channels (`int`, defaults to 4):
67
- The number of channels in the input sample.
68
- flip_sin_to_cos (`bool`, defaults to `True`):
69
- Whether to flip the sin to cos in the time embedding.
70
- freq_shift (`int`, defaults to 0):
71
- The frequency shift to apply to the time embedding.
72
- down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
73
- The tuple of downsample blocks to use.
74
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
75
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
76
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
77
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
78
- The tuple of upsample blocks to use.
79
- only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
80
- block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
81
- The tuple of output channels for each block.
82
- layers_per_block (`int`, defaults to 2):
83
- The number of layers per block.
84
- downsample_padding (`int`, defaults to 1):
85
- The padding to use for the downsampling convolution.
86
- mid_block_scale_factor (`float`, defaults to 1):
87
- The scale factor to use for the mid block.
88
- act_fn (`str`, defaults to "silu"):
89
- The activation function to use.
90
- norm_num_groups (`int`, *optional*, defaults to 32):
91
- The number of groups to use for the normalization. If None, normalization and activation layers is skipped
92
- in post-processing.
93
- norm_eps (`float`, defaults to 1e-5):
94
- The epsilon to use for the normalization.
95
- cross_attention_dim (`int`, defaults to 1280):
96
- The dimension of the cross attention features.
97
- transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
98
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
99
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
100
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
101
- encoder_hid_dim (`int`, *optional*, defaults to None):
102
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
103
- dimension to `cross_attention_dim`.
104
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
105
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
106
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
107
- attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
108
- The dimension of the attention heads.
109
- use_linear_projection (`bool`, defaults to `False`):
110
- class_embed_type (`str`, *optional*, defaults to `None`):
111
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
112
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
113
- addition_embed_type (`str`, *optional*, defaults to `None`):
114
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
115
- "text". "text" will use the `TextTimeEmbedding` layer.
116
- num_class_embeds (`int`, *optional*, defaults to 0):
117
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
118
- class conditioning with `class_embed_type` equal to `None`.
119
- upcast_attention (`bool`, defaults to `False`):
120
- resnet_time_scale_shift (`str`, defaults to `"default"`):
121
- Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
122
- projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
123
- The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
124
- `class_embed_type="projection"`.
125
- brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
126
- The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
127
- conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
128
- The tuple of output channel for each block in the `conditioning_embedding` layer.
129
- global_pool_conditions (`bool`, defaults to `False`):
130
- TODO(Patrick) - unused parameter.
131
- addition_embed_type_num_heads (`int`, defaults to 64):
132
- The number of heads to use for the `TextTimeEmbedding` layer.
133
- """
134
-
135
- _supports_gradient_checkpointing = True
136
-
137
- @register_to_config
138
- def __init__(
139
- self,
140
- in_channels: int = 4,
141
- conditioning_channels: int = 5,
142
- flip_sin_to_cos: bool = True,
143
- freq_shift: int = 0,
144
- down_block_types: Tuple[str, ...] = (
145
- "CrossAttnDownBlock2D",
146
- "CrossAttnDownBlock2D",
147
- "CrossAttnDownBlock2D",
148
- "DownBlock2D",
149
- ),
150
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
151
- up_block_types: Tuple[str, ...] = (
152
- "UpBlock2D",
153
- "CrossAttnUpBlock2D",
154
- "CrossAttnUpBlock2D",
155
- "CrossAttnUpBlock2D",
156
- ),
157
- only_cross_attention: Union[bool, Tuple[bool]] = False,
158
- block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
159
- layers_per_block: int = 2,
160
- downsample_padding: int = 1,
161
- mid_block_scale_factor: float = 1,
162
- act_fn: str = "silu",
163
- norm_num_groups: Optional[int] = 32,
164
- norm_eps: float = 1e-5,
165
- cross_attention_dim: int = 1280,
166
- transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
167
- encoder_hid_dim: Optional[int] = None,
168
- encoder_hid_dim_type: Optional[str] = None,
169
- attention_head_dim: Union[int, Tuple[int, ...]] = 8,
170
- num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
171
- use_linear_projection: bool = False,
172
- class_embed_type: Optional[str] = None,
173
- addition_embed_type: Optional[str] = None,
174
- addition_time_embed_dim: Optional[int] = None,
175
- num_class_embeds: Optional[int] = None,
176
- upcast_attention: bool = False,
177
- resnet_time_scale_shift: str = "default",
178
- projection_class_embeddings_input_dim: Optional[int] = None,
179
- brushnet_conditioning_channel_order: str = "rgb",
180
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
181
- global_pool_conditions: bool = False,
182
- addition_embed_type_num_heads: int = 64,
183
- ):
184
- super().__init__()
185
-
186
- # If `num_attention_heads` is not defined (which is the case for most models)
187
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
188
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
189
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
190
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
191
- # which is why we correct for the naming here.
192
- num_attention_heads = num_attention_heads or attention_head_dim
193
-
194
- # Check inputs
195
- if len(down_block_types) != len(up_block_types):
196
- raise ValueError(
197
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
198
- )
199
-
200
- if len(block_out_channels) != len(down_block_types):
201
- raise ValueError(
202
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
203
- )
204
-
205
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
206
- raise ValueError(
207
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
208
- )
209
-
210
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
211
- raise ValueError(
212
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
213
- )
214
-
215
- if isinstance(transformer_layers_per_block, int):
216
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
217
-
218
- # input
219
- conv_in_kernel = 3
220
- conv_in_padding = (conv_in_kernel - 1) // 2
221
- self.conv_in_condition = nn.Conv2d(
222
- in_channels + conditioning_channels,
223
- block_out_channels[0],
224
- kernel_size=conv_in_kernel,
225
- padding=conv_in_padding,
226
- )
227
-
228
- # time
229
- time_embed_dim = block_out_channels[0] * 4
230
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
231
- timestep_input_dim = block_out_channels[0]
232
- self.time_embedding = TimestepEmbedding(
233
- timestep_input_dim,
234
- time_embed_dim,
235
- act_fn=act_fn,
236
- )
237
-
238
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
239
- encoder_hid_dim_type = "text_proj"
240
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
241
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
242
-
243
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
244
- raise ValueError(
245
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
246
- )
247
-
248
- if encoder_hid_dim_type == "text_proj":
249
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
250
- elif encoder_hid_dim_type == "text_image_proj":
251
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
252
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
253
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
254
- self.encoder_hid_proj = TextImageProjection(
255
- text_embed_dim=encoder_hid_dim,
256
- image_embed_dim=cross_attention_dim,
257
- cross_attention_dim=cross_attention_dim,
258
- )
259
-
260
- elif encoder_hid_dim_type is not None:
261
- raise ValueError(
262
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
263
- )
264
- else:
265
- self.encoder_hid_proj = None
266
-
267
- # class embedding
268
- if class_embed_type is None and num_class_embeds is not None:
269
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
270
- elif class_embed_type == "timestep":
271
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
272
- elif class_embed_type == "identity":
273
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
274
- elif class_embed_type == "projection":
275
- if projection_class_embeddings_input_dim is None:
276
- raise ValueError(
277
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
278
- )
279
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
280
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
281
- # 2. it projects from an arbitrary input dimension.
282
- #
283
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
284
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
285
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
286
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
287
- else:
288
- self.class_embedding = None
289
-
290
- if addition_embed_type == "text":
291
- if encoder_hid_dim is not None:
292
- text_time_embedding_from_dim = encoder_hid_dim
293
- else:
294
- text_time_embedding_from_dim = cross_attention_dim
295
-
296
- self.add_embedding = TextTimeEmbedding(
297
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
298
- )
299
- elif addition_embed_type == "text_image":
300
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
301
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
302
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
303
- self.add_embedding = TextImageTimeEmbedding(
304
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
305
- )
306
- elif addition_embed_type == "text_time":
307
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
308
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
309
-
310
- elif addition_embed_type is not None:
311
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
312
-
313
- self.down_blocks = nn.ModuleList([])
314
- self.brushnet_down_blocks = nn.ModuleList([])
315
-
316
- if isinstance(only_cross_attention, bool):
317
- only_cross_attention = [only_cross_attention] * len(down_block_types)
318
-
319
- if isinstance(attention_head_dim, int):
320
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
321
-
322
- if isinstance(num_attention_heads, int):
323
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
324
-
325
- # down
326
- output_channel = block_out_channels[0]
327
-
328
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
329
- brushnet_block = zero_module(brushnet_block)
330
- self.brushnet_down_blocks.append(brushnet_block)
331
-
332
- for i, down_block_type in enumerate(down_block_types):
333
- input_channel = output_channel
334
- output_channel = block_out_channels[i]
335
- is_final_block = i == len(block_out_channels) - 1
336
-
337
- down_block = get_down_block(
338
- down_block_type,
339
- num_layers=layers_per_block,
340
- transformer_layers_per_block=transformer_layers_per_block[i],
341
- in_channels=input_channel,
342
- out_channels=output_channel,
343
- temb_channels=time_embed_dim,
344
- add_downsample=not is_final_block,
345
- resnet_eps=norm_eps,
346
- resnet_act_fn=act_fn,
347
- resnet_groups=norm_num_groups,
348
- cross_attention_dim=cross_attention_dim,
349
- num_attention_heads=num_attention_heads[i],
350
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
351
- downsample_padding=downsample_padding,
352
- use_linear_projection=use_linear_projection,
353
- only_cross_attention=only_cross_attention[i],
354
- upcast_attention=upcast_attention,
355
- resnet_time_scale_shift=resnet_time_scale_shift,
356
- )
357
- self.down_blocks.append(down_block)
358
-
359
- for _ in range(layers_per_block):
360
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
361
- brushnet_block = zero_module(brushnet_block)
362
- self.brushnet_down_blocks.append(brushnet_block)
363
-
364
- if not is_final_block:
365
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
366
- brushnet_block = zero_module(brushnet_block)
367
- self.brushnet_down_blocks.append(brushnet_block)
368
-
369
- # mid
370
- mid_block_channel = block_out_channels[-1]
371
-
372
- brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
373
- brushnet_block = zero_module(brushnet_block)
374
- self.brushnet_mid_block = brushnet_block
375
-
376
- self.mid_block = get_mid_block(
377
- mid_block_type,
378
- transformer_layers_per_block=transformer_layers_per_block[-1],
379
- in_channels=mid_block_channel,
380
- temb_channels=time_embed_dim,
381
- resnet_eps=norm_eps,
382
- resnet_act_fn=act_fn,
383
- output_scale_factor=mid_block_scale_factor,
384
- resnet_time_scale_shift=resnet_time_scale_shift,
385
- cross_attention_dim=cross_attention_dim,
386
- num_attention_heads=num_attention_heads[-1],
387
- resnet_groups=norm_num_groups,
388
- use_linear_projection=use_linear_projection,
389
- upcast_attention=upcast_attention,
390
- )
391
-
392
- # count how many layers upsample the images
393
- self.num_upsamplers = 0
394
-
395
- # up
396
- reversed_block_out_channels = list(reversed(block_out_channels))
397
- reversed_num_attention_heads = list(reversed(num_attention_heads))
398
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
399
- only_cross_attention = list(reversed(only_cross_attention))
400
-
401
- output_channel = reversed_block_out_channels[0]
402
-
403
- self.up_blocks = nn.ModuleList([])
404
- self.brushnet_up_blocks = nn.ModuleList([])
405
-
406
- for i, up_block_type in enumerate(up_block_types):
407
- is_final_block = i == len(block_out_channels) - 1
408
-
409
- prev_output_channel = output_channel
410
- output_channel = reversed_block_out_channels[i]
411
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
412
-
413
- # add upsample block for all BUT final layer
414
- if not is_final_block:
415
- add_upsample = True
416
- self.num_upsamplers += 1
417
- else:
418
- add_upsample = False
419
-
420
- up_block = get_up_block(
421
- up_block_type,
422
- num_layers=layers_per_block + 1,
423
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
424
- in_channels=input_channel,
425
- out_channels=output_channel,
426
- prev_output_channel=prev_output_channel,
427
- temb_channels=time_embed_dim,
428
- add_upsample=add_upsample,
429
- resnet_eps=norm_eps,
430
- resnet_act_fn=act_fn,
431
- resolution_idx=i,
432
- resnet_groups=norm_num_groups,
433
- cross_attention_dim=cross_attention_dim,
434
- num_attention_heads=reversed_num_attention_heads[i],
435
- use_linear_projection=use_linear_projection,
436
- only_cross_attention=only_cross_attention[i],
437
- upcast_attention=upcast_attention,
438
- resnet_time_scale_shift=resnet_time_scale_shift,
439
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
440
- )
441
- self.up_blocks.append(up_block)
442
- prev_output_channel = output_channel
443
-
444
- for _ in range(layers_per_block + 1):
445
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
446
- brushnet_block = zero_module(brushnet_block)
447
- self.brushnet_up_blocks.append(brushnet_block)
448
-
449
- if not is_final_block:
450
- brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
451
- brushnet_block = zero_module(brushnet_block)
452
- self.brushnet_up_blocks.append(brushnet_block)
453
-
454
- @classmethod
455
- def from_unet(
456
- cls,
457
- unet: UNet2DConditionModel,
458
- brushnet_conditioning_channel_order: str = "rgb",
459
- conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
460
- load_weights_from_unet: bool = True,
461
- conditioning_channels: int = 5,
462
- ):
463
- r"""
464
- Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
465
-
466
- Parameters:
467
- unet (`UNet2DConditionModel`):
468
- The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
469
- where applicable.
470
- """
471
- transformer_layers_per_block = (
472
- unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
473
- )
474
- encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
475
- encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
476
- addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
477
- addition_time_embed_dim = (
478
- unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
479
- )
480
-
481
- brushnet = cls(
482
- in_channels=unet.config.in_channels,
483
- conditioning_channels=conditioning_channels,
484
- flip_sin_to_cos=unet.config.flip_sin_to_cos,
485
- freq_shift=unet.config.freq_shift,
486
- # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
487
- down_block_types=[
488
- "CrossAttnDownBlock2D",
489
- "CrossAttnDownBlock2D",
490
- "CrossAttnDownBlock2D",
491
- "DownBlock2D",
492
- ],
493
- # mid_block_type='MidBlock2D',
494
- mid_block_type="UNetMidBlock2DCrossAttn",
495
- # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
496
- up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
497
- only_cross_attention=unet.config.only_cross_attention,
498
- block_out_channels=unet.config.block_out_channels,
499
- layers_per_block=unet.config.layers_per_block,
500
- downsample_padding=unet.config.downsample_padding,
501
- mid_block_scale_factor=unet.config.mid_block_scale_factor,
502
- act_fn=unet.config.act_fn,
503
- norm_num_groups=unet.config.norm_num_groups,
504
- norm_eps=unet.config.norm_eps,
505
- cross_attention_dim=unet.config.cross_attention_dim,
506
- transformer_layers_per_block=transformer_layers_per_block,
507
- encoder_hid_dim=encoder_hid_dim,
508
- encoder_hid_dim_type=encoder_hid_dim_type,
509
- attention_head_dim=unet.config.attention_head_dim,
510
- num_attention_heads=unet.config.num_attention_heads,
511
- use_linear_projection=unet.config.use_linear_projection,
512
- class_embed_type=unet.config.class_embed_type,
513
- addition_embed_type=addition_embed_type,
514
- addition_time_embed_dim=addition_time_embed_dim,
515
- num_class_embeds=unet.config.num_class_embeds,
516
- upcast_attention=unet.config.upcast_attention,
517
- resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
518
- projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
519
- brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
520
- conditioning_embedding_out_channels=conditioning_embedding_out_channels,
521
- )
522
-
523
- if load_weights_from_unet:
524
- conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
525
- conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
526
- conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
527
- brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
528
- brushnet.conv_in_condition.bias = unet.conv_in.bias
529
-
530
- brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
531
- brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
532
-
533
- if brushnet.class_embedding:
534
- brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
535
-
536
- brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
537
- brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
538
- brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
539
-
540
- return brushnet.to(unet.dtype)
541
-
542
- @property
543
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
544
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
545
- r"""
546
- Returns:
547
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
548
- indexed by its weight name.
549
- """
550
- # set recursively
551
- processors = {}
552
-
553
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
554
- if hasattr(module, "get_processor"):
555
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
556
-
557
- for sub_name, child in module.named_children():
558
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
559
-
560
- return processors
561
-
562
- for name, module in self.named_children():
563
- fn_recursive_add_processors(name, module, processors)
564
-
565
- return processors
566
-
567
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
568
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
569
- r"""
570
- Sets the attention processor to use to compute attention.
571
-
572
- Parameters:
573
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
574
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
575
- for **all** `Attention` layers.
576
-
577
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
578
- processor. This is strongly recommended when setting trainable attention processors.
579
-
580
- """
581
- count = len(self.attn_processors.keys())
582
-
583
- if isinstance(processor, dict) and len(processor) != count:
584
- raise ValueError(
585
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
586
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
587
- )
588
-
589
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
590
- if hasattr(module, "set_processor"):
591
- if not isinstance(processor, dict):
592
- module.set_processor(processor)
593
- else:
594
- module.set_processor(processor.pop(f"{name}.processor"))
595
-
596
- for sub_name, child in module.named_children():
597
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
598
-
599
- for name, module in self.named_children():
600
- fn_recursive_attn_processor(name, module, processor)
601
-
602
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
603
- def set_default_attn_processor(self):
604
- """
605
- Disables custom attention processors and sets the default attention implementation.
606
- """
607
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
608
- processor = AttnAddedKVProcessor()
609
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
- processor = AttnProcessor()
611
- else:
612
- raise ValueError(
613
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
614
- )
615
-
616
- self.set_attn_processor(processor)
617
-
618
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
619
- def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
620
- r"""
621
- Enable sliced attention computation.
622
-
623
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
624
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
625
-
626
- Args:
627
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
628
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
629
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
630
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
631
- must be a multiple of `slice_size`.
632
- """
633
- sliceable_head_dims = []
634
-
635
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
636
- if hasattr(module, "set_attention_slice"):
637
- sliceable_head_dims.append(module.sliceable_head_dim)
638
-
639
- for child in module.children():
640
- fn_recursive_retrieve_sliceable_dims(child)
641
-
642
- # retrieve number of attention layers
643
- for module in self.children():
644
- fn_recursive_retrieve_sliceable_dims(module)
645
-
646
- num_sliceable_layers = len(sliceable_head_dims)
647
-
648
- if slice_size == "auto":
649
- # half the attention head size is usually a good trade-off between
650
- # speed and memory
651
- slice_size = [dim // 2 for dim in sliceable_head_dims]
652
- elif slice_size == "max":
653
- # make smallest slice possible
654
- slice_size = num_sliceable_layers * [1]
655
-
656
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
657
-
658
- if len(slice_size) != len(sliceable_head_dims):
659
- raise ValueError(
660
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
661
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
662
- )
663
-
664
- for i in range(len(slice_size)):
665
- size = slice_size[i]
666
- dim = sliceable_head_dims[i]
667
- if size is not None and size > dim:
668
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
669
-
670
- # Recursively walk through all the children.
671
- # Any children which exposes the set_attention_slice method
672
- # gets the message
673
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
674
- if hasattr(module, "set_attention_slice"):
675
- module.set_attention_slice(slice_size.pop())
676
-
677
- for child in module.children():
678
- fn_recursive_set_attention_slice(child, slice_size)
679
-
680
- reversed_slice_size = list(reversed(slice_size))
681
- for module in self.children():
682
- fn_recursive_set_attention_slice(module, reversed_slice_size)
683
-
684
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
685
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
686
- module.gradient_checkpointing = value
687
-
688
- def forward(
689
- self,
690
- sample: torch.FloatTensor,
691
- timestep: Union[torch.Tensor, float, int],
692
- encoder_hidden_states: torch.Tensor,
693
- brushnet_cond: torch.FloatTensor,
694
- conditioning_scale: float = 1.0,
695
- class_labels: Optional[torch.Tensor] = None,
696
- timestep_cond: Optional[torch.Tensor] = None,
697
- attention_mask: Optional[torch.Tensor] = None,
698
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
699
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
- guess_mode: bool = False,
701
- return_dict: bool = True,
702
- debug=False,
703
- ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
704
- """
705
- The [`BrushNetModel`] forward method.
706
-
707
- Args:
708
- sample (`torch.FloatTensor`):
709
- The noisy input tensor.
710
- timestep (`Union[torch.Tensor, float, int]`):
711
- The number of timesteps to denoise an input.
712
- encoder_hidden_states (`torch.Tensor`):
713
- The encoder hidden states.
714
- brushnet_cond (`torch.FloatTensor`):
715
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
716
- conditioning_scale (`float`, defaults to `1.0`):
717
- The scale factor for BrushNet outputs.
718
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
719
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
720
- timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
721
- Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
722
- timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
723
- embeddings.
724
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
725
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
726
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
727
- negative values to the attention scores corresponding to "discard" tokens.
728
- added_cond_kwargs (`dict`):
729
- Additional conditions for the Stable Diffusion XL UNet.
730
- cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
731
- A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
732
- guess_mode (`bool`, defaults to `False`):
733
- In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
734
- you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
735
- return_dict (`bool`, defaults to `True`):
736
- Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
737
-
738
- Returns:
739
- [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
740
- If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
741
- returned where the first element is the sample tensor.
742
- """
743
- # check channel order
744
- channel_order = self.config.brushnet_conditioning_channel_order
745
-
746
- if channel_order == "rgb":
747
- # in rgb order by default
748
- ...
749
- elif channel_order == "bgr":
750
- brushnet_cond = torch.flip(brushnet_cond, dims=[1])
751
- else:
752
- raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
753
-
754
- if debug: print('BrushNet CA: attn mask')
755
-
756
- # prepare attention_mask
757
- if attention_mask is not None:
758
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
759
- attention_mask = attention_mask.unsqueeze(1)
760
-
761
- if debug: print('BrushNet CA: time')
762
-
763
- # 1. time
764
- timesteps = timestep
765
- if not torch.is_tensor(timesteps):
766
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
767
- # This would be a good case for the `match` statement (Python 3.10+)
768
- is_mps = sample.device.type == "mps"
769
- if isinstance(timestep, float):
770
- dtype = torch.float32 if is_mps else torch.float64
771
- else:
772
- dtype = torch.int32 if is_mps else torch.int64
773
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
774
- elif len(timesteps.shape) == 0:
775
- timesteps = timesteps[None].to(sample.device)
776
-
777
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
778
- timesteps = timesteps.expand(sample.shape[0])
779
-
780
- t_emb = self.time_proj(timesteps)
781
-
782
- # timesteps does not contain any weights and will always return f32 tensors
783
- # but time_embedding might actually be running in fp16. so we need to cast here.
784
- # there might be better ways to encapsulate this.
785
- t_emb = t_emb.to(dtype=sample.dtype)
786
-
787
- emb = self.time_embedding(t_emb, timestep_cond)
788
- aug_emb = None
789
-
790
- if self.class_embedding is not None:
791
- if class_labels is None:
792
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
793
-
794
- if self.config.class_embed_type == "timestep":
795
- class_labels = self.time_proj(class_labels)
796
-
797
- class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
798
- emb = emb + class_emb
799
-
800
- if self.config.addition_embed_type is not None:
801
- if self.config.addition_embed_type == "text":
802
- aug_emb = self.add_embedding(encoder_hidden_states)
803
-
804
- elif self.config.addition_embed_type == "text_time":
805
- if "text_embeds" not in added_cond_kwargs:
806
- raise ValueError(
807
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
808
- )
809
- text_embeds = added_cond_kwargs.get("text_embeds")
810
- if "time_ids" not in added_cond_kwargs:
811
- raise ValueError(
812
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
813
- )
814
- time_ids = added_cond_kwargs.get("time_ids")
815
- time_embeds = self.add_time_proj(time_ids.flatten())
816
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
817
-
818
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
819
- add_embeds = add_embeds.to(emb.dtype)
820
- aug_emb = self.add_embedding(add_embeds)
821
-
822
- emb = emb + aug_emb if aug_emb is not None else emb
823
-
824
- if debug: print('BrushNet CA: pre-process')
825
-
826
-
827
- # 2. pre-process
828
- brushnet_cond = torch.concat([sample, brushnet_cond], 1)
829
- sample = self.conv_in_condition(brushnet_cond)
830
-
831
- if debug: print('BrushNet CA: down')
832
-
833
- # 3. down
834
- down_block_res_samples = (sample,)
835
- for downsample_block in self.down_blocks:
836
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
837
- if debug: print('BrushNet CA (down block with XA): ', type(downsample_block))
838
- sample, res_samples = downsample_block(
839
- hidden_states=sample,
840
- temb=emb,
841
- encoder_hidden_states=encoder_hidden_states,
842
- attention_mask=attention_mask,
843
- cross_attention_kwargs=cross_attention_kwargs,
844
- debug=debug,
845
- )
846
- else:
847
- if debug: print('BrushNet CA (down block): ', type(downsample_block))
848
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, debug=debug)
849
-
850
- down_block_res_samples += res_samples
851
-
852
- if debug: print('BrushNet CA: PP down')
853
-
854
- # 4. PaintingNet down blocks
855
- brushnet_down_block_res_samples = ()
856
- for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
857
- down_block_res_sample = brushnet_down_block(down_block_res_sample)
858
- brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
859
-
860
- if debug: print('BrushNet CA: PP mid')
861
-
862
- # 5. mid
863
- if self.mid_block is not None:
864
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
865
- sample = self.mid_block(
866
- sample,
867
- emb,
868
- encoder_hidden_states=encoder_hidden_states,
869
- attention_mask=attention_mask,
870
- cross_attention_kwargs=cross_attention_kwargs,
871
- )
872
- else:
873
- sample = self.mid_block(sample, emb)
874
-
875
- if debug: print('BrushNet CA: mid')
876
-
877
- # 6. BrushNet mid blocks
878
- brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
879
-
880
- if debug: print('BrushNet CA: PP up')
881
-
882
- # 7. up
883
- up_block_res_samples = ()
884
- for i, upsample_block in enumerate(self.up_blocks):
885
- is_final_block = i == len(self.up_blocks) - 1
886
-
887
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
888
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
889
-
890
- # if we have not reached the final block and need to forward the
891
- # upsample size, we do it here
892
- if not is_final_block:
893
- upsample_size = down_block_res_samples[-1].shape[2:]
894
-
895
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
896
- sample, up_res_samples = upsample_block(
897
- hidden_states=sample,
898
- temb=emb,
899
- res_hidden_states_tuple=res_samples,
900
- encoder_hidden_states=encoder_hidden_states,
901
- cross_attention_kwargs=cross_attention_kwargs,
902
- upsample_size=upsample_size,
903
- attention_mask=attention_mask,
904
- return_res_samples=True,
905
- )
906
- else:
907
- sample, up_res_samples = upsample_block(
908
- hidden_states=sample,
909
- temb=emb,
910
- res_hidden_states_tuple=res_samples,
911
- upsample_size=upsample_size,
912
- return_res_samples=True,
913
- )
914
-
915
- up_block_res_samples += up_res_samples
916
-
917
- if debug: print('BrushNet CA: up')
918
-
919
- # 8. BrushNet up blocks
920
- brushnet_up_block_res_samples = ()
921
- for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
922
- up_block_res_sample = brushnet_up_block(up_block_res_sample)
923
- brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
924
-
925
- if debug: print('BrushNet CA: scaling')
926
-
927
- # 6. scaling
928
- if guess_mode and not self.config.global_pool_conditions:
929
- scales = torch.logspace(
930
- -1,
931
- 0,
932
- len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
933
- device=sample.device,
934
- ) # 0.1 to 1.0
935
- scales = scales * conditioning_scale
936
-
937
- brushnet_down_block_res_samples = [
938
- sample * scale
939
- for sample, scale in zip(
940
- brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
941
- )
942
- ]
943
- brushnet_mid_block_res_sample = (
944
- brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
945
- )
946
- brushnet_up_block_res_samples = [
947
- sample * scale
948
- for sample, scale in zip(
949
- brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
950
- )
951
- ]
952
- else:
953
- brushnet_down_block_res_samples = [
954
- sample * conditioning_scale for sample in brushnet_down_block_res_samples
955
- ]
956
- brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
957
- brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
958
-
959
- if self.config.global_pool_conditions:
960
- brushnet_down_block_res_samples = [
961
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
962
- ]
963
- brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
964
- brushnet_up_block_res_samples = [
965
- torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
966
- ]
967
-
968
- if debug: print('BrushNet CA: finish')
969
-
970
- if not return_dict:
971
- return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
972
-
973
- return BrushNetOutput(
974
- down_block_res_samples=brushnet_down_block_res_samples,
975
- mid_block_res_sample=brushnet_mid_block_res_sample,
976
- up_block_res_samples=brushnet_up_block_res_samples,
977
- )
978
-
979
-
980
- def zero_module(module):
981
- for p in module.parameters():
982
- nn.init.zeros_(p)
983
- return module
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/brushnet_xl.json DELETED
@@ -1,63 +0,0 @@
1
- {
2
- "_class_name": "BrushNetModel",
3
- "_diffusers_version": "0.27.0.dev0",
4
- "_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
5
- "act_fn": "silu",
6
- "addition_embed_type": "text_time",
7
- "addition_embed_type_num_heads": 64,
8
- "addition_time_embed_dim": 256,
9
- "attention_head_dim": [
10
- 5,
11
- 10,
12
- 20
13
- ],
14
- "block_out_channels": [
15
- 320,
16
- 640,
17
- 1280
18
- ],
19
- "brushnet_conditioning_channel_order": "rgb",
20
- "class_embed_type": null,
21
- "conditioning_channels": 5,
22
- "conditioning_embedding_out_channels": [
23
- 16,
24
- 32,
25
- 96,
26
- 256
27
- ],
28
- "cross_attention_dim": 2048,
29
- "down_block_types": [
30
- "DownBlock2D",
31
- "DownBlock2D",
32
- "DownBlock2D"
33
- ],
34
- "downsample_padding": 1,
35
- "encoder_hid_dim": null,
36
- "encoder_hid_dim_type": null,
37
- "flip_sin_to_cos": true,
38
- "freq_shift": 0,
39
- "global_pool_conditions": false,
40
- "in_channels": 4,
41
- "layers_per_block": 2,
42
- "mid_block_scale_factor": 1,
43
- "mid_block_type": "MidBlock2D",
44
- "norm_eps": 1e-05,
45
- "norm_num_groups": 32,
46
- "num_attention_heads": null,
47
- "num_class_embeds": null,
48
- "only_cross_attention": false,
49
- "projection_class_embeddings_input_dim": 2816,
50
- "resnet_time_scale_shift": "default",
51
- "transformer_layers_per_block": [
52
- 1,
53
- 2,
54
- 10
55
- ],
56
- "up_block_types": [
57
- "UpBlock2D",
58
- "UpBlock2D",
59
- "UpBlock2D"
60
- ],
61
- "upcast_attention": null,
62
- "use_linear_projection": true
63
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/powerpaint.json DELETED
@@ -1,57 +0,0 @@
1
- {
2
- "_class_name": "BrushNetModel",
3
- "_diffusers_version": "0.27.2",
4
- "act_fn": "silu",
5
- "addition_embed_type": null,
6
- "addition_embed_type_num_heads": 64,
7
- "addition_time_embed_dim": null,
8
- "attention_head_dim": 8,
9
- "block_out_channels": [
10
- 320,
11
- 640,
12
- 1280,
13
- 1280
14
- ],
15
- "brushnet_conditioning_channel_order": "rgb",
16
- "class_embed_type": null,
17
- "conditioning_channels": 5,
18
- "conditioning_embedding_out_channels": [
19
- 16,
20
- 32,
21
- 96,
22
- 256
23
- ],
24
- "cross_attention_dim": 768,
25
- "down_block_types": [
26
- "CrossAttnDownBlock2D",
27
- "CrossAttnDownBlock2D",
28
- "CrossAttnDownBlock2D",
29
- "DownBlock2D"
30
- ],
31
- "downsample_padding": 1,
32
- "encoder_hid_dim": null,
33
- "encoder_hid_dim_type": null,
34
- "flip_sin_to_cos": true,
35
- "freq_shift": 0,
36
- "global_pool_conditions": false,
37
- "in_channels": 4,
38
- "layers_per_block": 2,
39
- "mid_block_scale_factor": 1,
40
- "mid_block_type": "UNetMidBlock2DCrossAttn",
41
- "norm_eps": 1e-05,
42
- "norm_num_groups": 32,
43
- "num_attention_heads": null,
44
- "num_class_embeds": null,
45
- "only_cross_attention": false,
46
- "projection_class_embeddings_input_dim": null,
47
- "resnet_time_scale_shift": "default",
48
- "transformer_layers_per_block": 1,
49
- "up_block_types": [
50
- "UpBlock2D",
51
- "CrossAttnUpBlock2D",
52
- "CrossAttnUpBlock2D",
53
- "CrossAttnUpBlock2D"
54
- ],
55
- "upcast_attention": false,
56
- "use_linear_projection": false
57
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/powerpaint_utils.py DELETED
@@ -1,496 +0,0 @@
1
- import copy
2
- import random
3
-
4
- import torch
5
- import torch.nn as nn
6
- from transformers import CLIPTokenizer
7
- from typing import Any, List, Optional, Union
8
-
9
- class TokenizerWrapper:
10
- """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
11
- currently. This wrapper is modified from https://github.com/huggingface/dif
12
- fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
13
- py#L358 # noqa.
14
-
15
- Args:
16
- from_pretrained (Union[str, os.PathLike], optional): The *model id*
17
- of a pretrained model or a path to a *directory* containing
18
- model weights and config. Defaults to None.
19
- from_config (Union[str, os.PathLike], optional): The *model id*
20
- of a pretrained model or a path to a *directory* containing
21
- model weights and config. Defaults to None.
22
-
23
- *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
24
- will be passed to `from_pretrained` function. Otherwise, *args
25
- and **kwargs will be used to initialize the model by
26
- `self._module_cls(*args, **kwargs)`.
27
- """
28
-
29
- def __init__(self, tokenizer: CLIPTokenizer):
30
- self.wrapped = tokenizer
31
- self.token_map = {}
32
-
33
- def __getattr__(self, name: str) -> Any:
34
- if name in self.__dict__:
35
- return getattr(self, name)
36
- #if name == "wrapped":
37
- # return getattr(self, 'wrapped')#super().__getattr__("wrapped")
38
-
39
- try:
40
- return getattr(self.wrapped, name)
41
- except AttributeError:
42
- raise AttributeError(
43
- "'name' cannot be found in both "
44
- f"'{self.__class__.__name__}' and "
45
- f"'{self.__class__.__name__}.tokenizer'."
46
- )
47
-
48
- def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
49
- """Attempt to add tokens to the tokenizer.
50
-
51
- Args:
52
- tokens (Union[str, List[str]]): The tokens to be added.
53
- """
54
- num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
55
- assert num_added_tokens != 0, (
56
- f"The tokenizer already contains the token {tokens}. Please pass "
57
- "a different `placeholder_token` that is not already in the "
58
- "tokenizer."
59
- )
60
-
61
- def get_token_info(self, token: str) -> dict:
62
- """Get the information of a token, including its start and end index in
63
- the current tokenizer.
64
-
65
- Args:
66
- token (str): The token to be queried.
67
-
68
- Returns:
69
- dict: The information of the token, including its start and end
70
- index in current tokenizer.
71
- """
72
- token_ids = self.__call__(token).input_ids
73
- start, end = token_ids[1], token_ids[-2] + 1
74
- return {"name": token, "start": start, "end": end}
75
-
76
- def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
77
- """Add placeholder tokens to the tokenizer.
78
-
79
- Args:
80
- placeholder_token (str): The placeholder token to be added.
81
- num_vec_per_token (int, optional): The number of vectors of
82
- the added placeholder token.
83
- *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
84
- """
85
- output = []
86
- if num_vec_per_token == 1:
87
- self.try_adding_tokens(placeholder_token, *args, **kwargs)
88
- output.append(placeholder_token)
89
- else:
90
- output = []
91
- for i in range(num_vec_per_token):
92
- ith_token = placeholder_token + f"_{i}"
93
- self.try_adding_tokens(ith_token, *args, **kwargs)
94
- output.append(ith_token)
95
-
96
- for token in self.token_map:
97
- if token in placeholder_token:
98
- raise ValueError(
99
- f"The tokenizer already has placeholder token {token} "
100
- f"that can get confused with {placeholder_token} "
101
- "keep placeholder tokens independent"
102
- )
103
- self.token_map[placeholder_token] = output
104
-
105
- def replace_placeholder_tokens_in_text(
106
- self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
107
- ) -> Union[str, List[str]]:
108
- """Replace the keywords in text with placeholder tokens. This function
109
- will be called in `self.__call__` and `self.encode`.
110
-
111
- Args:
112
- text (Union[str, List[str]]): The text to be processed.
113
- vector_shuffle (bool, optional): Whether to shuffle the vectors.
114
- Defaults to False.
115
- prop_tokens_to_load (float, optional): The proportion of tokens to
116
- be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
117
-
118
- Returns:
119
- Union[str, List[str]]: The processed text.
120
- """
121
- if isinstance(text, list):
122
- output = []
123
- for i in range(len(text)):
124
- output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
125
- return output
126
-
127
- for placeholder_token in self.token_map:
128
- if placeholder_token in text:
129
- tokens = self.token_map[placeholder_token]
130
- tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
131
- if vector_shuffle:
132
- tokens = copy.copy(tokens)
133
- random.shuffle(tokens)
134
- text = text.replace(placeholder_token, " ".join(tokens))
135
- return text
136
-
137
- def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
138
- """Replace the placeholder tokens in text with the original keywords.
139
- This function will be called in `self.decode`.
140
-
141
- Args:
142
- text (Union[str, List[str]]): The text to be processed.
143
-
144
- Returns:
145
- Union[str, List[str]]: The processed text.
146
- """
147
- if isinstance(text, list):
148
- output = []
149
- for i in range(len(text)):
150
- output.append(self.replace_text_with_placeholder_tokens(text[i]))
151
- return output
152
-
153
- for placeholder_token, tokens in self.token_map.items():
154
- merged_tokens = " ".join(tokens)
155
- if merged_tokens in text:
156
- text = text.replace(merged_tokens, placeholder_token)
157
- return text
158
-
159
- def __call__(
160
- self,
161
- text: Union[str, List[str]],
162
- *args,
163
- vector_shuffle: bool = False,
164
- prop_tokens_to_load: float = 1.0,
165
- **kwargs,
166
- ):
167
- """The call function of the wrapper.
168
-
169
- Args:
170
- text (Union[str, List[str]]): The text to be tokenized.
171
- vector_shuffle (bool, optional): Whether to shuffle the vectors.
172
- Defaults to False.
173
- prop_tokens_to_load (float, optional): The proportion of tokens to
174
- be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
175
- *args, **kwargs: The arguments for `self.wrapped.__call__`.
176
- """
177
- replaced_text = self.replace_placeholder_tokens_in_text(
178
- text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
179
- )
180
-
181
- return self.wrapped.__call__(replaced_text, *args, **kwargs)
182
-
183
- def encode(self, text: Union[str, List[str]], *args, **kwargs):
184
- """Encode the passed text to token index.
185
-
186
- Args:
187
- text (Union[str, List[str]]): The text to be encode.
188
- *args, **kwargs: The arguments for `self.wrapped.__call__`.
189
- """
190
- replaced_text = self.replace_placeholder_tokens_in_text(text)
191
- return self.wrapped(replaced_text, *args, **kwargs)
192
-
193
- def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
194
- """Decode the token index to text.
195
-
196
- Args:
197
- token_ids: The token index to be decoded.
198
- return_raw: Whether keep the placeholder token in the text.
199
- Defaults to False.
200
- *args, **kwargs: The arguments for `self.wrapped.decode`.
201
-
202
- Returns:
203
- Union[str, List[str]]: The decoded text.
204
- """
205
- text = self.wrapped.decode(token_ids, *args, **kwargs)
206
- if return_raw:
207
- return text
208
- replaced_text = self.replace_text_with_placeholder_tokens(text)
209
- return replaced_text
210
-
211
- def __repr__(self):
212
- """The representation of the wrapper."""
213
- s = super().__repr__()
214
- prefix = f"Wrapped Module Class: {self._module_cls}\n"
215
- prefix += f"Wrapped Module Name: {self._module_name}\n"
216
- if self._from_pretrained:
217
- prefix += f"From Pretrained: {self._from_pretrained}\n"
218
- s = prefix + s
219
- return s
220
-
221
-
222
- class EmbeddingLayerWithFixes(nn.Module):
223
- """The revised embedding layer to support external embeddings. This design
224
- of this class is inspired by https://github.com/AUTOMATIC1111/stable-
225
- diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
226
- jack.py#L224 # noqa.
227
-
228
- Args:
229
- wrapped (nn.Emebdding): The embedding layer to be wrapped.
230
- external_embeddings (Union[dict, List[dict]], optional): The external
231
- embeddings added to this layer. Defaults to None.
232
- """
233
-
234
- def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
235
- super().__init__()
236
- self.wrapped = wrapped
237
- self.num_embeddings = wrapped.weight.shape[0]
238
-
239
- self.external_embeddings = []
240
- if external_embeddings:
241
- self.add_embeddings(external_embeddings)
242
-
243
- self.trainable_embeddings = nn.ParameterDict()
244
-
245
- @property
246
- def weight(self):
247
- """Get the weight of wrapped embedding layer."""
248
- return self.wrapped.weight
249
-
250
- def check_duplicate_names(self, embeddings: List[dict]):
251
- """Check whether duplicate names exist in list of 'external
252
- embeddings'.
253
-
254
- Args:
255
- embeddings (List[dict]): A list of embedding to be check.
256
- """
257
- names = [emb["name"] for emb in embeddings]
258
- assert len(names) == len(set(names)), (
259
- "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
260
- )
261
-
262
- def check_ids_overlap(self, embeddings):
263
- """Check whether overlap exist in token ids of 'external_embeddings'.
264
-
265
- Args:
266
- embeddings (List[dict]): A list of embedding to be check.
267
- """
268
- ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
269
- ids_range.sort() # sort by 'start'
270
- # check if 'end' has overlapping
271
- for idx in range(len(ids_range) - 1):
272
- name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
273
- assert ids_range[idx][1] <= ids_range[idx + 1][0], (
274
- f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
275
- )
276
-
277
- def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
278
- """Add external embeddings to this layer.
279
-
280
- Use case:
281
-
282
- >>> 1. Add token to tokenizer and get the token id.
283
- >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
284
- >>> # 'how much' in kiswahili
285
- >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
286
- >>>
287
- >>> 2. Add external embeddings to the model.
288
- >>> new_embedding = {
289
- >>> 'name': 'ngapi', # 'how much' in kiswahili
290
- >>> 'embedding': torch.ones(1, 15) * 4,
291
- >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
292
- >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
293
- >>> 'trainable': False # if True, will registry as a parameter
294
- >>> }
295
- >>> embedding_layer = nn.Embedding(10, 15)
296
- >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
297
- >>> embedding_layer_wrapper.add_embeddings(new_embedding)
298
- >>>
299
- >>> 3. Forward tokenizer and embedding layer!
300
- >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
301
- >>> input_ids = tokenizer(
302
- >>> input_text, padding='max_length', truncation=True,
303
- >>> return_tensors='pt')['input_ids']
304
- >>> out_feat = embedding_layer_wrapper(input_ids)
305
- >>>
306
- >>> 4. Let's validate the result!
307
- >>> assert (out_feat[0, 3: 7] == 2.3).all()
308
- >>> assert (out_feat[2, 5: 9] == 2.3).all()
309
-
310
- Args:
311
- embeddings (Union[dict, list[dict]]): The external embeddings to
312
- be added. Each dict must contain the following 4 fields: 'name'
313
- (the name of this embedding), 'embedding' (the embedding
314
- tensor), 'start' (the start token id of this embedding), 'end'
315
- (the end token id of this embedding). For example:
316
- `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
317
- """
318
- if isinstance(embeddings, dict):
319
- embeddings = [embeddings]
320
-
321
- self.external_embeddings += embeddings
322
- self.check_duplicate_names(self.external_embeddings)
323
- self.check_ids_overlap(self.external_embeddings)
324
-
325
- # set for trainable
326
- added_trainable_emb_info = []
327
- for embedding in embeddings:
328
- trainable = embedding.get("trainable", False)
329
- if trainable:
330
- name = embedding["name"]
331
- embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
332
- self.trainable_embeddings[name] = embedding["embedding"]
333
- added_trainable_emb_info.append(name)
334
-
335
- added_emb_info = [emb["name"] for emb in embeddings]
336
- added_emb_info = ", ".join(added_emb_info)
337
- print(f"Successfully add external embeddings: {added_emb_info}.", "current")
338
-
339
- if added_trainable_emb_info:
340
- added_trainable_emb_info = ", ".join(added_trainable_emb_info)
341
- print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
342
-
343
- def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
344
- """Replace external input ids to 0.
345
-
346
- Args:
347
- input_ids (torch.Tensor): The input ids to be replaced.
348
-
349
- Returns:
350
- torch.Tensor: The replaced input ids.
351
- """
352
- input_ids_fwd = input_ids.clone()
353
- input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
354
- return input_ids_fwd
355
-
356
- def replace_embeddings(
357
- self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
358
- ) -> torch.Tensor:
359
- """Replace external embedding to the embedding layer. Noted that, in
360
- this function we use `torch.cat` to avoid inplace modification.
361
-
362
- Args:
363
- input_ids (torch.Tensor): The original token ids. Shape like
364
- [LENGTH, ].
365
- embedding (torch.Tensor): The embedding of token ids after
366
- `replace_input_ids` function.
367
- external_embedding (dict): The external embedding to be replaced.
368
-
369
- Returns:
370
- torch.Tensor: The replaced embedding.
371
- """
372
- new_embedding = []
373
-
374
- name = external_embedding["name"]
375
- start = external_embedding["start"]
376
- end = external_embedding["end"]
377
- target_ids_to_replace = [i for i in range(start, end)]
378
- ext_emb = external_embedding["embedding"]
379
-
380
- # do not need to replace
381
- if not (input_ids == start).any():
382
- return embedding
383
-
384
- # start replace
385
- s_idx, e_idx = 0, 0
386
- while e_idx < len(input_ids):
387
- if input_ids[e_idx] == start:
388
- if e_idx != 0:
389
- # add embedding do not need to replace
390
- new_embedding.append(embedding[s_idx:e_idx])
391
-
392
- # check if the next embedding need to replace is valid
393
- actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
394
- assert actually_ids_to_replace == target_ids_to_replace, (
395
- f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
396
- f"Expect '{target_ids_to_replace}' for embedding "
397
- f"'{name}' but found '{actually_ids_to_replace}'."
398
- )
399
-
400
- new_embedding.append(ext_emb)
401
-
402
- s_idx = e_idx + end - start
403
- e_idx = s_idx + 1
404
- else:
405
- e_idx += 1
406
-
407
- if e_idx == len(input_ids):
408
- new_embedding.append(embedding[s_idx:e_idx])
409
-
410
- return torch.cat(new_embedding, dim=0)
411
-
412
- def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
413
- """The forward function.
414
-
415
- Args:
416
- input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
417
- [LENGTH, ].
418
- external_embeddings (Optional[List[dict]]): The external
419
- embeddings. If not passed, only `self.external_embeddings`
420
- will be used. Defaults to None.
421
-
422
- input_ids: shape like [bz, LENGTH] or [LENGTH].
423
- """
424
- assert input_ids.ndim in [1, 2]
425
- if input_ids.ndim == 1:
426
- input_ids = input_ids.unsqueeze(0)
427
-
428
- if external_embeddings is None and not self.external_embeddings:
429
- return self.wrapped(input_ids)
430
-
431
- input_ids_fwd = self.replace_input_ids(input_ids)
432
- inputs_embeds = self.wrapped(input_ids_fwd)
433
-
434
- vecs = []
435
-
436
- if external_embeddings is None:
437
- external_embeddings = []
438
- elif isinstance(external_embeddings, dict):
439
- external_embeddings = [external_embeddings]
440
- embeddings = self.external_embeddings + external_embeddings
441
-
442
- for input_id, embedding in zip(input_ids, inputs_embeds):
443
- new_embedding = embedding
444
- for external_embedding in embeddings:
445
- new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
446
- vecs.append(new_embedding)
447
-
448
- return torch.stack(vecs)
449
-
450
-
451
-
452
- def add_tokens(
453
- tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
454
- ):
455
- """Add token for training.
456
-
457
- # TODO: support add tokens as dict, then we can load pretrained tokens.
458
- """
459
- if initialize_tokens is not None:
460
- assert len(initialize_tokens) == len(
461
- placeholder_tokens
462
- ), "placeholder_token should be the same length as initialize_token"
463
- for ii in range(len(placeholder_tokens)):
464
- tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
465
-
466
- # text_encoder.set_embedding_layer()
467
- embedding_layer = text_encoder.text_model.embeddings.token_embedding
468
- text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
469
- embedding_layer = text_encoder.text_model.embeddings.token_embedding
470
-
471
- assert embedding_layer is not None, (
472
- "Do not support get embedding layer for current text encoder. " "Please check your configuration."
473
- )
474
- initialize_embedding = []
475
- if initialize_tokens is not None:
476
- for ii in range(len(placeholder_tokens)):
477
- init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
478
- temp_embedding = embedding_layer.weight[init_id]
479
- initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
480
- else:
481
- for ii in range(len(placeholder_tokens)):
482
- init_id = tokenizer("a").input_ids[1]
483
- temp_embedding = embedding_layer.weight[init_id]
484
- len_emb = temp_embedding.shape[0]
485
- init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
486
- initialize_embedding.append(init_weight)
487
-
488
- # initialize_embedding = torch.cat(initialize_embedding,dim=0)
489
-
490
- token_info_all = []
491
- for ii in range(len(placeholder_tokens)):
492
- token_info = tokenizer.get_token_info(placeholder_tokens[ii])
493
- token_info["embedding"] = initialize_embedding[ii]
494
- token_info["trainable"] = True
495
- token_info_all.append(token_info)
496
- embedding_layer.add_embeddings(token_info_all)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet/unet_2d_blocks.py DELETED
The diff for this file is too large to render. See raw diff
 
MagicQuill/brushnet/unet_2d_condition.py DELETED
@@ -1,1355 +0,0 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- from dataclasses import dataclass
15
- from typing import Any, Dict, List, Optional, Tuple, Union
16
-
17
- import torch
18
- import torch.nn as nn
19
- import torch.utils.checkpoint
20
-
21
- from diffusers.configuration_utils import ConfigMixin, register_to_config
22
- from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
- from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
- from diffusers.models.activations import get_activation
25
- from diffusers.models.attention_processor import (
26
- ADDED_KV_ATTENTION_PROCESSORS,
27
- CROSS_ATTENTION_PROCESSORS,
28
- Attention,
29
- AttentionProcessor,
30
- AttnAddedKVProcessor,
31
- AttnProcessor,
32
- )
33
- from diffusers.models.embeddings import (
34
- GaussianFourierProjection,
35
- GLIGENTextBoundingboxProjection,
36
- ImageHintTimeEmbedding,
37
- ImageProjection,
38
- ImageTimeEmbedding,
39
- TextImageProjection,
40
- TextImageTimeEmbedding,
41
- TextTimeEmbedding,
42
- TimestepEmbedding,
43
- Timesteps,
44
- )
45
- from diffusers.models.modeling_utils import ModelMixin
46
- from .unet_2d_blocks import (
47
- get_down_block,
48
- get_mid_block,
49
- get_up_block,
50
- )
51
-
52
-
53
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
-
55
-
56
- @dataclass
57
- class UNet2DConditionOutput(BaseOutput):
58
- """
59
- The output of [`UNet2DConditionModel`].
60
-
61
- Args:
62
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
- """
65
-
66
- sample: torch.FloatTensor = None
67
-
68
-
69
- class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
70
- r"""
71
- A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
- shaped output.
73
-
74
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
- for all models (such as downloading or saving).
76
-
77
- Parameters:
78
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
- Height and width of input/output sample.
80
- in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
- center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
- flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
84
- Whether to flip the sin to cos in the time embedding.
85
- freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
- The tuple of downsample blocks to use.
88
- mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
90
- `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
- The tuple of upsample blocks to use.
93
- only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
- Whether to include self-attention in the basic transformer blocks, see
95
- [`~models.attention.BasicTransformerBlock`].
96
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
- The tuple of output channels for each block.
98
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
- downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
- mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
103
- norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
104
- If `None`, normalization and activation layers is skipped in post-processing.
105
- norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107
- The dimension of the cross attention features.
108
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
- reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
113
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
114
- blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
115
- [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
116
- [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
- encoder_hid_dim (`int`, *optional*, defaults to None):
118
- If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
119
- dimension to `cross_attention_dim`.
120
- encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
121
- If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
122
- embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
123
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
- num_attention_heads (`int`, *optional*):
125
- The number of attention heads. If not defined, defaults to `attention_head_dim`
126
- resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
127
- for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
128
- class_embed_type (`str`, *optional*, defaults to `None`):
129
- The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
130
- `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
131
- addition_embed_type (`str`, *optional*, defaults to `None`):
132
- Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
133
- "text". "text" will use the `TextTimeEmbedding` layer.
134
- addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
135
- Dimension for the timestep embeddings.
136
- num_class_embeds (`int`, *optional*, defaults to `None`):
137
- Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
138
- class conditioning with `class_embed_type` equal to `None`.
139
- time_embedding_type (`str`, *optional*, defaults to `positional`):
140
- The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
141
- time_embedding_dim (`int`, *optional*, defaults to `None`):
142
- An optional override for the dimension of the projected time embedding.
143
- time_embedding_act_fn (`str`, *optional*, defaults to `None`):
144
- Optional activation function to use only once on the time embeddings before they are passed to the rest of
145
- the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
146
- timestep_post_act (`str`, *optional*, defaults to `None`):
147
- The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
148
- time_cond_proj_dim (`int`, *optional*, defaults to `None`):
149
- The dimension of `cond_proj` layer in the timestep embedding.
150
- conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
151
- conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
152
- projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
153
- `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
154
- class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
155
- embeddings with the class embeddings.
156
- mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
157
- Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
158
- `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
159
- `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
160
- otherwise.
161
- """
162
-
163
- _supports_gradient_checkpointing = True
164
-
165
- @register_to_config
166
- def __init__(
167
- self,
168
- sample_size: Optional[int] = None,
169
- in_channels: int = 4,
170
- out_channels: int = 4,
171
- center_input_sample: bool = False,
172
- flip_sin_to_cos: bool = True,
173
- freq_shift: int = 0,
174
- down_block_types: Tuple[str] = (
175
- "CrossAttnDownBlock2D",
176
- "CrossAttnDownBlock2D",
177
- "CrossAttnDownBlock2D",
178
- "DownBlock2D",
179
- ),
180
- mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
181
- up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
182
- only_cross_attention: Union[bool, Tuple[bool]] = False,
183
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
- layers_per_block: Union[int, Tuple[int]] = 2,
185
- downsample_padding: int = 1,
186
- mid_block_scale_factor: float = 1,
187
- dropout: float = 0.0,
188
- act_fn: str = "silu",
189
- norm_num_groups: Optional[int] = 32,
190
- norm_eps: float = 1e-5,
191
- cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
- reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
- encoder_hid_dim: Optional[int] = None,
195
- encoder_hid_dim_type: Optional[str] = None,
196
- attention_head_dim: Union[int, Tuple[int]] = 8,
197
- num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
- dual_cross_attention: bool = False,
199
- use_linear_projection: bool = False,
200
- class_embed_type: Optional[str] = None,
201
- addition_embed_type: Optional[str] = None,
202
- addition_time_embed_dim: Optional[int] = None,
203
- num_class_embeds: Optional[int] = None,
204
- upcast_attention: bool = False,
205
- resnet_time_scale_shift: str = "default",
206
- resnet_skip_time_act: bool = False,
207
- resnet_out_scale_factor: float = 1.0,
208
- time_embedding_type: str = "positional",
209
- time_embedding_dim: Optional[int] = None,
210
- time_embedding_act_fn: Optional[str] = None,
211
- timestep_post_act: Optional[str] = None,
212
- time_cond_proj_dim: Optional[int] = None,
213
- conv_in_kernel: int = 3,
214
- conv_out_kernel: int = 3,
215
- projection_class_embeddings_input_dim: Optional[int] = None,
216
- attention_type: str = "default",
217
- class_embeddings_concat: bool = False,
218
- mid_block_only_cross_attention: Optional[bool] = None,
219
- cross_attention_norm: Optional[str] = None,
220
- addition_embed_type_num_heads: int = 64,
221
- ):
222
- super().__init__()
223
-
224
- self.sample_size = sample_size
225
-
226
- if num_attention_heads is not None:
227
- raise ValueError(
228
- "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
- )
230
-
231
- # If `num_attention_heads` is not defined (which is the case for most models)
232
- # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
- # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
- # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
- # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
- # which is why we correct for the naming here.
237
- num_attention_heads = num_attention_heads or attention_head_dim
238
-
239
- # Check inputs
240
- self._check_config(
241
- down_block_types=down_block_types,
242
- up_block_types=up_block_types,
243
- only_cross_attention=only_cross_attention,
244
- block_out_channels=block_out_channels,
245
- layers_per_block=layers_per_block,
246
- cross_attention_dim=cross_attention_dim,
247
- transformer_layers_per_block=transformer_layers_per_block,
248
- reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
249
- attention_head_dim=attention_head_dim,
250
- num_attention_heads=num_attention_heads,
251
- )
252
-
253
- # input
254
- conv_in_padding = (conv_in_kernel - 1) // 2
255
- self.conv_in = nn.Conv2d(
256
- in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
- )
258
-
259
- # time
260
- time_embed_dim, timestep_input_dim = self._set_time_proj(
261
- time_embedding_type,
262
- block_out_channels=block_out_channels,
263
- flip_sin_to_cos=flip_sin_to_cos,
264
- freq_shift=freq_shift,
265
- time_embedding_dim=time_embedding_dim,
266
- )
267
-
268
- self.time_embedding = TimestepEmbedding(
269
- timestep_input_dim,
270
- time_embed_dim,
271
- act_fn=act_fn,
272
- post_act_fn=timestep_post_act,
273
- cond_proj_dim=time_cond_proj_dim,
274
- )
275
-
276
- self._set_encoder_hid_proj(
277
- encoder_hid_dim_type,
278
- cross_attention_dim=cross_attention_dim,
279
- encoder_hid_dim=encoder_hid_dim,
280
- )
281
-
282
- # class embedding
283
- self._set_class_embedding(
284
- class_embed_type,
285
- act_fn=act_fn,
286
- num_class_embeds=num_class_embeds,
287
- projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
288
- time_embed_dim=time_embed_dim,
289
- timestep_input_dim=timestep_input_dim,
290
- )
291
-
292
- self._set_add_embedding(
293
- addition_embed_type,
294
- addition_embed_type_num_heads=addition_embed_type_num_heads,
295
- addition_time_embed_dim=addition_time_embed_dim,
296
- cross_attention_dim=cross_attention_dim,
297
- encoder_hid_dim=encoder_hid_dim,
298
- flip_sin_to_cos=flip_sin_to_cos,
299
- freq_shift=freq_shift,
300
- projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
- time_embed_dim=time_embed_dim,
302
- )
303
-
304
- if time_embedding_act_fn is None:
305
- self.time_embed_act = None
306
- else:
307
- self.time_embed_act = get_activation(time_embedding_act_fn)
308
-
309
- self.down_blocks = nn.ModuleList([])
310
- self.up_blocks = nn.ModuleList([])
311
-
312
- if isinstance(only_cross_attention, bool):
313
- if mid_block_only_cross_attention is None:
314
- mid_block_only_cross_attention = only_cross_attention
315
-
316
- only_cross_attention = [only_cross_attention] * len(down_block_types)
317
-
318
- if mid_block_only_cross_attention is None:
319
- mid_block_only_cross_attention = False
320
-
321
- if isinstance(num_attention_heads, int):
322
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
323
-
324
- if isinstance(attention_head_dim, int):
325
- attention_head_dim = (attention_head_dim,) * len(down_block_types)
326
-
327
- if isinstance(cross_attention_dim, int):
328
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
329
-
330
- if isinstance(layers_per_block, int):
331
- layers_per_block = [layers_per_block] * len(down_block_types)
332
-
333
- if isinstance(transformer_layers_per_block, int):
334
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
335
-
336
- if class_embeddings_concat:
337
- # The time embeddings are concatenated with the class embeddings. The dimension of the
338
- # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
339
- # regular time embeddings
340
- blocks_time_embed_dim = time_embed_dim * 2
341
- else:
342
- blocks_time_embed_dim = time_embed_dim
343
-
344
- # down
345
- output_channel = block_out_channels[0]
346
- for i, down_block_type in enumerate(down_block_types):
347
- input_channel = output_channel
348
- output_channel = block_out_channels[i]
349
- is_final_block = i == len(block_out_channels) - 1
350
-
351
- down_block = get_down_block(
352
- down_block_type,
353
- num_layers=layers_per_block[i],
354
- transformer_layers_per_block=transformer_layers_per_block[i],
355
- in_channels=input_channel,
356
- out_channels=output_channel,
357
- temb_channels=blocks_time_embed_dim,
358
- add_downsample=not is_final_block,
359
- resnet_eps=norm_eps,
360
- resnet_act_fn=act_fn,
361
- resnet_groups=norm_num_groups,
362
- cross_attention_dim=cross_attention_dim[i],
363
- num_attention_heads=num_attention_heads[i],
364
- downsample_padding=downsample_padding,
365
- dual_cross_attention=dual_cross_attention,
366
- use_linear_projection=use_linear_projection,
367
- only_cross_attention=only_cross_attention[i],
368
- upcast_attention=upcast_attention,
369
- resnet_time_scale_shift=resnet_time_scale_shift,
370
- attention_type=attention_type,
371
- resnet_skip_time_act=resnet_skip_time_act,
372
- resnet_out_scale_factor=resnet_out_scale_factor,
373
- cross_attention_norm=cross_attention_norm,
374
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
375
- dropout=dropout,
376
- )
377
- self.down_blocks.append(down_block)
378
-
379
- # mid
380
- self.mid_block = get_mid_block(
381
- mid_block_type,
382
- temb_channels=blocks_time_embed_dim,
383
- in_channels=block_out_channels[-1],
384
- resnet_eps=norm_eps,
385
- resnet_act_fn=act_fn,
386
- resnet_groups=norm_num_groups,
387
- output_scale_factor=mid_block_scale_factor,
388
- transformer_layers_per_block=transformer_layers_per_block[-1],
389
- num_attention_heads=num_attention_heads[-1],
390
- cross_attention_dim=cross_attention_dim[-1],
391
- dual_cross_attention=dual_cross_attention,
392
- use_linear_projection=use_linear_projection,
393
- mid_block_only_cross_attention=mid_block_only_cross_attention,
394
- upcast_attention=upcast_attention,
395
- resnet_time_scale_shift=resnet_time_scale_shift,
396
- attention_type=attention_type,
397
- resnet_skip_time_act=resnet_skip_time_act,
398
- cross_attention_norm=cross_attention_norm,
399
- attention_head_dim=attention_head_dim[-1],
400
- dropout=dropout,
401
- )
402
-
403
- # count how many layers upsample the images
404
- self.num_upsamplers = 0
405
-
406
- # up
407
- reversed_block_out_channels = list(reversed(block_out_channels))
408
- reversed_num_attention_heads = list(reversed(num_attention_heads))
409
- reversed_layers_per_block = list(reversed(layers_per_block))
410
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
411
- reversed_transformer_layers_per_block = (
412
- list(reversed(transformer_layers_per_block))
413
- if reverse_transformer_layers_per_block is None
414
- else reverse_transformer_layers_per_block
415
- )
416
- only_cross_attention = list(reversed(only_cross_attention))
417
-
418
- output_channel = reversed_block_out_channels[0]
419
- for i, up_block_type in enumerate(up_block_types):
420
- is_final_block = i == len(block_out_channels) - 1
421
-
422
- prev_output_channel = output_channel
423
- output_channel = reversed_block_out_channels[i]
424
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
425
-
426
- # add upsample block for all BUT final layer
427
- if not is_final_block:
428
- add_upsample = True
429
- self.num_upsamplers += 1
430
- else:
431
- add_upsample = False
432
-
433
- up_block = get_up_block(
434
- up_block_type,
435
- num_layers=reversed_layers_per_block[i] + 1,
436
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
437
- in_channels=input_channel,
438
- out_channels=output_channel,
439
- prev_output_channel=prev_output_channel,
440
- temb_channels=blocks_time_embed_dim,
441
- add_upsample=add_upsample,
442
- resnet_eps=norm_eps,
443
- resnet_act_fn=act_fn,
444
- resolution_idx=i,
445
- resnet_groups=norm_num_groups,
446
- cross_attention_dim=reversed_cross_attention_dim[i],
447
- num_attention_heads=reversed_num_attention_heads[i],
448
- dual_cross_attention=dual_cross_attention,
449
- use_linear_projection=use_linear_projection,
450
- only_cross_attention=only_cross_attention[i],
451
- upcast_attention=upcast_attention,
452
- resnet_time_scale_shift=resnet_time_scale_shift,
453
- attention_type=attention_type,
454
- resnet_skip_time_act=resnet_skip_time_act,
455
- resnet_out_scale_factor=resnet_out_scale_factor,
456
- cross_attention_norm=cross_attention_norm,
457
- attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
- dropout=dropout,
459
- )
460
- self.up_blocks.append(up_block)
461
- prev_output_channel = output_channel
462
-
463
- # out
464
- if norm_num_groups is not None:
465
- self.conv_norm_out = nn.GroupNorm(
466
- num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
467
- )
468
-
469
- self.conv_act = get_activation(act_fn)
470
-
471
- else:
472
- self.conv_norm_out = None
473
- self.conv_act = None
474
-
475
- conv_out_padding = (conv_out_kernel - 1) // 2
476
- self.conv_out = nn.Conv2d(
477
- block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
478
- )
479
-
480
- self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
481
-
482
- def _check_config(
483
- self,
484
- down_block_types: Tuple[str],
485
- up_block_types: Tuple[str],
486
- only_cross_attention: Union[bool, Tuple[bool]],
487
- block_out_channels: Tuple[int],
488
- layers_per_block: Union[int, Tuple[int]],
489
- cross_attention_dim: Union[int, Tuple[int]],
490
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
491
- reverse_transformer_layers_per_block: bool,
492
- attention_head_dim: int,
493
- num_attention_heads: Optional[Union[int, Tuple[int]]],
494
- ):
495
- if len(down_block_types) != len(up_block_types):
496
- raise ValueError(
497
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
498
- )
499
-
500
- if len(block_out_channels) != len(down_block_types):
501
- raise ValueError(
502
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
503
- )
504
-
505
- if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
506
- raise ValueError(
507
- f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
508
- )
509
-
510
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
511
- raise ValueError(
512
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
513
- )
514
-
515
- if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
516
- raise ValueError(
517
- f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
518
- )
519
-
520
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
521
- raise ValueError(
522
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
523
- )
524
-
525
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
526
- raise ValueError(
527
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
528
- )
529
- if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
530
- for layer_number_per_block in transformer_layers_per_block:
531
- if isinstance(layer_number_per_block, list):
532
- raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
533
-
534
- def _set_time_proj(
535
- self,
536
- time_embedding_type: str,
537
- block_out_channels: int,
538
- flip_sin_to_cos: bool,
539
- freq_shift: float,
540
- time_embedding_dim: int,
541
- ) -> Tuple[int, int]:
542
- if time_embedding_type == "fourier":
543
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
544
- if time_embed_dim % 2 != 0:
545
- raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
546
- self.time_proj = GaussianFourierProjection(
547
- time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
548
- )
549
- timestep_input_dim = time_embed_dim
550
- elif time_embedding_type == "positional":
551
- time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
552
-
553
- self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
554
- timestep_input_dim = block_out_channels[0]
555
- else:
556
- raise ValueError(
557
- f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
558
- )
559
-
560
- return time_embed_dim, timestep_input_dim
561
-
562
- def _set_encoder_hid_proj(
563
- self,
564
- encoder_hid_dim_type: Optional[str],
565
- cross_attention_dim: Union[int, Tuple[int]],
566
- encoder_hid_dim: Optional[int],
567
- ):
568
- if encoder_hid_dim_type is None and encoder_hid_dim is not None:
569
- encoder_hid_dim_type = "text_proj"
570
- self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
571
- logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
572
-
573
- if encoder_hid_dim is None and encoder_hid_dim_type is not None:
574
- raise ValueError(
575
- f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
576
- )
577
-
578
- if encoder_hid_dim_type == "text_proj":
579
- self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
580
- elif encoder_hid_dim_type == "text_image_proj":
581
- # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
582
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
583
- # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
584
- self.encoder_hid_proj = TextImageProjection(
585
- text_embed_dim=encoder_hid_dim,
586
- image_embed_dim=cross_attention_dim,
587
- cross_attention_dim=cross_attention_dim,
588
- )
589
- elif encoder_hid_dim_type == "image_proj":
590
- # Kandinsky 2.2
591
- self.encoder_hid_proj = ImageProjection(
592
- image_embed_dim=encoder_hid_dim,
593
- cross_attention_dim=cross_attention_dim,
594
- )
595
- elif encoder_hid_dim_type is not None:
596
- raise ValueError(
597
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
598
- )
599
- else:
600
- self.encoder_hid_proj = None
601
-
602
- def _set_class_embedding(
603
- self,
604
- class_embed_type: Optional[str],
605
- act_fn: str,
606
- num_class_embeds: Optional[int],
607
- projection_class_embeddings_input_dim: Optional[int],
608
- time_embed_dim: int,
609
- timestep_input_dim: int,
610
- ):
611
- if class_embed_type is None and num_class_embeds is not None:
612
- self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
613
- elif class_embed_type == "timestep":
614
- self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
615
- elif class_embed_type == "identity":
616
- self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
617
- elif class_embed_type == "projection":
618
- if projection_class_embeddings_input_dim is None:
619
- raise ValueError(
620
- "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
621
- )
622
- # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
623
- # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
624
- # 2. it projects from an arbitrary input dimension.
625
- #
626
- # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
627
- # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
628
- # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
629
- self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
630
- elif class_embed_type == "simple_projection":
631
- if projection_class_embeddings_input_dim is None:
632
- raise ValueError(
633
- "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
634
- )
635
- self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
636
- else:
637
- self.class_embedding = None
638
-
639
- def _set_add_embedding(
640
- self,
641
- addition_embed_type: str,
642
- addition_embed_type_num_heads: int,
643
- addition_time_embed_dim: Optional[int],
644
- flip_sin_to_cos: bool,
645
- freq_shift: float,
646
- cross_attention_dim: Optional[int],
647
- encoder_hid_dim: Optional[int],
648
- projection_class_embeddings_input_dim: Optional[int],
649
- time_embed_dim: int,
650
- ):
651
- if addition_embed_type == "text":
652
- if encoder_hid_dim is not None:
653
- text_time_embedding_from_dim = encoder_hid_dim
654
- else:
655
- text_time_embedding_from_dim = cross_attention_dim
656
-
657
- self.add_embedding = TextTimeEmbedding(
658
- text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
659
- )
660
- elif addition_embed_type == "text_image":
661
- # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
662
- # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
663
- # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
664
- self.add_embedding = TextImageTimeEmbedding(
665
- text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
666
- )
667
- elif addition_embed_type == "text_time":
668
- self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
669
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
670
- elif addition_embed_type == "image":
671
- # Kandinsky 2.2
672
- self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
673
- elif addition_embed_type == "image_hint":
674
- # Kandinsky 2.2 ControlNet
675
- self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
676
- elif addition_embed_type is not None:
677
- raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
678
-
679
- def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
680
- if attention_type in ["gated", "gated-text-image"]:
681
- positive_len = 768
682
- if isinstance(cross_attention_dim, int):
683
- positive_len = cross_attention_dim
684
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
685
- positive_len = cross_attention_dim[0]
686
-
687
- feature_type = "text-only" if attention_type == "gated" else "text-image"
688
- self.position_net = GLIGENTextBoundingboxProjection(
689
- positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
690
- )
691
-
692
- @property
693
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
694
- r"""
695
- Returns:
696
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
697
- indexed by its weight name.
698
- """
699
- # set recursively
700
- processors = {}
701
-
702
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
703
- if hasattr(module, "get_processor"):
704
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
705
-
706
- for sub_name, child in module.named_children():
707
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
708
-
709
- return processors
710
-
711
- for name, module in self.named_children():
712
- fn_recursive_add_processors(name, module, processors)
713
-
714
- return processors
715
-
716
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
717
- r"""
718
- Sets the attention processor to use to compute attention.
719
-
720
- Parameters:
721
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
722
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
723
- for **all** `Attention` layers.
724
-
725
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
726
- processor. This is strongly recommended when setting trainable attention processors.
727
-
728
- """
729
- count = len(self.attn_processors.keys())
730
-
731
- if isinstance(processor, dict) and len(processor) != count:
732
- raise ValueError(
733
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
734
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
735
- )
736
-
737
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
738
- if hasattr(module, "set_processor"):
739
- if not isinstance(processor, dict):
740
- module.set_processor(processor)
741
- else:
742
- module.set_processor(processor.pop(f"{name}.processor"))
743
-
744
- for sub_name, child in module.named_children():
745
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
746
-
747
- for name, module in self.named_children():
748
- fn_recursive_attn_processor(name, module, processor)
749
-
750
- def set_default_attn_processor(self):
751
- """
752
- Disables custom attention processors and sets the default attention implementation.
753
- """
754
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
755
- processor = AttnAddedKVProcessor()
756
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
757
- processor = AttnProcessor()
758
- else:
759
- raise ValueError(
760
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
761
- )
762
-
763
- self.set_attn_processor(processor)
764
-
765
- def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
766
- r"""
767
- Enable sliced attention computation.
768
-
769
- When this option is enabled, the attention module splits the input tensor in slices to compute attention in
770
- several steps. This is useful for saving some memory in exchange for a small decrease in speed.
771
-
772
- Args:
773
- slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
774
- When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
775
- `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
776
- provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
777
- must be a multiple of `slice_size`.
778
- """
779
- sliceable_head_dims = []
780
-
781
- def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
782
- if hasattr(module, "set_attention_slice"):
783
- sliceable_head_dims.append(module.sliceable_head_dim)
784
-
785
- for child in module.children():
786
- fn_recursive_retrieve_sliceable_dims(child)
787
-
788
- # retrieve number of attention layers
789
- for module in self.children():
790
- fn_recursive_retrieve_sliceable_dims(module)
791
-
792
- num_sliceable_layers = len(sliceable_head_dims)
793
-
794
- if slice_size == "auto":
795
- # half the attention head size is usually a good trade-off between
796
- # speed and memory
797
- slice_size = [dim // 2 for dim in sliceable_head_dims]
798
- elif slice_size == "max":
799
- # make smallest slice possible
800
- slice_size = num_sliceable_layers * [1]
801
-
802
- slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
803
-
804
- if len(slice_size) != len(sliceable_head_dims):
805
- raise ValueError(
806
- f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
- f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
- )
809
-
810
- for i in range(len(slice_size)):
811
- size = slice_size[i]
812
- dim = sliceable_head_dims[i]
813
- if size is not None and size > dim:
814
- raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
-
816
- # Recursively walk through all the children.
817
- # Any children which exposes the set_attention_slice method
818
- # gets the message
819
- def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
820
- if hasattr(module, "set_attention_slice"):
821
- module.set_attention_slice(slice_size.pop())
822
-
823
- for child in module.children():
824
- fn_recursive_set_attention_slice(child, slice_size)
825
-
826
- reversed_slice_size = list(reversed(slice_size))
827
- for module in self.children():
828
- fn_recursive_set_attention_slice(module, reversed_slice_size)
829
-
830
- def _set_gradient_checkpointing(self, module, value=False):
831
- if hasattr(module, "gradient_checkpointing"):
832
- module.gradient_checkpointing = value
833
-
834
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
835
- r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836
-
837
- The suffixes after the scaling factors represent the stage blocks where they are being applied.
838
-
839
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
840
- are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
841
-
842
- Args:
843
- s1 (`float`):
844
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
845
- mitigate the "oversmoothing effect" in the enhanced denoising process.
846
- s2 (`float`):
847
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
848
- mitigate the "oversmoothing effect" in the enhanced denoising process.
849
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
850
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
851
- """
852
- for i, upsample_block in enumerate(self.up_blocks):
853
- setattr(upsample_block, "s1", s1)
854
- setattr(upsample_block, "s2", s2)
855
- setattr(upsample_block, "b1", b1)
856
- setattr(upsample_block, "b2", b2)
857
-
858
- def disable_freeu(self):
859
- """Disables the FreeU mechanism."""
860
- freeu_keys = {"s1", "s2", "b1", "b2"}
861
- for i, upsample_block in enumerate(self.up_blocks):
862
- for k in freeu_keys:
863
- if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
864
- setattr(upsample_block, k, None)
865
-
866
- def fuse_qkv_projections(self):
867
- """
868
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
869
- are fused. For cross-attention modules, key and value projection matrices are fused.
870
-
871
- <Tip warning={true}>
872
-
873
- This API is 🧪 experimental.
874
-
875
- </Tip>
876
- """
877
- self.original_attn_processors = None
878
-
879
- for _, attn_processor in self.attn_processors.items():
880
- if "Added" in str(attn_processor.__class__.__name__):
881
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
882
-
883
- self.original_attn_processors = self.attn_processors
884
-
885
- for module in self.modules():
886
- if isinstance(module, Attention):
887
- module.fuse_projections(fuse=True)
888
-
889
- def unfuse_qkv_projections(self):
890
- """Disables the fused QKV projection if enabled.
891
-
892
- <Tip warning={true}>
893
-
894
- This API is 🧪 experimental.
895
-
896
- </Tip>
897
-
898
- """
899
- if self.original_attn_processors is not None:
900
- self.set_attn_processor(self.original_attn_processors)
901
-
902
- def unload_lora(self):
903
- """Unloads LoRA weights."""
904
- deprecate(
905
- "unload_lora",
906
- "0.28.0",
907
- "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
908
- )
909
- for module in self.modules():
910
- if hasattr(module, "set_lora_layer"):
911
- module.set_lora_layer(None)
912
-
913
- def get_time_embed(
914
- self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
915
- ) -> Optional[torch.Tensor]:
916
- timesteps = timestep
917
- if not torch.is_tensor(timesteps):
918
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
- # This would be a good case for the `match` statement (Python 3.10+)
920
- is_mps = sample.device.type == "mps"
921
- if isinstance(timestep, float):
922
- dtype = torch.float32 if is_mps else torch.float64
923
- else:
924
- dtype = torch.int32 if is_mps else torch.int64
925
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
- elif len(timesteps.shape) == 0:
927
- timesteps = timesteps[None].to(sample.device)
928
-
929
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
- timesteps = timesteps.expand(sample.shape[0])
931
-
932
- t_emb = self.time_proj(timesteps)
933
- # `Timesteps` does not contain any weights and will always return f32 tensors
934
- # but time_embedding might actually be running in fp16. so we need to cast here.
935
- # there might be better ways to encapsulate this.
936
- t_emb = t_emb.to(dtype=sample.dtype)
937
- return t_emb
938
-
939
- def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
940
- class_emb = None
941
- if self.class_embedding is not None:
942
- if class_labels is None:
943
- raise ValueError("class_labels should be provided when num_class_embeds > 0")
944
-
945
- if self.config.class_embed_type == "timestep":
946
- class_labels = self.time_proj(class_labels)
947
-
948
- # `Timesteps` does not contain any weights and will always return f32 tensors
949
- # there might be better ways to encapsulate this.
950
- class_labels = class_labels.to(dtype=sample.dtype)
951
-
952
- class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
953
- return class_emb
954
-
955
- def get_aug_embed(
956
- self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
957
- ) -> Optional[torch.Tensor]:
958
- aug_emb = None
959
- if self.config.addition_embed_type == "text":
960
- aug_emb = self.add_embedding(encoder_hidden_states)
961
- elif self.config.addition_embed_type == "text_image":
962
- # Kandinsky 2.1 - style
963
- if "image_embeds" not in added_cond_kwargs:
964
- raise ValueError(
965
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
966
- )
967
-
968
- image_embs = added_cond_kwargs.get("image_embeds")
969
- text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
970
- aug_emb = self.add_embedding(text_embs, image_embs)
971
- elif self.config.addition_embed_type == "text_time":
972
- # SDXL - style
973
- if "text_embeds" not in added_cond_kwargs:
974
- raise ValueError(
975
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
976
- )
977
- text_embeds = added_cond_kwargs.get("text_embeds")
978
- if "time_ids" not in added_cond_kwargs:
979
- raise ValueError(
980
- f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
981
- )
982
- time_ids = added_cond_kwargs.get("time_ids")
983
- time_embeds = self.add_time_proj(time_ids.flatten())
984
- time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
985
- add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
986
- add_embeds = add_embeds.to(emb.dtype)
987
- aug_emb = self.add_embedding(add_embeds)
988
- elif self.config.addition_embed_type == "image":
989
- # Kandinsky 2.2 - style
990
- if "image_embeds" not in added_cond_kwargs:
991
- raise ValueError(
992
- f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
993
- )
994
- image_embs = added_cond_kwargs.get("image_embeds")
995
- aug_emb = self.add_embedding(image_embs)
996
- elif self.config.addition_embed_type == "image_hint":
997
- # Kandinsky 2.2 - style
998
- if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
999
- raise ValueError(
1000
- f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1001
- )
1002
- image_embs = added_cond_kwargs.get("image_embeds")
1003
- hint = added_cond_kwargs.get("hint")
1004
- aug_emb = self.add_embedding(image_embs, hint)
1005
- return aug_emb
1006
-
1007
- def process_encoder_hidden_states(
1008
- self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1009
- ) -> torch.Tensor:
1010
- if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1011
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1012
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1013
- # Kandinsky 2.1 - style
1014
- if "image_embeds" not in added_cond_kwargs:
1015
- raise ValueError(
1016
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1017
- )
1018
-
1019
- image_embeds = added_cond_kwargs.get("image_embeds")
1020
- encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1021
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1022
- # Kandinsky 2.2 - style
1023
- if "image_embeds" not in added_cond_kwargs:
1024
- raise ValueError(
1025
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1026
- )
1027
- image_embeds = added_cond_kwargs.get("image_embeds")
1028
- encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1029
- elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1030
- if "image_embeds" not in added_cond_kwargs:
1031
- raise ValueError(
1032
- f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
- )
1034
- image_embeds = added_cond_kwargs.get("image_embeds")
1035
- image_embeds = self.encoder_hid_proj(image_embeds)
1036
- encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037
- return encoder_hidden_states
1038
-
1039
- def forward(
1040
- self,
1041
- sample: torch.FloatTensor,
1042
- timestep: Union[torch.Tensor, float, int],
1043
- encoder_hidden_states: torch.Tensor,
1044
- class_labels: Optional[torch.Tensor] = None,
1045
- timestep_cond: Optional[torch.Tensor] = None,
1046
- attention_mask: Optional[torch.Tensor] = None,
1047
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1048
- added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1049
- down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1050
- mid_block_additional_residual: Optional[torch.Tensor] = None,
1051
- down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
- encoder_attention_mask: Optional[torch.Tensor] = None,
1053
- return_dict: bool = True,
1054
- down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1055
- mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
1056
- up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1057
- ) -> Union[UNet2DConditionOutput, Tuple]:
1058
- r"""
1059
- The [`UNet2DConditionModel`] forward method.
1060
-
1061
- Args:
1062
- sample (`torch.FloatTensor`):
1063
- The noisy input tensor with the following shape `(batch, channel, height, width)`.
1064
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1065
- encoder_hidden_states (`torch.FloatTensor`):
1066
- The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1067
- class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1068
- Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1069
- timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1070
- Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1071
- through the `self.time_embedding` layer to obtain the timestep embeddings.
1072
- attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1073
- An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1074
- is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1075
- negative values to the attention scores corresponding to "discard" tokens.
1076
- cross_attention_kwargs (`dict`, *optional*):
1077
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1078
- `self.processor` in
1079
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1080
- added_cond_kwargs: (`dict`, *optional*):
1081
- A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1082
- are passed along to the UNet blocks.
1083
- down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1084
- A tuple of tensors that if specified are added to the residuals of down unet blocks.
1085
- mid_block_additional_residual: (`torch.Tensor`, *optional*):
1086
- A tensor that if specified is added to the residual of the middle unet block.
1087
- down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1088
- additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1089
- encoder_attention_mask (`torch.Tensor`):
1090
- A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1091
- `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1092
- which adds large negative values to the attention scores corresponding to "discard" tokens.
1093
- return_dict (`bool`, *optional*, defaults to `True`):
1094
- Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1095
- tuple.
1096
-
1097
- Returns:
1098
- [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1099
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1100
- otherwise a `tuple` is returned where the first element is the sample tensor.
1101
- """
1102
- # By default samples have to be AT least a multiple of the overall upsampling factor.
1103
- # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1104
- # However, the upsampling interpolation output size can be forced to fit any upsampling size
1105
- # on the fly if necessary.
1106
- default_overall_up_factor = 2**self.num_upsamplers
1107
-
1108
- # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1109
- forward_upsample_size = False
1110
- upsample_size = None
1111
-
1112
- for dim in sample.shape[-2:]:
1113
- if dim % default_overall_up_factor != 0:
1114
- # Forward upsample size to force interpolation output size.
1115
- forward_upsample_size = True
1116
- break
1117
-
1118
- # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1119
- # expects mask of shape:
1120
- # [batch, key_tokens]
1121
- # adds singleton query_tokens dimension:
1122
- # [batch, 1, key_tokens]
1123
- # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1124
- # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1125
- # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1126
- if attention_mask is not None:
1127
- # assume that mask is expressed as:
1128
- # (1 = keep, 0 = discard)
1129
- # convert mask into a bias that can be added to attention scores:
1130
- # (keep = +0, discard = -10000.0)
1131
- attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1132
- attention_mask = attention_mask.unsqueeze(1)
1133
-
1134
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
1135
- if encoder_attention_mask is not None:
1136
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1137
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1138
-
1139
- # 0. center input if necessary
1140
- if self.config.center_input_sample:
1141
- sample = 2 * sample - 1.0
1142
-
1143
- # 1. time
1144
- t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1145
- emb = self.time_embedding(t_emb, timestep_cond)
1146
- aug_emb = None
1147
-
1148
- class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1149
- if class_emb is not None:
1150
- if self.config.class_embeddings_concat:
1151
- emb = torch.cat([emb, class_emb], dim=-1)
1152
- else:
1153
- emb = emb + class_emb
1154
-
1155
- aug_emb = self.get_aug_embed(
1156
- emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1157
- )
1158
- if self.config.addition_embed_type == "image_hint":
1159
- aug_emb, hint = aug_emb
1160
- sample = torch.cat([sample, hint], dim=1)
1161
-
1162
- emb = emb + aug_emb if aug_emb is not None else emb
1163
-
1164
- if self.time_embed_act is not None:
1165
- emb = self.time_embed_act(emb)
1166
-
1167
- encoder_hidden_states = self.process_encoder_hidden_states(
1168
- encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1169
- )
1170
-
1171
- # 2. pre-process
1172
- sample = self.conv_in(sample)
1173
-
1174
- # 2.5 GLIGEN position net
1175
- if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1176
- cross_attention_kwargs = cross_attention_kwargs.copy()
1177
- gligen_args = cross_attention_kwargs.pop("gligen")
1178
- cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1179
-
1180
- # 3. down
1181
- # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1182
- # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1183
- if cross_attention_kwargs is not None:
1184
- cross_attention_kwargs = cross_attention_kwargs.copy()
1185
- lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1186
- else:
1187
- lora_scale = 1.0
1188
-
1189
- if USE_PEFT_BACKEND:
1190
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1191
- scale_lora_layers(self, lora_scale)
1192
-
1193
- is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1194
- # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1195
- is_adapter = down_intrablock_additional_residuals is not None
1196
- # maintain backward compatibility for legacy usage, where
1197
- # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1198
- # but can only use one or the other
1199
- is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
1200
- if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1201
- deprecate(
1202
- "T2I should not use down_block_additional_residuals",
1203
- "1.3.0",
1204
- "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1205
- and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1206
- for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1207
- standard_warn=False,
1208
- )
1209
- down_intrablock_additional_residuals = down_block_additional_residuals
1210
- is_adapter = True
1211
-
1212
- down_block_res_samples = (sample,)
1213
-
1214
- if is_brushnet:
1215
- sample = sample + down_block_add_samples.pop(0)
1216
-
1217
- for downsample_block in self.down_blocks:
1218
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1219
- # For t2i-adapter CrossAttnDownBlock2D
1220
- additional_residuals = {}
1221
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1222
- additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1223
-
1224
- i = len(down_block_add_samples)
1225
-
1226
- if is_brushnet and len(down_block_add_samples)>0:
1227
- additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1228
- for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1229
-
1230
- sample, res_samples = downsample_block(
1231
- hidden_states=sample,
1232
- temb=emb,
1233
- encoder_hidden_states=encoder_hidden_states,
1234
- attention_mask=attention_mask,
1235
- cross_attention_kwargs=cross_attention_kwargs,
1236
- encoder_attention_mask=encoder_attention_mask,
1237
- **additional_residuals,
1238
- )
1239
- else:
1240
- additional_residuals = {}
1241
-
1242
- i = len(down_block_add_samples)
1243
-
1244
- if is_brushnet and len(down_block_add_samples)>0:
1245
- additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1246
- for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1247
-
1248
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
1249
- if is_adapter and len(down_intrablock_additional_residuals) > 0:
1250
- sample += down_intrablock_additional_residuals.pop(0)
1251
-
1252
- down_block_res_samples += res_samples
1253
-
1254
- if is_controlnet:
1255
- new_down_block_res_samples = ()
1256
-
1257
- for down_block_res_sample, down_block_additional_residual in zip(
1258
- down_block_res_samples, down_block_additional_residuals
1259
- ):
1260
- down_block_res_sample = down_block_res_sample + down_block_additional_residual
1261
- new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1262
-
1263
- down_block_res_samples = new_down_block_res_samples
1264
-
1265
- # 4. mid
1266
- if self.mid_block is not None:
1267
- if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1268
- sample = self.mid_block(
1269
- sample,
1270
- emb,
1271
- encoder_hidden_states=encoder_hidden_states,
1272
- attention_mask=attention_mask,
1273
- cross_attention_kwargs=cross_attention_kwargs,
1274
- encoder_attention_mask=encoder_attention_mask,
1275
- )
1276
- else:
1277
- sample = self.mid_block(sample, emb)
1278
-
1279
- # To support T2I-Adapter-XL
1280
- if (
1281
- is_adapter
1282
- and len(down_intrablock_additional_residuals) > 0
1283
- and sample.shape == down_intrablock_additional_residuals[0].shape
1284
- ):
1285
- sample += down_intrablock_additional_residuals.pop(0)
1286
-
1287
- if is_controlnet:
1288
- sample = sample + mid_block_additional_residual
1289
-
1290
- if is_brushnet:
1291
- sample = sample + mid_block_add_sample
1292
-
1293
- # 5. up
1294
- for i, upsample_block in enumerate(self.up_blocks):
1295
- is_final_block = i == len(self.up_blocks) - 1
1296
-
1297
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1298
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1299
-
1300
- # if we have not reached the final block and need to forward the
1301
- # upsample size, we do it here
1302
- if not is_final_block and forward_upsample_size:
1303
- upsample_size = down_block_res_samples[-1].shape[2:]
1304
-
1305
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1306
- additional_residuals = {}
1307
-
1308
- i = len(up_block_add_samples)
1309
-
1310
- if is_brushnet and len(up_block_add_samples)>0:
1311
- additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1312
- for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1313
-
1314
- sample = upsample_block(
1315
- hidden_states=sample,
1316
- temb=emb,
1317
- res_hidden_states_tuple=res_samples,
1318
- encoder_hidden_states=encoder_hidden_states,
1319
- cross_attention_kwargs=cross_attention_kwargs,
1320
- upsample_size=upsample_size,
1321
- attention_mask=attention_mask,
1322
- encoder_attention_mask=encoder_attention_mask,
1323
- **additional_residuals,
1324
- )
1325
- else:
1326
- additional_residuals = {}
1327
-
1328
- i = len(up_block_add_samples)
1329
-
1330
- if is_brushnet and len(up_block_add_samples)>0:
1331
- additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1332
- for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1333
-
1334
- sample = upsample_block(
1335
- hidden_states=sample,
1336
- temb=emb,
1337
- res_hidden_states_tuple=res_samples,
1338
- upsample_size=upsample_size,
1339
- **additional_residuals,
1340
- )
1341
-
1342
- # 6. post-process
1343
- if self.conv_norm_out:
1344
- sample = self.conv_norm_out(sample)
1345
- sample = self.conv_act(sample)
1346
- sample = self.conv_out(sample)
1347
-
1348
- if USE_PEFT_BACKEND:
1349
- # remove `lora_scale` from each PEFT layer
1350
- unscale_lora_layers(self, lora_scale)
1351
-
1352
- if not return_dict:
1353
- return (sample,)
1354
-
1355
- return UNet2DConditionOutput(sample=sample)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/brushnet_nodes.py DELETED
@@ -1,1094 +0,0 @@
1
- import os
2
- import types
3
- from typing import Tuple
4
-
5
- import torch
6
- import torchvision.transforms as T
7
- import torch.nn.functional as F
8
- from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
- import sys
10
-
11
- import comfy.sd
12
- import comfy.utils
13
- import comfy.model_management
14
- import comfy.sd1_clip
15
- import comfy.ldm.models.autoencoder
16
- import comfy.supported_models
17
-
18
- import folder_paths
19
-
20
- from .model_patch import add_model_patch_option, patch_model_function_wrapper
21
- from .brushnet.brushnet import BrushNetModel
22
- from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
23
- from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
24
-
25
- current_directory = os.path.dirname(os.path.abspath(__file__))
26
- brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
27
- brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
28
- powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
29
-
30
- sd15_scaling_factor = 0.18215
31
- sdxl_scaling_factor = 0.13025
32
-
33
- print(sys.path)
34
- ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
35
- comfy.ldm.models.autoencoder.AutoencoderKL
36
- ]
37
-
38
-
39
- class BrushNetLoader:
40
- @classmethod
41
- def INPUT_TYPES(self):
42
- self.inpaint_files = get_files_with_extension('inpaint')
43
- return {"required":
44
- {
45
- "brushnet": ([file for file in self.inpaint_files], ),
46
- "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
47
- },
48
- }
49
-
50
- CATEGORY = "inpaint"
51
- RETURN_TYPES = ("BRMODEL",)
52
- RETURN_NAMES = ("brushnet",)
53
-
54
- FUNCTION = "brushnet_loading"
55
-
56
- def brushnet_loading(self, brushnet, dtype):
57
- brushnet_file = os.path.join(self.inpaint_files[brushnet], brushnet)
58
- print('BrushNet model file:', brushnet_file)
59
- is_SDXL = False
60
- is_PP = False
61
- sd = comfy.utils.load_torch_file(brushnet_file)
62
- brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
63
- del sd
64
- if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
65
- is_SDXL = False
66
- if keys == 322:
67
- is_PP = False
68
- print('BrushNet model type: SD1.5')
69
- else:
70
- is_PP = True
71
- print('PowerPaint model type: SD1.5')
72
- elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
73
- print('BrushNet model type: Loading SDXL')
74
- is_SDXL = True
75
- is_PP = False
76
- else:
77
- raise Exception("Unknown BrushNet model")
78
-
79
- with init_empty_weights():
80
- if is_SDXL:
81
- brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
82
- brushnet_model = BrushNetModel.from_config(brushnet_config)
83
- elif is_PP:
84
- brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
85
- brushnet_model = PowerPaintModel.from_config(brushnet_config)
86
- else:
87
- brushnet_config = BrushNetModel.load_config(brushnet_config_file)
88
- brushnet_model = BrushNetModel.from_config(brushnet_config)
89
-
90
- if is_PP:
91
- print("PowerPaint model file:", brushnet_file)
92
- else:
93
- print("BrushNet model file:", brushnet_file)
94
-
95
- if dtype == 'float16':
96
- torch_dtype = torch.float16
97
- elif dtype == 'bfloat16':
98
- torch_dtype = torch.bfloat16
99
- elif dtype == 'float32':
100
- torch_dtype = torch.float32
101
- else:
102
- torch_dtype = torch.float64
103
-
104
- brushnet_model = load_checkpoint_and_dispatch(
105
- brushnet_model,
106
- brushnet_file,
107
- device_map="sequential",
108
- max_memory=None,
109
- offload_folder=None,
110
- offload_state_dict=False,
111
- dtype=torch_dtype,
112
- force_hooks=False,
113
- )
114
-
115
- if is_PP:
116
- print("PowerPaint model is loaded")
117
- elif is_SDXL:
118
- print("BrushNet SDXL model is loaded")
119
- else:
120
- print("BrushNet SD1.5 model is loaded")
121
-
122
- return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
123
-
124
-
125
- class PowerPaintCLIPLoader:
126
-
127
- @classmethod
128
- def INPUT_TYPES(self):
129
- self.inpaint_files = get_files_with_extension('inpaint', ['.bin'])
130
- self.clip_files = get_files_with_extension('clip')
131
- return {"required":
132
- {
133
- "base": ([file for file in self.clip_files], ),
134
- "powerpaint": ([file for file in self.inpaint_files], ),
135
- },
136
- }
137
-
138
- CATEGORY = "inpaint"
139
- RETURN_TYPES = ("CLIP",)
140
- RETURN_NAMES = ("clip",)
141
-
142
- FUNCTION = "ppclip_loading"
143
-
144
- def ppclip_loading(self, base, powerpaint):
145
- base_CLIP_file = os.path.join(self.clip_files[base], base)
146
- pp_CLIP_file = os.path.join(self.inpaint_files[powerpaint], powerpaint)
147
-
148
- pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
149
-
150
- print('PowerPaint base CLIP file: ', base_CLIP_file)
151
-
152
- pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
153
- pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
154
-
155
- add_tokens(
156
- tokenizer = pp_tokenizer,
157
- text_encoder = pp_text_encoder,
158
- placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
159
- initialize_tokens = ["a", "a", "a"],
160
- num_vectors_per_token = 10,
161
- )
162
-
163
- pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
164
-
165
- print('PowerPaint CLIP file: ', pp_CLIP_file)
166
-
167
- pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
168
- pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
169
-
170
- return (pp_clip,)
171
-
172
-
173
- class PowerPaint:
174
-
175
- @classmethod
176
- def INPUT_TYPES(s):
177
- return {"required":
178
- {
179
- "model": ("MODEL",),
180
- "vae": ("VAE", ),
181
- "image": ("IMAGE",),
182
- "mask": ("MASK",),
183
- "powerpaint": ("BRMODEL", ),
184
- "clip": ("CLIP", ),
185
- "positive": ("CONDITIONING", ),
186
- "negative": ("CONDITIONING", ),
187
- "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
188
- "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
189
- "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
190
- "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
191
- "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
192
- "save_memory": (['none', 'auto', 'max'], ),
193
- },
194
- }
195
-
196
- CATEGORY = "inpaint"
197
- RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
198
- RETURN_NAMES = ("model","positive","negative","latent",)
199
-
200
- FUNCTION = "model_update"
201
-
202
- def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at, save_memory):
203
-
204
- is_SDXL, is_PP = check_compatibilty(model, powerpaint)
205
- if not is_PP:
206
- raise Exception("BrushNet model was loaded, please use BrushNet node")
207
-
208
- # Make a copy of the model so that we're not patching it everywhere in the workflow.
209
- model = model.clone()
210
-
211
- # prepare image and mask
212
- # no batches for original image and mask
213
- masked_image, mask = prepare_image(image, mask)
214
-
215
- batch = masked_image.shape[0]
216
- #width = masked_image.shape[2]
217
- #height = masked_image.shape[1]
218
-
219
- if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
220
- scaling_factor = model.model.model_config.latent_format.scale_factor
221
- else:
222
- scaling_factor = sd15_scaling_factor
223
-
224
- torch_dtype = powerpaint['dtype']
225
-
226
- # prepare conditioning latents
227
- conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
228
- conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
229
- conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
230
-
231
- # prepare embeddings
232
-
233
- if function == "object removal":
234
- promptA = "P_ctxt"
235
- promptB = "P_ctxt"
236
- negative_promptA = "P_obj"
237
- negative_promptB = "P_obj"
238
- print('You should add to positive prompt: "empty scene blur"')
239
- #positive = positive + " empty scene blur"
240
- elif function == "context aware":
241
- promptA = "P_ctxt"
242
- promptB = "P_ctxt"
243
- negative_promptA = ""
244
- negative_promptB = ""
245
- #positive = positive + " empty scene"
246
- print('You should add to positive prompt: "empty scene"')
247
- elif function == "shape guided":
248
- promptA = "P_shape"
249
- promptB = "P_ctxt"
250
- negative_promptA = "P_shape"
251
- negative_promptB = "P_ctxt"
252
- elif function == "image outpainting":
253
- promptA = "P_ctxt"
254
- promptB = "P_ctxt"
255
- negative_promptA = "P_obj"
256
- negative_promptB = "P_obj"
257
- #positive = positive + " empty scene"
258
- print('You should add to positive prompt: "empty scene"')
259
- else:
260
- promptA = "P_obj"
261
- promptB = "P_obj"
262
- negative_promptA = "P_obj"
263
- negative_promptB = "P_obj"
264
-
265
- tokens = clip.tokenize(promptA)
266
- prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
267
-
268
- tokens = clip.tokenize(negative_promptA)
269
- negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
270
-
271
- tokens = clip.tokenize(promptB)
272
- prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
273
-
274
- tokens = clip.tokenize(negative_promptB)
275
- negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
276
-
277
- prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
278
- negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
279
-
280
- # unload vae and CLIPs
281
- del vae
282
- del clip
283
- for loaded_model in comfy.model_management.current_loaded_models:
284
- if type(loaded_model.model.model) in ModelsToUnload:
285
- comfy.model_management.current_loaded_models.remove(loaded_model)
286
- loaded_model.model_unload()
287
- del loaded_model
288
-
289
- # apply patch to model
290
-
291
- brushnet_conditioning_scale = scale
292
- control_guidance_start = start_at
293
- control_guidance_end = end_at
294
-
295
- if save_memory != 'none':
296
- powerpaint['brushnet'].set_attention_slice(save_memory)
297
-
298
- add_brushnet_patch(model,
299
- powerpaint['brushnet'],
300
- torch_dtype,
301
- conditioning_latents,
302
- (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
303
- negative_prompt_embeds_pp, prompt_embeds_pp,
304
- None, None, None,
305
- False)
306
-
307
- latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
308
-
309
- return (model, positive, negative, {"samples":latent},)
310
-
311
-
312
- class BrushNet:
313
-
314
- @classmethod
315
- def INPUT_TYPES(s):
316
- return {"required":
317
- {
318
- "model": ("MODEL",),
319
- "vae": ("VAE", ),
320
- "image": ("IMAGE",),
321
- "mask": ("MASK",),
322
- "brushnet": ("BRMODEL", ),
323
- "positive": ("CONDITIONING", ),
324
- "negative": ("CONDITIONING", ),
325
- "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
326
- "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
327
- "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
328
- },
329
- }
330
-
331
- CATEGORY = "inpaint"
332
- RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
333
- RETURN_NAMES = ("model","positive","negative","latent",)
334
-
335
- FUNCTION = "model_update"
336
-
337
- def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
338
-
339
- is_SDXL, is_PP = check_compatibilty(model, brushnet)
340
-
341
- if is_PP:
342
- raise Exception("PowerPaint model was loaded, please use PowerPaint node")
343
-
344
- # Make a copy of the model so that we're not patching it everywhere in the workflow.
345
- model = model.clone()
346
-
347
- # prepare image and mask
348
- # no batches for original image and mask
349
- masked_image, mask = prepare_image(image, mask)
350
-
351
- batch = masked_image.shape[0]
352
- width = masked_image.shape[2]
353
- height = masked_image.shape[1]
354
-
355
- if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
356
- scaling_factor = model.model.model_config.latent_format.scale_factor
357
- elif is_SDXL:
358
- scaling_factor = sdxl_scaling_factor
359
- else:
360
- scaling_factor = sd15_scaling_factor
361
-
362
- torch_dtype = brushnet['dtype']
363
-
364
- # prepare conditioning latents
365
- conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
366
- conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
367
- conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
368
-
369
- # unload vae
370
- del vae
371
- for loaded_model in comfy.model_management.current_loaded_models:
372
- if type(loaded_model.model.model) in ModelsToUnload:
373
- comfy.model_management.current_loaded_models.remove(loaded_model)
374
- loaded_model.model_unload()
375
- del loaded_model
376
-
377
- # prepare embeddings
378
-
379
- prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
380
- negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
381
-
382
- max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
383
- if prompt_embeds.shape[1] < max_tokens:
384
- multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
385
- prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
386
- print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
387
- if negative_prompt_embeds.shape[1] < max_tokens:
388
- multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
389
- negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
390
- print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
391
-
392
- if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
393
- pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
394
- else:
395
- print('BrushNet: positive conditioning has not pooled_output')
396
- if is_SDXL:
397
- print('BrushNet will not produce correct results')
398
- pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
399
-
400
- if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
401
- negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
402
- else:
403
- print('BrushNet: negative conditioning has not pooled_output')
404
- if is_SDXL:
405
- print('BrushNet will not produce correct results')
406
- negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
407
-
408
- time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
409
-
410
- if not is_SDXL:
411
- pooled_prompt_embeds = None
412
- negative_pooled_prompt_embeds = None
413
- time_ids = None
414
-
415
- # apply patch to model
416
-
417
- brushnet_conditioning_scale = scale
418
- control_guidance_start = start_at
419
- control_guidance_end = end_at
420
-
421
- add_brushnet_patch(model,
422
- brushnet['brushnet'],
423
- torch_dtype,
424
- conditioning_latents,
425
- (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
426
- prompt_embeds, negative_prompt_embeds,
427
- pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
428
- False)
429
-
430
- latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
431
-
432
- return (model, positive, negative, {"samples":latent},)
433
-
434
-
435
- class BlendInpaint:
436
-
437
- @classmethod
438
- def INPUT_TYPES(s):
439
- return {"required":
440
- {
441
- "inpaint": ("IMAGE",),
442
- "original": ("IMAGE",),
443
- "mask": ("MASK",),
444
- "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
445
- "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
446
- },
447
- "optional":
448
- {
449
- "origin": ("VECTOR",),
450
- },
451
- }
452
-
453
- CATEGORY = "inpaint"
454
- RETURN_TYPES = ("IMAGE","MASK",)
455
- RETURN_NAMES = ("image","MASK",)
456
-
457
- FUNCTION = "blend_inpaint"
458
-
459
- def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
460
-
461
- original, mask = check_image_mask(original, mask, 'Blend Inpaint')
462
-
463
- if len(inpaint.shape) < 4:
464
- # image tensor shape should be [B, H, W, C], but batch somehow is missing
465
- inpaint = inpaint[None,:,:,:]
466
-
467
- if inpaint.shape[0] < original.shape[0]:
468
- print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
469
- original= original[:inpaint.shape[0],:,:]
470
- mask = mask[:inpaint.shape[0],:,:]
471
-
472
- if inpaint.shape[0] > original.shape[0]:
473
- # batch over inpaint
474
- count = 0
475
- original_list = []
476
- mask_list = []
477
- origin_list = []
478
- while (count < inpaint.shape[0]):
479
- for i in range(original.shape[0]):
480
- original_list.append(original[i][None,:,:,:])
481
- mask_list.append(mask[i][None,:,:])
482
- if origin is not None:
483
- origin_list.append(origin[i][None,:])
484
- count += 1
485
- if count >= inpaint.shape[0]:
486
- break
487
- original = torch.concat(original_list, dim=0)
488
- mask = torch.concat(mask_list, dim=0)
489
- if origin is not None:
490
- origin = torch.concat(origin_list, dim=0)
491
-
492
- if kernel % 2 == 0:
493
- kernel += 1
494
- transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
495
-
496
- ret = []
497
- blurred = []
498
- for i in range(inpaint.shape[0]):
499
- if origin is None:
500
- blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
501
- blurred.append(blurred_mask[0])
502
-
503
- result = torch.nn.functional.interpolate(
504
- inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
505
- size=(
506
- original[i].shape[0],
507
- original[i].shape[1],
508
- )
509
- ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
510
- else:
511
- # got mask from CutForInpaint
512
- height, width, _ = original[i].shape
513
- x0 = origin[i][0].item()
514
- y0 = origin[i][1].item()
515
-
516
- if mask[i].shape[0] < height or mask[i].shape[1] < width:
517
- padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
518
- y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
519
- else:
520
- padded_mask = mask[i]
521
- blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
522
- blurred.append(blurred_mask[0][0])
523
-
524
- result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
525
- y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
526
- result = result[None,:,:,:].to(original.device).to(original.dtype)
527
-
528
- ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
529
-
530
- return (torch.stack(ret), torch.stack(blurred), )
531
-
532
-
533
- class CutForInpaint:
534
-
535
- @classmethod
536
- def INPUT_TYPES(s):
537
- return {"required":
538
- {
539
- "image": ("IMAGE",),
540
- "mask": ("MASK",),
541
- "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
542
- "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
543
- },
544
- }
545
-
546
- CATEGORY = "inpaint"
547
- RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
548
- RETURN_NAMES = ("image","mask","origin",)
549
-
550
- FUNCTION = "cut_for_inpaint"
551
-
552
- def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
553
-
554
- image, mask = check_image_mask(image, mask, 'BrushNet')
555
-
556
- ret = []
557
- msk = []
558
- org = []
559
- for i in range(image.shape[0]):
560
- x0, y0, w, h = cut_with_mask(mask[i], width, height)
561
- ret.append((image[i][y0:y0+h,x0:x0+w,:]))
562
- msk.append((mask[i][y0:y0+h,x0:x0+w]))
563
- org.append(torch.IntTensor([x0,y0]))
564
-
565
- return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
566
-
567
-
568
- #### Utility function
569
-
570
- def get_files_with_extension(folder_name, extension=['.safetensors']):
571
-
572
- try:
573
- folders = folder_paths.get_folder_paths(folder_name)
574
- except:
575
- folders = []
576
-
577
- if not folders:
578
- folders = [os.path.join(folder_paths.models_dir, folder_name)]
579
- if not os.path.isdir(folders[0]):
580
- folders = [os.path.join(folder_paths.base_path, folder_name)]
581
- if not os.path.isdir(folders[0]):
582
- return {}
583
-
584
- filtered_folders = []
585
- for x in folders:
586
- if not os.path.isdir(x):
587
- continue
588
- the_same = False
589
- for y in filtered_folders:
590
- if os.path.samefile(x, y):
591
- the_same = True
592
- break
593
- if not the_same:
594
- filtered_folders.append(x)
595
-
596
- if not filtered_folders:
597
- return {}
598
-
599
- output = {}
600
- for x in filtered_folders:
601
- files, folders_all = folder_paths.recursive_search(x, excluded_dir_names=[".git"])
602
- filtered_files = folder_paths.filter_files_extensions(files, extension)
603
-
604
- for f in filtered_files:
605
- output[f] = x
606
-
607
- return output
608
-
609
-
610
- # get blocks from state_dict so we could know which model it is
611
- def brushnet_blocks(sd):
612
- brushnet_down_block = 0
613
- brushnet_mid_block = 0
614
- brushnet_up_block = 0
615
- for key in sd:
616
- if 'brushnet_down_block' in key:
617
- brushnet_down_block += 1
618
- if 'brushnet_mid_block' in key:
619
- brushnet_mid_block += 1
620
- if 'brushnet_up_block' in key:
621
- brushnet_up_block += 1
622
- return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
623
-
624
-
625
- # Check models compatibility
626
- def check_compatibilty(model, brushnet):
627
- is_SDXL = False
628
- is_PP = False
629
- if isinstance(model.model.model_config, comfy.supported_models.SD15):
630
- print('Base model type: SD1.5')
631
- is_SDXL = False
632
- if brushnet["SDXL"]:
633
- raise Exception("Base model is SD15, but BrushNet is SDXL type")
634
- if brushnet["PP"]:
635
- is_PP = True
636
- elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
637
- print('Base model type: SDXL')
638
- is_SDXL = True
639
- if not brushnet["SDXL"]:
640
- raise Exception("Base model is SDXL, but BrushNet is SD15 type")
641
- else:
642
- print('Base model type: ', type(model.model.model_config))
643
- raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
644
-
645
- return (is_SDXL, is_PP)
646
-
647
-
648
- def check_image_mask(image, mask, name):
649
- if len(image.shape) < 4:
650
- # image tensor shape should be [B, H, W, C], but batch somehow is missing
651
- image = image[None,:,:,:]
652
-
653
- if len(mask.shape) > 3:
654
- # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
655
- # take first mask, red channel
656
- mask = (mask[:,:,:,0])[:,:,:]
657
- elif len(mask.shape) < 3:
658
- # mask tensor shape should be [B, H, W] but batch somehow is missing
659
- mask = mask[None,:,:]
660
-
661
- if image.shape[0] > mask.shape[0]:
662
- print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
663
- if mask.shape[0] == 1:
664
- print(name, "will copy the mask to fill batch")
665
- mask = torch.cat([mask] * image.shape[0], dim=0)
666
- else:
667
- print(name, "will add empty masks to fill batch")
668
- empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
669
- mask = torch.cat([mask, empty_mask], dim=0)
670
- elif image.shape[0] < mask.shape[0]:
671
- print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
672
- mask = mask[:image.shape[0],:,:]
673
-
674
- return (image, mask)
675
-
676
-
677
- # Prepare image and mask
678
- def prepare_image(image, mask):
679
-
680
- image, mask = check_image_mask(image, mask, 'BrushNet')
681
-
682
- print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
683
-
684
- if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
685
- raise Exception("Image and mask should be the same size")
686
-
687
- # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
688
- mask = mask.round()
689
-
690
- masked_image = image * (1.0 - mask[:,:,:,None])
691
-
692
- return (masked_image, mask)
693
-
694
-
695
- # Get origin of the mask
696
- def cut_with_mask(mask, width, height):
697
- iy, ix = (mask == 1).nonzero(as_tuple=True)
698
-
699
- h0, w0 = mask.shape
700
-
701
- if iy.numel() == 0:
702
- x_c = w0 / 2.0
703
- y_c = h0 / 2.0
704
- else:
705
- x_min = ix.min().item()
706
- x_max = ix.max().item()
707
- y_min = iy.min().item()
708
- y_max = iy.max().item()
709
-
710
- if x_max - x_min > width or y_max - y_min > height:
711
- raise Exception("Masked area is bigger than provided dimensions")
712
-
713
- x_c = (x_min + x_max) / 2.0
714
- y_c = (y_min + y_max) / 2.0
715
-
716
- width2 = width / 2.0
717
- height2 = height / 2.0
718
-
719
- if w0 <= width:
720
- x0 = 0
721
- w = w0
722
- else:
723
- x0 = max(0, x_c - width2)
724
- w = width
725
- if x0 + width > w0:
726
- x0 = w0 - width
727
-
728
- if h0 <= height:
729
- y0 = 0
730
- h = h0
731
- else:
732
- y0 = max(0, y_c - height2)
733
- h = height
734
- if y0 + height > h0:
735
- y0 = h0 - height
736
-
737
- return (int(x0), int(y0), int(w), int(h))
738
-
739
-
740
- # Prepare conditioning_latents
741
- @torch.inference_mode()
742
- def get_image_latents(masked_image, mask, vae, scaling_factor):
743
- processed_image = masked_image.to(vae.device)
744
- image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
745
- processed_mask = 1. - mask[:,None,:,:]
746
- interpolated_mask = torch.nn.functional.interpolate(
747
- processed_mask,
748
- size=(
749
- image_latents.shape[-2],
750
- image_latents.shape[-1]
751
- )
752
- )
753
- interpolated_mask = interpolated_mask.to(image_latents.device)
754
-
755
- conditioning_latents = [image_latents, interpolated_mask]
756
-
757
- print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
758
-
759
- return conditioning_latents
760
-
761
-
762
- # Main function where magic happens
763
- @torch.inference_mode()
764
- def brushnet_inference(x, timesteps, transformer_options, debug):
765
- if 'model_patch' not in transformer_options:
766
- print('BrushNet inference: there is no model_patch key in transformer_options')
767
- return ([], 0, [])
768
- mp = transformer_options['model_patch']
769
- if 'brushnet' not in mp:
770
- print('BrushNet inference: there is no brushnet key in mdel_patch')
771
- return ([], 0, [])
772
- bo = mp['brushnet']
773
- if 'model' not in bo:
774
- print('BrushNet inference: there is no model key in brushnet')
775
- return ([], 0, [])
776
- brushnet = bo['model']
777
- if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
778
- print('BrushNet model is not a BrushNetModel class')
779
- return ([], 0, [])
780
-
781
- torch_dtype = bo['dtype']
782
- cl_list = bo['latents']
783
- brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
784
- pe = bo['prompt_embeds']
785
- npe = bo['negative_prompt_embeds']
786
- ppe, nppe, time_ids = bo['add_embeds']
787
-
788
- #do_classifier_free_guidance = mp['free_guidance']
789
- do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
790
-
791
- x = x.detach().clone()
792
- x = x.to(torch_dtype).to(brushnet.device)
793
-
794
- timesteps = timesteps.detach().clone()
795
- timesteps = timesteps.to(torch_dtype).to(brushnet.device)
796
-
797
- total_steps = mp['total_steps']
798
- step = mp['step']
799
-
800
- added_cond_kwargs = {}
801
-
802
- if do_classifier_free_guidance and step == 0:
803
- print('BrushNet inference: do_classifier_free_guidance is True')
804
-
805
- sub_idx = None
806
- if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
807
- sub_idx = transformer_options['ad_params']['sub_idxs']
808
-
809
- # we have batch input images
810
- batch = cl_list[0].shape[0]
811
- # we have incoming latents
812
- latents_incoming = x.shape[0]
813
- # and we already got some
814
- latents_got = bo['latent_id']
815
- if step == 0 or batch > 1:
816
- print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
817
- % (step, batch, latents_incoming, latents_got))
818
-
819
- image_latents = []
820
- masks = []
821
- prompt_embeds = []
822
- negative_prompt_embeds = []
823
- pooled_prompt_embeds = []
824
- negative_pooled_prompt_embeds = []
825
- if sub_idx:
826
- # AnimateDiff indexes detected
827
- if step == 0:
828
- print('BrushNet inference: AnimateDiff indexes detected and applied')
829
-
830
- batch = len(sub_idx)
831
-
832
- if do_classifier_free_guidance:
833
- for i in sub_idx:
834
- image_latents.append(cl_list[0][i][None,:,:,:])
835
- masks.append(cl_list[1][i][None,:,:,:])
836
- prompt_embeds.append(pe)
837
- negative_prompt_embeds.append(npe)
838
- pooled_prompt_embeds.append(ppe)
839
- negative_pooled_prompt_embeds.append(nppe)
840
- for i in sub_idx:
841
- image_latents.append(cl_list[0][i][None,:,:,:])
842
- masks.append(cl_list[1][i][None,:,:,:])
843
- else:
844
- for i in sub_idx:
845
- image_latents.append(cl_list[0][i][None,:,:,:])
846
- masks.append(cl_list[1][i][None,:,:,:])
847
- prompt_embeds.append(pe)
848
- pooled_prompt_embeds.append(ppe)
849
- else:
850
- # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
851
- continue_batch = True
852
- for i in range(latents_incoming):
853
- number = latents_got + i
854
- if number < batch:
855
- # 1st pass, cond
856
- image_latents.append(cl_list[0][number][None,:,:,:])
857
- masks.append(cl_list[1][number][None,:,:,:])
858
- prompt_embeds.append(pe)
859
- pooled_prompt_embeds.append(ppe)
860
- elif do_classifier_free_guidance and number < batch * 2:
861
- # 2nd pass, uncond
862
- image_latents.append(cl_list[0][number-batch][None,:,:,:])
863
- masks.append(cl_list[1][number-batch][None,:,:,:])
864
- negative_prompt_embeds.append(npe)
865
- negative_pooled_prompt_embeds.append(nppe)
866
- else:
867
- # latent batch
868
- image_latents.append(cl_list[0][0][None,:,:,:])
869
- masks.append(cl_list[1][0][None,:,:,:])
870
- prompt_embeds.append(pe)
871
- pooled_prompt_embeds.append(ppe)
872
- latents_got = -i
873
- continue_batch = False
874
-
875
- if continue_batch:
876
- # we don't have full batch yet
877
- if do_classifier_free_guidance:
878
- if number < batch * 2 - 1:
879
- bo['latent_id'] = number + 1
880
- else:
881
- bo['latent_id'] = 0
882
- else:
883
- if number < batch - 1:
884
- bo['latent_id'] = number + 1
885
- else:
886
- bo['latent_id'] = 0
887
- else:
888
- bo['latent_id'] = 0
889
-
890
- cl = []
891
- for il, m in zip(image_latents, masks):
892
- cl.append(torch.concat([il, m], dim=1))
893
- cl2apply = torch.concat(cl, dim=0)
894
-
895
- conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
896
-
897
- # print("BrushNet CL: conditioning_latents shape =", conditioning_latents.shape)
898
- # print("BrushNet CL: x shape =", x.shape)
899
-
900
- prompt_embeds.extend(negative_prompt_embeds)
901
- prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
902
-
903
- if ppe is not None:
904
- added_cond_kwargs = {}
905
- added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
906
-
907
- pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
908
- pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
909
- added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
910
- else:
911
- added_cond_kwargs = None
912
-
913
- if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
914
- if step == 0:
915
- print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
916
- conditioning_latents = torch.nn.functional.interpolate(
917
- conditioning_latents, size=(
918
- x.shape[2],
919
- x.shape[3],
920
- ), mode='bicubic',
921
- ).to(torch_dtype).to(brushnet.device)
922
-
923
- if step == 0:
924
- print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape, 'dtype', torch_dtype)
925
-
926
- if debug: print('BrushNet: step =', step)
927
-
928
- if step < control_guidance_start or step > control_guidance_end:
929
- cond_scale = 0.0
930
- else:
931
- cond_scale = brushnet_conditioning_scale
932
-
933
- return brushnet(x,
934
- encoder_hidden_states=prompt_embeds,
935
- brushnet_cond=conditioning_latents,
936
- timestep = timesteps,
937
- conditioning_scale=cond_scale,
938
- guess_mode=False,
939
- added_cond_kwargs=added_cond_kwargs,
940
- return_dict=False,
941
- debug=debug,
942
- )
943
-
944
-
945
- # This is main patch function
946
- def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
947
- controls,
948
- prompt_embeds, negative_prompt_embeds,
949
- pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids,
950
- debug):
951
-
952
- is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
953
-
954
- if is_SDXL:
955
- input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
956
- [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
957
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
958
- [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
959
- [4, comfy.ldm.modules.attention.SpatialTransformer],
960
- [5, comfy.ldm.modules.attention.SpatialTransformer],
961
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
962
- [7, comfy.ldm.modules.attention.SpatialTransformer],
963
- [8, comfy.ldm.modules.attention.SpatialTransformer]]
964
- middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
965
- output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
966
- [1, comfy.ldm.modules.attention.SpatialTransformer],
967
- [2, comfy.ldm.modules.attention.SpatialTransformer],
968
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
969
- [3, comfy.ldm.modules.attention.SpatialTransformer],
970
- [4, comfy.ldm.modules.attention.SpatialTransformer],
971
- [5, comfy.ldm.modules.attention.SpatialTransformer],
972
- [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
973
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
974
- [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
975
- [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
976
- else:
977
- input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
978
- [1, comfy.ldm.modules.attention.SpatialTransformer],
979
- [2, comfy.ldm.modules.attention.SpatialTransformer],
980
- [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
981
- [4, comfy.ldm.modules.attention.SpatialTransformer],
982
- [5, comfy.ldm.modules.attention.SpatialTransformer],
983
- [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
984
- [7, comfy.ldm.modules.attention.SpatialTransformer],
985
- [8, comfy.ldm.modules.attention.SpatialTransformer],
986
- [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
987
- [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
988
- [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
989
- middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
990
- output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
991
- [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
992
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
993
- [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
994
- [3, comfy.ldm.modules.attention.SpatialTransformer],
995
- [4, comfy.ldm.modules.attention.SpatialTransformer],
996
- [5, comfy.ldm.modules.attention.SpatialTransformer],
997
- [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
998
- [6, comfy.ldm.modules.attention.SpatialTransformer],
999
- [7, comfy.ldm.modules.attention.SpatialTransformer],
1000
- [8, comfy.ldm.modules.attention.SpatialTransformer],
1001
- [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
1002
- [9, comfy.ldm.modules.attention.SpatialTransformer],
1003
- [10, comfy.ldm.modules.attention.SpatialTransformer],
1004
- [11, comfy.ldm.modules.attention.SpatialTransformer]]
1005
-
1006
- def last_layer_index(block, tp):
1007
- layer_list = []
1008
- for layer in block:
1009
- layer_list.append(type(layer))
1010
- layer_list.reverse()
1011
- if tp not in layer_list:
1012
- return -1, layer_list.reverse()
1013
- return len(layer_list) - 1 - layer_list.index(tp), layer_list
1014
-
1015
- def brushnet_forward(model, x, timesteps, transformer_options, control):
1016
- if 'brushnet' not in transformer_options['model_patch']:
1017
- input_samples = []
1018
- mid_sample = 0
1019
- output_samples = []
1020
- else:
1021
- # brushnet inference
1022
- input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options, debug)
1023
-
1024
- # give additional samples to blocks
1025
- for i, tp in input_blocks:
1026
- idx, layer_list = last_layer_index(model.input_blocks[i], tp)
1027
- if idx < 0:
1028
- print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
1029
- continue
1030
- model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
1031
-
1032
- idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
1033
- if idx < 0:
1034
- print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
1035
- model.middle_block[idx].add_sample_after = mid_sample
1036
-
1037
- for i, tp in output_blocks:
1038
- idx, layer_list = last_layer_index(model.output_blocks[i], tp)
1039
- if idx < 0:
1040
- print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
1041
- continue
1042
- model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
1043
-
1044
- patch_model_function_wrapper(model, brushnet_forward)
1045
-
1046
- to = add_model_patch_option(model)
1047
- mp = to['model_patch']
1048
- if 'brushnet' not in mp:
1049
- mp['brushnet'] = {}
1050
- bo = mp['brushnet']
1051
-
1052
- bo['model'] = brushnet
1053
- bo['dtype'] = torch_dtype
1054
- bo['latents'] = conditioning_latents
1055
- bo['controls'] = controls
1056
- bo['prompt_embeds'] = prompt_embeds
1057
- bo['negative_prompt_embeds'] = negative_prompt_embeds
1058
- bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
1059
- bo['latent_id'] = 0
1060
-
1061
- # patch layers `forward` so we can apply brushnet
1062
- def forward_patched_by_brushnet(self, x, *args, **kwargs):
1063
- h = self.original_forward(x, *args, **kwargs)
1064
- if hasattr(self, 'add_sample_after') and type(self):
1065
- to_add = self.add_sample_after
1066
- if torch.is_tensor(to_add):
1067
- # interpolate due to RAUNet
1068
- if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1069
- to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1070
- h += to_add.to(h.dtype).to(h.device)
1071
- else:
1072
- h += self.add_sample_after
1073
- self.add_sample_after = 0
1074
- return h
1075
-
1076
- for i, block in enumerate(model.model.diffusion_model.input_blocks):
1077
- for j, layer in enumerate(block):
1078
- if not hasattr(layer, 'original_forward'):
1079
- layer.original_forward = layer.forward
1080
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1081
- layer.add_sample_after = 0
1082
-
1083
- for j, layer in enumerate(model.model.diffusion_model.middle_block):
1084
- if not hasattr(layer, 'original_forward'):
1085
- layer.original_forward = layer.forward
1086
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1087
- layer.add_sample_after = 0
1088
-
1089
- for i, block in enumerate(model.model.diffusion_model.output_blocks):
1090
- for j, layer in enumerate(block):
1091
- if not hasattr(layer, 'original_forward'):
1092
- layer.original_forward = layer.forward
1093
- layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1094
- layer.add_sample_after = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/.DS_Store DELETED
Binary file (6.15 kB)
 
MagicQuill/comfy/checkpoint_pickle.py DELETED
@@ -1,13 +0,0 @@
1
- import pickle
2
-
3
- load = pickle.load
4
-
5
- class Empty:
6
- pass
7
-
8
- class Unpickler(pickle.Unpickler):
9
- def find_class(self, module, name):
10
- #TODO: safe unpickle
11
- if module.startswith("pytorch_lightning"):
12
- return Empty
13
- return super().find_class(module, name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/cldm/__pycache__/cldm.cpython-310.pyc DELETED
Binary file (6.11 kB)
 
MagicQuill/comfy/cldm/cldm.py DELETED
@@ -1,313 +0,0 @@
1
- #taken from: https://github.com/lllyasviel/ControlNet
2
- #and modified
3
-
4
- import torch
5
- import torch as th
6
- import torch.nn as nn
7
-
8
- from ..ldm.modules.diffusionmodules.util import (
9
- zero_module,
10
- timestep_embedding,
11
- )
12
-
13
- from ..ldm.modules.attention import SpatialTransformer
14
- from ..ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample
15
- from ..ldm.util import exists
16
- import comfy.ops
17
-
18
- class ControlledUnetModel(UNetModel):
19
- #implemented in the ldm unet
20
- pass
21
-
22
- class ControlNet(nn.Module):
23
- def __init__(
24
- self,
25
- image_size,
26
- in_channels,
27
- model_channels,
28
- hint_channels,
29
- num_res_blocks,
30
- dropout=0,
31
- channel_mult=(1, 2, 4, 8),
32
- conv_resample=True,
33
- dims=2,
34
- num_classes=None,
35
- use_checkpoint=False,
36
- dtype=torch.float32,
37
- num_heads=-1,
38
- num_head_channels=-1,
39
- num_heads_upsample=-1,
40
- use_scale_shift_norm=False,
41
- resblock_updown=False,
42
- use_new_attention_order=False,
43
- use_spatial_transformer=False, # custom transformer support
44
- transformer_depth=1, # custom transformer support
45
- context_dim=None, # custom transformer support
46
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
- legacy=True,
48
- disable_self_attentions=None,
49
- num_attention_blocks=None,
50
- disable_middle_self_attn=False,
51
- use_linear_in_transformer=False,
52
- adm_in_channels=None,
53
- transformer_depth_middle=None,
54
- transformer_depth_output=None,
55
- attn_precision=None,
56
- device=None,
57
- operations=comfy.ops.disable_weight_init,
58
- **kwargs,
59
- ):
60
- super().__init__()
61
- assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
62
- if use_spatial_transformer:
63
- assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
64
-
65
- if context_dim is not None:
66
- assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
67
- # from omegaconf.listconfig import ListConfig
68
- # if type(context_dim) == ListConfig:
69
- # context_dim = list(context_dim)
70
-
71
- if num_heads_upsample == -1:
72
- num_heads_upsample = num_heads
73
-
74
- if num_heads == -1:
75
- assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
76
-
77
- if num_head_channels == -1:
78
- assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
79
-
80
- self.dims = dims
81
- self.image_size = image_size
82
- self.in_channels = in_channels
83
- self.model_channels = model_channels
84
-
85
- if isinstance(num_res_blocks, int):
86
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
87
- else:
88
- if len(num_res_blocks) != len(channel_mult):
89
- raise ValueError("provide num_res_blocks either as an int (globally constant) or "
90
- "as a list/tuple (per-level) with the same length as channel_mult")
91
- self.num_res_blocks = num_res_blocks
92
-
93
- if disable_self_attentions is not None:
94
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
95
- assert len(disable_self_attentions) == len(channel_mult)
96
- if num_attention_blocks is not None:
97
- assert len(num_attention_blocks) == len(self.num_res_blocks)
98
- assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
99
-
100
- transformer_depth = transformer_depth[:]
101
-
102
- self.dropout = dropout
103
- self.channel_mult = channel_mult
104
- self.conv_resample = conv_resample
105
- self.num_classes = num_classes
106
- self.use_checkpoint = use_checkpoint
107
- self.dtype = dtype
108
- self.num_heads = num_heads
109
- self.num_head_channels = num_head_channels
110
- self.num_heads_upsample = num_heads_upsample
111
- self.predict_codebook_ids = n_embed is not None
112
-
113
- time_embed_dim = model_channels * 4
114
- self.time_embed = nn.Sequential(
115
- operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
116
- nn.SiLU(),
117
- operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
118
- )
119
-
120
- if self.num_classes is not None:
121
- if isinstance(self.num_classes, int):
122
- self.label_emb = nn.Embedding(num_classes, time_embed_dim)
123
- elif self.num_classes == "continuous":
124
- print("setting up linear c_adm embedding layer")
125
- self.label_emb = nn.Linear(1, time_embed_dim)
126
- elif self.num_classes == "sequential":
127
- assert adm_in_channels is not None
128
- self.label_emb = nn.Sequential(
129
- nn.Sequential(
130
- operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
131
- nn.SiLU(),
132
- operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
133
- )
134
- )
135
- else:
136
- raise ValueError()
137
-
138
- self.input_blocks = nn.ModuleList(
139
- [
140
- TimestepEmbedSequential(
141
- operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
142
- )
143
- ]
144
- )
145
- self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
146
-
147
- self.input_hint_block = TimestepEmbedSequential(
148
- operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
149
- nn.SiLU(),
150
- operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
151
- nn.SiLU(),
152
- operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
153
- nn.SiLU(),
154
- operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
155
- nn.SiLU(),
156
- operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
157
- nn.SiLU(),
158
- operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
159
- nn.SiLU(),
160
- operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
161
- nn.SiLU(),
162
- operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
163
- )
164
-
165
- self._feature_size = model_channels
166
- input_block_chans = [model_channels]
167
- ch = model_channels
168
- ds = 1
169
- for level, mult in enumerate(channel_mult):
170
- for nr in range(self.num_res_blocks[level]):
171
- layers = [
172
- ResBlock(
173
- ch,
174
- time_embed_dim,
175
- dropout,
176
- out_channels=mult * model_channels,
177
- dims=dims,
178
- use_checkpoint=use_checkpoint,
179
- use_scale_shift_norm=use_scale_shift_norm,
180
- dtype=self.dtype,
181
- device=device,
182
- operations=operations,
183
- )
184
- ]
185
- ch = mult * model_channels
186
- num_transformers = transformer_depth.pop(0)
187
- if num_transformers > 0:
188
- if num_head_channels == -1:
189
- dim_head = ch // num_heads
190
- else:
191
- num_heads = ch // num_head_channels
192
- dim_head = num_head_channels
193
- if legacy:
194
- #num_heads = 1
195
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
196
- if exists(disable_self_attentions):
197
- disabled_sa = disable_self_attentions[level]
198
- else:
199
- disabled_sa = False
200
-
201
- if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
202
- layers.append(
203
- SpatialTransformer(
204
- ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
205
- disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
206
- use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
207
- )
208
- )
209
- self.input_blocks.append(TimestepEmbedSequential(*layers))
210
- self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
211
- self._feature_size += ch
212
- input_block_chans.append(ch)
213
- if level != len(channel_mult) - 1:
214
- out_ch = ch
215
- self.input_blocks.append(
216
- TimestepEmbedSequential(
217
- ResBlock(
218
- ch,
219
- time_embed_dim,
220
- dropout,
221
- out_channels=out_ch,
222
- dims=dims,
223
- use_checkpoint=use_checkpoint,
224
- use_scale_shift_norm=use_scale_shift_norm,
225
- down=True,
226
- dtype=self.dtype,
227
- device=device,
228
- operations=operations
229
- )
230
- if resblock_updown
231
- else Downsample(
232
- ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
233
- )
234
- )
235
- )
236
- ch = out_ch
237
- input_block_chans.append(ch)
238
- self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
239
- ds *= 2
240
- self._feature_size += ch
241
-
242
- if num_head_channels == -1:
243
- dim_head = ch // num_heads
244
- else:
245
- num_heads = ch // num_head_channels
246
- dim_head = num_head_channels
247
- if legacy:
248
- #num_heads = 1
249
- dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
250
- mid_block = [
251
- ResBlock(
252
- ch,
253
- time_embed_dim,
254
- dropout,
255
- dims=dims,
256
- use_checkpoint=use_checkpoint,
257
- use_scale_shift_norm=use_scale_shift_norm,
258
- dtype=self.dtype,
259
- device=device,
260
- operations=operations
261
- )]
262
- if transformer_depth_middle >= 0:
263
- mid_block += [SpatialTransformer( # always uses a self-attn
264
- ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
265
- disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
266
- use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
267
- ),
268
- ResBlock(
269
- ch,
270
- time_embed_dim,
271
- dropout,
272
- dims=dims,
273
- use_checkpoint=use_checkpoint,
274
- use_scale_shift_norm=use_scale_shift_norm,
275
- dtype=self.dtype,
276
- device=device,
277
- operations=operations
278
- )]
279
- self.middle_block = TimestepEmbedSequential(*mid_block)
280
- self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
281
- self._feature_size += ch
282
-
283
- def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
284
- return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
285
-
286
- def forward(self, x, hint, timesteps, context, y=None, **kwargs):
287
- t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
288
- emb = self.time_embed(t_emb)
289
-
290
- guided_hint = self.input_hint_block(hint, emb, context)
291
-
292
- outs = []
293
-
294
- hs = []
295
- if self.num_classes is not None:
296
- assert y.shape[0] == x.shape[0]
297
- emb = emb + self.label_emb(y)
298
-
299
- h = x
300
- for module, zero_conv in zip(self.input_blocks, self.zero_convs):
301
- if guided_hint is not None:
302
- h = module(h, emb, context)
303
- h += guided_hint
304
- guided_hint = None
305
- else:
306
- h = module(h, emb, context)
307
- outs.append(zero_conv(h, emb, context))
308
-
309
- h = self.middle_block(h, emb, context)
310
- outs.append(self.middle_block_out(h, emb, context))
311
-
312
- return outs
313
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/cli_args.py DELETED
@@ -1,143 +0,0 @@
1
- import argparse
2
- import enum
3
- import comfy.options
4
-
5
- class EnumAction(argparse.Action):
6
- """
7
- Argparse action for handling Enums
8
- """
9
- def __init__(self, **kwargs):
10
- # Pop off the type value
11
- enum_type = kwargs.pop("type", None)
12
-
13
- # Ensure an Enum subclass is provided
14
- if enum_type is None:
15
- raise ValueError("type must be assigned an Enum when using EnumAction")
16
- if not issubclass(enum_type, enum.Enum):
17
- raise TypeError("type must be an Enum when using EnumAction")
18
-
19
- # Generate choices from the Enum
20
- choices = tuple(e.value for e in enum_type)
21
- kwargs.setdefault("choices", choices)
22
- kwargs.setdefault("metavar", f"[{','.join(list(choices))}]")
23
-
24
- super(EnumAction, self).__init__(**kwargs)
25
-
26
- self._enum = enum_type
27
-
28
- def __call__(self, parser, namespace, values, option_string=None):
29
- # Convert value back into an Enum
30
- value = self._enum(values)
31
- setattr(namespace, self.dest, value)
32
-
33
-
34
- parser = argparse.ArgumentParser()
35
-
36
- parser.add_argument("--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
37
- parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
38
- parser.add_argument("--tls-keyfile", type=str, help="Path to TLS (SSL) key file. Enables TLS, makes app accessible at https://... requires --tls-certfile to function")
39
- parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certificate file. Enables TLS, makes app accessible at https://... requires --tls-keyfile to function")
40
- parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
41
- parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
42
-
43
- parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
44
- parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
45
- parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
46
- parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
47
- parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
48
- parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
49
- parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
50
- cm_group = parser.add_mutually_exclusive_group()
51
- cm_group.add_argument("--cuda-malloc", action="store_true", help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
52
- cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
53
-
54
-
55
- fp_group = parser.add_mutually_exclusive_group()
56
- fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
57
- fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
58
-
59
- fpunet_group = parser.add_mutually_exclusive_group()
60
- fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
61
- fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
62
- fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
63
- fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
64
-
65
- fpvae_group = parser.add_mutually_exclusive_group()
66
- fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
67
- fpvae_group.add_argument("--fp32-vae", action="store_true", help="Run the VAE in full precision fp32.")
68
- fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in bf16.")
69
-
70
- parser.add_argument("--cpu-vae", action="store_true", help="Run the VAE on the CPU.")
71
-
72
- fpte_group = parser.add_mutually_exclusive_group()
73
- fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Store text encoder weights in fp8 (e4m3fn variant).")
74
- fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
75
- fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
76
- fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
77
-
78
- parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
79
-
80
- parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
81
-
82
- parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
83
-
84
- class LatentPreviewMethod(enum.Enum):
85
- NoPreviews = "none"
86
- Auto = "auto"
87
- Latent2RGB = "latent2rgb"
88
- TAESD = "taesd"
89
-
90
- parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
91
-
92
- attn_group = parser.add_mutually_exclusive_group()
93
- attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
94
- attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
95
- attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
96
-
97
- parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
98
-
99
- upcast = parser.add_mutually_exclusive_group()
100
- upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
101
- upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
102
-
103
-
104
- vram_group = parser.add_mutually_exclusive_group()
105
- vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
106
- vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
107
- vram_group.add_argument("--normalvram", action="store_true", help="Used to force normal vram use if lowvram gets automatically enabled.")
108
- vram_group.add_argument("--lowvram", action="store_true", help="Split the unet in parts to use less vram.")
109
- vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
110
- vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
111
-
112
-
113
- parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
114
- parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
115
-
116
- parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
117
- parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
118
- parser.add_argument("--windows-standalone-build", action="store_true", help="Windows standalone build: Enable convenient things that most people using the standalone windows build will probably enjoy (like auto opening the page on startup).")
119
-
120
- parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
121
-
122
- parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
123
-
124
- parser.add_argument("--verbose", action="store_true", help="Enables more debug prints.")
125
-
126
-
127
- if comfy.options.args_parsing:
128
- args = parser.parse_args()
129
- else:
130
- args = parser.parse_args([])
131
-
132
- if args.windows_standalone_build:
133
- args.auto_launch = True
134
-
135
- if args.disable_auto_launch:
136
- args.auto_launch = False
137
-
138
- import logging
139
- logging_level = logging.INFO
140
- if args.verbose:
141
- logging_level = logging.DEBUG
142
-
143
- logging.basicConfig(format="%(message)s", level=logging_level)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_config_bigg.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "architectures": [
3
- "CLIPTextModel"
4
- ],
5
- "attention_dropout": 0.0,
6
- "bos_token_id": 0,
7
- "dropout": 0.0,
8
- "eos_token_id": 2,
9
- "hidden_act": "gelu",
10
- "hidden_size": 1280,
11
- "initializer_factor": 1.0,
12
- "initializer_range": 0.02,
13
- "intermediate_size": 5120,
14
- "layer_norm_eps": 1e-05,
15
- "max_position_embeddings": 77,
16
- "model_type": "clip_text_model",
17
- "num_attention_heads": 20,
18
- "num_hidden_layers": 32,
19
- "pad_token_id": 1,
20
- "projection_dim": 1280,
21
- "torch_dtype": "float32",
22
- "vocab_size": 49408
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_model.py DELETED
@@ -1,194 +0,0 @@
1
- import torch
2
- from comfy.ldm.modules.attention import optimized_attention_for_device
3
-
4
- class CLIPAttention(torch.nn.Module):
5
- def __init__(self, embed_dim, heads, dtype, device, operations):
6
- super().__init__()
7
-
8
- self.heads = heads
9
- self.q_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
10
- self.k_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
11
- self.v_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
12
-
13
- self.out_proj = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
14
-
15
- def forward(self, x, mask=None, optimized_attention=None):
16
- q = self.q_proj(x)
17
- k = self.k_proj(x)
18
- v = self.v_proj(x)
19
-
20
- out = optimized_attention(q, k, v, self.heads, mask)
21
- return self.out_proj(out)
22
-
23
- ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
24
- "gelu": torch.nn.functional.gelu,
25
- }
26
-
27
- class CLIPMLP(torch.nn.Module):
28
- def __init__(self, embed_dim, intermediate_size, activation, dtype, device, operations):
29
- super().__init__()
30
- self.fc1 = operations.Linear(embed_dim, intermediate_size, bias=True, dtype=dtype, device=device)
31
- self.activation = ACTIVATIONS[activation]
32
- self.fc2 = operations.Linear(intermediate_size, embed_dim, bias=True, dtype=dtype, device=device)
33
-
34
- def forward(self, x):
35
- x = self.fc1(x)
36
- x = self.activation(x)
37
- x = self.fc2(x)
38
- return x
39
-
40
- class CLIPLayer(torch.nn.Module):
41
- def __init__(self, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
42
- super().__init__()
43
- self.layer_norm1 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
44
- self.self_attn = CLIPAttention(embed_dim, heads, dtype, device, operations)
45
- self.layer_norm2 = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
46
- self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device, operations)
47
-
48
- def forward(self, x, mask=None, optimized_attention=None):
49
- x += self.self_attn(self.layer_norm1(x), mask, optimized_attention)
50
- x += self.mlp(self.layer_norm2(x))
51
- return x
52
-
53
-
54
- class CLIPEncoder(torch.nn.Module):
55
- def __init__(self, num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations):
56
- super().__init__()
57
- self.layers = torch.nn.ModuleList([CLIPLayer(embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) for i in range(num_layers)])
58
-
59
- def forward(self, x, mask=None, intermediate_output=None):
60
- optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
61
-
62
- if intermediate_output is not None:
63
- if intermediate_output < 0:
64
- intermediate_output = len(self.layers) + intermediate_output
65
-
66
- intermediate = None
67
- for i, l in enumerate(self.layers):
68
- x = l(x, mask, optimized_attention)
69
- if i == intermediate_output:
70
- intermediate = x.clone()
71
- return x, intermediate
72
-
73
- class CLIPEmbeddings(torch.nn.Module):
74
- def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
75
- super().__init__()
76
- self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
77
- self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
78
-
79
- def forward(self, input_tokens):
80
- return self.token_embedding(input_tokens) + self.position_embedding.weight
81
-
82
-
83
- class CLIPTextModel_(torch.nn.Module):
84
- def __init__(self, config_dict, dtype, device, operations):
85
- num_layers = config_dict["num_hidden_layers"]
86
- embed_dim = config_dict["hidden_size"]
87
- heads = config_dict["num_attention_heads"]
88
- intermediate_size = config_dict["intermediate_size"]
89
- intermediate_activation = config_dict["hidden_act"]
90
-
91
- super().__init__()
92
- self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
93
- self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
94
- self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
95
-
96
- def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
97
- x = self.embeddings(input_tokens)
98
- mask = None
99
- if attention_mask is not None:
100
- mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
101
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
102
-
103
- causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
104
- if mask is not None:
105
- mask += causal_mask
106
- else:
107
- mask = causal_mask
108
-
109
- x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
110
- x = self.final_layer_norm(x)
111
- if i is not None and final_layer_norm_intermediate:
112
- i = self.final_layer_norm(i)
113
-
114
- pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
115
- return x, i, pooled_output
116
-
117
- class CLIPTextModel(torch.nn.Module):
118
- def __init__(self, config_dict, dtype, device, operations):
119
- super().__init__()
120
- self.num_layers = config_dict["num_hidden_layers"]
121
- self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
122
- embed_dim = config_dict["hidden_size"]
123
- self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
124
- self.text_projection.weight.copy_(torch.eye(embed_dim))
125
- self.dtype = dtype
126
-
127
- def get_input_embeddings(self):
128
- return self.text_model.embeddings.token_embedding
129
-
130
- def set_input_embeddings(self, embeddings):
131
- self.text_model.embeddings.token_embedding = embeddings
132
-
133
- def forward(self, *args, **kwargs):
134
- x = self.text_model(*args, **kwargs)
135
- out = self.text_projection(x[2])
136
- return (x[0], x[1], out, x[2])
137
-
138
-
139
- class CLIPVisionEmbeddings(torch.nn.Module):
140
- def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
141
- super().__init__()
142
- self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
143
-
144
- self.patch_embedding = operations.Conv2d(
145
- in_channels=num_channels,
146
- out_channels=embed_dim,
147
- kernel_size=patch_size,
148
- stride=patch_size,
149
- bias=False,
150
- dtype=dtype,
151
- device=device
152
- )
153
-
154
- num_patches = (image_size // patch_size) ** 2
155
- num_positions = num_patches + 1
156
- self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
157
-
158
- def forward(self, pixel_values):
159
- embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
160
- return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
161
-
162
-
163
- class CLIPVision(torch.nn.Module):
164
- def __init__(self, config_dict, dtype, device, operations):
165
- super().__init__()
166
- num_layers = config_dict["num_hidden_layers"]
167
- embed_dim = config_dict["hidden_size"]
168
- heads = config_dict["num_attention_heads"]
169
- intermediate_size = config_dict["intermediate_size"]
170
- intermediate_activation = config_dict["hidden_act"]
171
-
172
- self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
173
- self.pre_layrnorm = operations.LayerNorm(embed_dim)
174
- self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
175
- self.post_layernorm = operations.LayerNorm(embed_dim)
176
-
177
- def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
178
- x = self.embeddings(pixel_values)
179
- x = self.pre_layrnorm(x)
180
- #TODO: attention_mask?
181
- x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
182
- pooled_output = self.post_layernorm(x[:, 0, :])
183
- return x, i, pooled_output
184
-
185
- class CLIPVisionModelProjection(torch.nn.Module):
186
- def __init__(self, config_dict, dtype, device, operations):
187
- super().__init__()
188
- self.vision_model = CLIPVision(config_dict, dtype, device, operations)
189
- self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
190
-
191
- def forward(self, *args, **kwargs):
192
- x = self.vision_model(*args, **kwargs)
193
- out = self.visual_projection(x[2])
194
- return (x[0], x[1], out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_vision.py DELETED
@@ -1,117 +0,0 @@
1
- from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
2
- import os
3
- import torch
4
- import json
5
- import logging
6
-
7
- import comfy.ops
8
- import comfy.model_patcher
9
- import comfy.model_management
10
- import comfy.utils
11
- import comfy.clip_model
12
-
13
- class Output:
14
- def __getitem__(self, key):
15
- return getattr(self, key)
16
- def __setitem__(self, key, item):
17
- setattr(self, key, item)
18
-
19
- def clip_preprocess(image, size=224):
20
- mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
21
- std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
22
- image = image.movedim(-1, 1)
23
- if not (image.shape[2] == size and image.shape[3] == size):
24
- scale = (size / min(image.shape[2], image.shape[3]))
25
- image = torch.nn.functional.interpolate(image, size=(round(scale * image.shape[2]), round(scale * image.shape[3])), mode="bicubic", antialias=True)
26
- h = (image.shape[2] - size)//2
27
- w = (image.shape[3] - size)//2
28
- image = image[:,:,h:h+size,w:w+size]
29
- image = torch.clip((255. * image), 0, 255).round() / 255.0
30
- return (image - mean.view([3,1,1])) / std.view([3,1,1])
31
-
32
- class ClipVisionModel():
33
- def __init__(self, json_config):
34
- with open(json_config) as f:
35
- config = json.load(f)
36
-
37
- self.load_device = comfy.model_management.text_encoder_device()
38
- offload_device = comfy.model_management.text_encoder_offload_device()
39
- self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
40
- self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
41
- self.model.eval()
42
-
43
- self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
44
-
45
- def load_sd(self, sd):
46
- return self.model.load_state_dict(sd, strict=False)
47
-
48
- def get_sd(self):
49
- return self.model.state_dict()
50
-
51
- def encode_image(self, image):
52
- comfy.model_management.load_model_gpu(self.patcher)
53
- pixel_values = clip_preprocess(image.to(self.load_device)).float()
54
- out = self.model(pixel_values=pixel_values, intermediate_output=-2)
55
-
56
- outputs = Output()
57
- outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
58
- outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
59
- outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
60
- return outputs
61
-
62
- def convert_to_transformers(sd, prefix):
63
- sd_k = sd.keys()
64
- if "{}transformer.resblocks.0.attn.in_proj_weight".format(prefix) in sd_k:
65
- keys_to_replace = {
66
- "{}class_embedding".format(prefix): "vision_model.embeddings.class_embedding",
67
- "{}conv1.weight".format(prefix): "vision_model.embeddings.patch_embedding.weight",
68
- "{}positional_embedding".format(prefix): "vision_model.embeddings.position_embedding.weight",
69
- "{}ln_post.bias".format(prefix): "vision_model.post_layernorm.bias",
70
- "{}ln_post.weight".format(prefix): "vision_model.post_layernorm.weight",
71
- "{}ln_pre.bias".format(prefix): "vision_model.pre_layrnorm.bias",
72
- "{}ln_pre.weight".format(prefix): "vision_model.pre_layrnorm.weight",
73
- }
74
-
75
- for x in keys_to_replace:
76
- if x in sd_k:
77
- sd[keys_to_replace[x]] = sd.pop(x)
78
-
79
- if "{}proj".format(prefix) in sd_k:
80
- sd['visual_projection.weight'] = sd.pop("{}proj".format(prefix)).transpose(0, 1)
81
-
82
- sd = transformers_convert(sd, prefix, "vision_model.", 48)
83
- else:
84
- replace_prefix = {prefix: ""}
85
- sd = state_dict_prefix_replace(sd, replace_prefix)
86
- return sd
87
-
88
- def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
89
- if convert_keys:
90
- sd = convert_to_transformers(sd, prefix)
91
- if "vision_model.encoder.layers.47.layer_norm1.weight" in sd:
92
- json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_g.json")
93
- elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
94
- json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
95
- elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
96
- json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
97
- else:
98
- return None
99
-
100
- clip = ClipVisionModel(json_config)
101
- m, u = clip.load_sd(sd)
102
- if len(m) > 0:
103
- logging.warning("missing clip vision: {}".format(m))
104
- u = set(u)
105
- keys = list(sd.keys())
106
- for k in keys:
107
- if k not in u:
108
- t = sd.pop(k)
109
- del t
110
- return clip
111
-
112
- def load(ckpt_path):
113
- sd = load_torch_file(ckpt_path)
114
- if "visual.transformer.resblocks.0.attn.in_proj_weight" in sd:
115
- return load_clipvision_from_sd(sd, prefix="visual.", convert_keys=True)
116
- else:
117
- return load_clipvision_from_sd(sd)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_vision_config_g.json DELETED
@@ -1,18 +0,0 @@
1
- {
2
- "attention_dropout": 0.0,
3
- "dropout": 0.0,
4
- "hidden_act": "gelu",
5
- "hidden_size": 1664,
6
- "image_size": 224,
7
- "initializer_factor": 1.0,
8
- "initializer_range": 0.02,
9
- "intermediate_size": 8192,
10
- "layer_norm_eps": 1e-05,
11
- "model_type": "clip_vision_model",
12
- "num_attention_heads": 16,
13
- "num_channels": 3,
14
- "num_hidden_layers": 48,
15
- "patch_size": 14,
16
- "projection_dim": 1280,
17
- "torch_dtype": "float32"
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_vision_config_h.json DELETED
@@ -1,18 +0,0 @@
1
- {
2
- "attention_dropout": 0.0,
3
- "dropout": 0.0,
4
- "hidden_act": "gelu",
5
- "hidden_size": 1280,
6
- "image_size": 224,
7
- "initializer_factor": 1.0,
8
- "initializer_range": 0.02,
9
- "intermediate_size": 5120,
10
- "layer_norm_eps": 1e-05,
11
- "model_type": "clip_vision_model",
12
- "num_attention_heads": 16,
13
- "num_channels": 3,
14
- "num_hidden_layers": 32,
15
- "patch_size": 14,
16
- "projection_dim": 1024,
17
- "torch_dtype": "float32"
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/clip_vision_config_vitl.json DELETED
@@ -1,18 +0,0 @@
1
- {
2
- "attention_dropout": 0.0,
3
- "dropout": 0.0,
4
- "hidden_act": "quick_gelu",
5
- "hidden_size": 1024,
6
- "image_size": 224,
7
- "initializer_factor": 1.0,
8
- "initializer_range": 0.02,
9
- "intermediate_size": 4096,
10
- "layer_norm_eps": 1e-05,
11
- "model_type": "clip_vision_model",
12
- "num_attention_heads": 16,
13
- "num_channels": 3,
14
- "num_hidden_layers": 24,
15
- "patch_size": 14,
16
- "projection_dim": 768,
17
- "torch_dtype": "float32"
18
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/conds.py DELETED
@@ -1,83 +0,0 @@
1
- import torch
2
- import math
3
- import comfy.utils
4
-
5
-
6
- def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
7
- return abs(a*b) // math.gcd(a, b)
8
-
9
- class CONDRegular:
10
- def __init__(self, cond):
11
- self.cond = cond
12
-
13
- def _copy_with(self, cond):
14
- return self.__class__(cond)
15
-
16
- def process_cond(self, batch_size, device, **kwargs):
17
- return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
18
-
19
- def can_concat(self, other):
20
- if self.cond.shape != other.cond.shape:
21
- return False
22
- return True
23
-
24
- def concat(self, others):
25
- conds = [self.cond]
26
- for x in others:
27
- conds.append(x.cond)
28
- return torch.cat(conds)
29
-
30
- class CONDNoiseShape(CONDRegular):
31
- def process_cond(self, batch_size, device, area, **kwargs):
32
- data = self.cond
33
- if area is not None:
34
- dims = len(area) // 2
35
- for i in range(dims):
36
- data = data.narrow(i + 2, area[i + dims], area[i])
37
-
38
- return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device))
39
-
40
-
41
- class CONDCrossAttn(CONDRegular):
42
- def can_concat(self, other):
43
- s1 = self.cond.shape
44
- s2 = other.cond.shape
45
- if s1 != s2:
46
- if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
47
- return False
48
-
49
- mult_min = lcm(s1[1], s2[1])
50
- diff = mult_min // min(s1[1], s2[1])
51
- if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
52
- return False
53
- return True
54
-
55
- def concat(self, others):
56
- conds = [self.cond]
57
- crossattn_max_len = self.cond.shape[1]
58
- for x in others:
59
- c = x.cond
60
- crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
61
- conds.append(c)
62
-
63
- out = []
64
- for c in conds:
65
- if c.shape[1] < crossattn_max_len:
66
- c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
67
- out.append(c)
68
- return torch.cat(out)
69
-
70
- class CONDConstant(CONDRegular):
71
- def __init__(self, cond):
72
- self.cond = cond
73
-
74
- def process_cond(self, batch_size, device, **kwargs):
75
- return self._copy_with(self.cond)
76
-
77
- def can_concat(self, other):
78
- if self.cond != other.cond:
79
- return False
80
- return True
81
-
82
- def concat(self, others):
83
- return self.cond
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/controlnet.py DELETED
@@ -1,554 +0,0 @@
1
- import torch
2
- import math
3
- import os
4
- import logging
5
- import comfy.utils
6
- import comfy.model_management
7
- import comfy.model_detection
8
- import comfy.model_patcher
9
- import comfy.ops
10
-
11
- import comfy.cldm.cldm
12
- import comfy.t2i_adapter.adapter
13
- import comfy.ldm.cascade.controlnet
14
-
15
-
16
- def broadcast_image_to(tensor, target_batch_size, batched_number):
17
- current_batch_size = tensor.shape[0]
18
- #print(current_batch_size, target_batch_size)
19
- if current_batch_size == 1:
20
- return tensor
21
-
22
- per_batch = target_batch_size // batched_number
23
- tensor = tensor[:per_batch]
24
-
25
- if per_batch > tensor.shape[0]:
26
- tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
27
-
28
- current_batch_size = tensor.shape[0]
29
- if current_batch_size == target_batch_size:
30
- return tensor
31
- else:
32
- return torch.cat([tensor] * batched_number, dim=0)
33
-
34
- class ControlBase:
35
- def __init__(self, device=None):
36
- self.cond_hint_original = None
37
- self.cond_hint = None
38
- self.strength = 1.0
39
- self.timestep_percent_range = (0.0, 1.0)
40
- self.global_average_pooling = False
41
- self.timestep_range = None
42
- self.compression_ratio = 8
43
- self.upscale_algorithm = 'nearest-exact'
44
-
45
- if device is None:
46
- device = comfy.model_management.get_torch_device()
47
- self.device = device
48
- self.previous_controlnet = None
49
-
50
- def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0)):
51
- self.cond_hint_original = cond_hint
52
- self.strength = strength
53
- self.timestep_percent_range = timestep_percent_range
54
- return self
55
-
56
- def pre_run(self, model, percent_to_timestep_function):
57
- self.timestep_range = (percent_to_timestep_function(self.timestep_percent_range[0]), percent_to_timestep_function(self.timestep_percent_range[1]))
58
- if self.previous_controlnet is not None:
59
- self.previous_controlnet.pre_run(model, percent_to_timestep_function)
60
-
61
- def set_previous_controlnet(self, controlnet):
62
- self.previous_controlnet = controlnet
63
- return self
64
-
65
- def cleanup(self):
66
- if self.previous_controlnet is not None:
67
- self.previous_controlnet.cleanup()
68
- if self.cond_hint is not None:
69
- del self.cond_hint
70
- self.cond_hint = None
71
- self.timestep_range = None
72
-
73
- def get_models(self):
74
- out = []
75
- if self.previous_controlnet is not None:
76
- out += self.previous_controlnet.get_models()
77
- return out
78
-
79
- def copy_to(self, c):
80
- c.cond_hint_original = self.cond_hint_original
81
- c.strength = self.strength
82
- c.timestep_percent_range = self.timestep_percent_range
83
- c.global_average_pooling = self.global_average_pooling
84
- c.compression_ratio = self.compression_ratio
85
- c.upscale_algorithm = self.upscale_algorithm
86
-
87
- def inference_memory_requirements(self, dtype):
88
- if self.previous_controlnet is not None:
89
- return self.previous_controlnet.inference_memory_requirements(dtype)
90
- return 0
91
-
92
- def control_merge(self, control_input, control_output, control_prev, output_dtype):
93
- out = {'input':[], 'middle':[], 'output': []}
94
-
95
- if control_input is not None:
96
- for i in range(len(control_input)):
97
- key = 'input'
98
- x = control_input[i]
99
- if x is not None:
100
- x *= self.strength
101
- if x.dtype != output_dtype:
102
- x = x.to(output_dtype)
103
- out[key].insert(0, x)
104
-
105
- if control_output is not None:
106
- for i in range(len(control_output)):
107
- if i == (len(control_output) - 1):
108
- key = 'middle'
109
- index = 0
110
- else:
111
- key = 'output'
112
- index = i
113
- x = control_output[i]
114
- if x is not None:
115
- if self.global_average_pooling:
116
- x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
117
-
118
- x *= self.strength
119
- if x.dtype != output_dtype:
120
- x = x.to(output_dtype)
121
-
122
- out[key].append(x)
123
- if control_prev is not None:
124
- for x in ['input', 'middle', 'output']:
125
- o = out[x]
126
- for i in range(len(control_prev[x])):
127
- prev_val = control_prev[x][i]
128
- if i >= len(o):
129
- o.append(prev_val)
130
- elif prev_val is not None:
131
- if o[i] is None:
132
- o[i] = prev_val
133
- else:
134
- if o[i].shape[0] < prev_val.shape[0]:
135
- o[i] = prev_val + o[i]
136
- else:
137
- o[i] += prev_val
138
- return out
139
-
140
- class ControlNet(ControlBase):
141
- def __init__(self, control_model=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
142
- super().__init__(device)
143
- self.control_model = control_model
144
- self.load_device = load_device
145
- if control_model is not None:
146
- self.control_model_wrapped = comfy.model_patcher.ModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
147
-
148
- self.global_average_pooling = global_average_pooling
149
- self.model_sampling_current = None
150
- self.manual_cast_dtype = manual_cast_dtype
151
-
152
- def get_control(self, x_noisy, t, cond, batched_number):
153
- control_prev = None
154
- if self.previous_controlnet is not None:
155
- control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
156
-
157
- if self.timestep_range is not None:
158
- if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
159
- if control_prev is not None:
160
- return control_prev
161
- else:
162
- return None
163
-
164
- dtype = self.control_model.dtype
165
- if self.manual_cast_dtype is not None:
166
- dtype = self.manual_cast_dtype
167
-
168
- output_dtype = x_noisy.dtype
169
- if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
170
- if self.cond_hint is not None:
171
- del self.cond_hint
172
- self.cond_hint = None
173
- self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio, self.upscale_algorithm, "center").to(dtype).to(self.device)
174
- if x_noisy.shape[0] != self.cond_hint.shape[0]:
175
- self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
176
-
177
- context = cond.get('crossattn_controlnet', cond['c_crossattn'])
178
- y = cond.get('y', None)
179
- if y is not None:
180
- y = y.to(dtype)
181
- timestep = self.model_sampling_current.timestep(t)
182
- x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
183
-
184
- control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
185
- return self.control_merge(None, control, control_prev, output_dtype)
186
-
187
- def copy(self):
188
- c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
189
- c.control_model = self.control_model
190
- c.control_model_wrapped = self.control_model_wrapped
191
- self.copy_to(c)
192
- return c
193
-
194
- def get_models(self):
195
- out = super().get_models()
196
- out.append(self.control_model_wrapped)
197
- return out
198
-
199
- def pre_run(self, model, percent_to_timestep_function):
200
- super().pre_run(model, percent_to_timestep_function)
201
- self.model_sampling_current = model.model_sampling
202
-
203
- def cleanup(self):
204
- self.model_sampling_current = None
205
- super().cleanup()
206
-
207
- class ControlLoraOps:
208
- class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
209
- def __init__(self, in_features: int, out_features: int, bias: bool = True,
210
- device=None, dtype=None) -> None:
211
- factory_kwargs = {'device': device, 'dtype': dtype}
212
- super().__init__()
213
- self.in_features = in_features
214
- self.out_features = out_features
215
- self.weight = None
216
- self.up = None
217
- self.down = None
218
- self.bias = None
219
-
220
- def forward(self, input):
221
- weight, bias = comfy.ops.cast_bias_weight(self, input)
222
- if self.up is not None:
223
- return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
224
- else:
225
- return torch.nn.functional.linear(input, weight, bias)
226
-
227
- class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
228
- def __init__(
229
- self,
230
- in_channels,
231
- out_channels,
232
- kernel_size,
233
- stride=1,
234
- padding=0,
235
- dilation=1,
236
- groups=1,
237
- bias=True,
238
- padding_mode='zeros',
239
- device=None,
240
- dtype=None
241
- ):
242
- super().__init__()
243
- self.in_channels = in_channels
244
- self.out_channels = out_channels
245
- self.kernel_size = kernel_size
246
- self.stride = stride
247
- self.padding = padding
248
- self.dilation = dilation
249
- self.transposed = False
250
- self.output_padding = 0
251
- self.groups = groups
252
- self.padding_mode = padding_mode
253
-
254
- self.weight = None
255
- self.bias = None
256
- self.up = None
257
- self.down = None
258
-
259
-
260
- def forward(self, input):
261
- weight, bias = comfy.ops.cast_bias_weight(self, input)
262
- if self.up is not None:
263
- return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
264
- else:
265
- return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
266
-
267
-
268
- class ControlLora(ControlNet):
269
- def __init__(self, control_weights, global_average_pooling=False, device=None):
270
- ControlBase.__init__(self, device)
271
- self.control_weights = control_weights
272
- self.global_average_pooling = global_average_pooling
273
-
274
- def pre_run(self, model, percent_to_timestep_function):
275
- super().pre_run(model, percent_to_timestep_function)
276
- controlnet_config = model.model_config.unet_config.copy()
277
- controlnet_config.pop("out_channels")
278
- controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
279
- self.manual_cast_dtype = model.manual_cast_dtype
280
- dtype = model.get_dtype()
281
- if self.manual_cast_dtype is None:
282
- class control_lora_ops(ControlLoraOps, comfy.ops.disable_weight_init):
283
- pass
284
- else:
285
- class control_lora_ops(ControlLoraOps, comfy.ops.manual_cast):
286
- pass
287
- dtype = self.manual_cast_dtype
288
-
289
- controlnet_config["operations"] = control_lora_ops
290
- controlnet_config["dtype"] = dtype
291
- self.control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
292
- self.control_model.to(comfy.model_management.get_torch_device())
293
- diffusion_model = model.diffusion_model
294
- sd = diffusion_model.state_dict()
295
- cm = self.control_model.state_dict()
296
-
297
- for k in sd:
298
- weight = sd[k]
299
- try:
300
- comfy.utils.set_attr_param(self.control_model, k, weight)
301
- except:
302
- pass
303
-
304
- for k in self.control_weights:
305
- if k not in {"lora_controlnet"}:
306
- comfy.utils.set_attr_param(self.control_model, k, self.control_weights[k].to(dtype).to(comfy.model_management.get_torch_device()))
307
-
308
- def copy(self):
309
- c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
310
- self.copy_to(c)
311
- return c
312
-
313
- def cleanup(self):
314
- del self.control_model
315
- self.control_model = None
316
- super().cleanup()
317
-
318
- def get_models(self):
319
- out = ControlBase.get_models(self)
320
- return out
321
-
322
- def inference_memory_requirements(self, dtype):
323
- return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
324
-
325
- def load_controlnet(ckpt_path, model=None):
326
- controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
327
- if "lora_controlnet" in controlnet_data:
328
- return ControlLora(controlnet_data)
329
-
330
- controlnet_config = None
331
- supported_inference_dtypes = None
332
-
333
- if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
334
- controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data)
335
- diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
336
- diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
337
- diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
338
-
339
- count = 0
340
- loop = True
341
- while loop:
342
- suffix = [".weight", ".bias"]
343
- for s in suffix:
344
- k_in = "controlnet_down_blocks.{}{}".format(count, s)
345
- k_out = "zero_convs.{}.0{}".format(count, s)
346
- if k_in not in controlnet_data:
347
- loop = False
348
- break
349
- diffusers_keys[k_in] = k_out
350
- count += 1
351
-
352
- count = 0
353
- loop = True
354
- while loop:
355
- suffix = [".weight", ".bias"]
356
- for s in suffix:
357
- if count == 0:
358
- k_in = "controlnet_cond_embedding.conv_in{}".format(s)
359
- else:
360
- k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
361
- k_out = "input_hint_block.{}{}".format(count * 2, s)
362
- if k_in not in controlnet_data:
363
- k_in = "controlnet_cond_embedding.conv_out{}".format(s)
364
- loop = False
365
- diffusers_keys[k_in] = k_out
366
- count += 1
367
-
368
- new_sd = {}
369
- for k in diffusers_keys:
370
- if k in controlnet_data:
371
- new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
372
-
373
- leftover_keys = controlnet_data.keys()
374
- if len(leftover_keys) > 0:
375
- logging.warning("leftover keys: {}".format(leftover_keys))
376
- controlnet_data = new_sd
377
-
378
- pth_key = 'control_model.zero_convs.0.0.weight'
379
- pth = False
380
- key = 'zero_convs.0.0.weight'
381
- if pth_key in controlnet_data:
382
- pth = True
383
- key = pth_key
384
- prefix = "control_model."
385
- elif key in controlnet_data:
386
- prefix = ""
387
- else:
388
- net = load_t2i_adapter(controlnet_data)
389
- if net is None:
390
- logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
391
- return net
392
-
393
- if controlnet_config is None:
394
- model_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, True)
395
- supported_inference_dtypes = model_config.supported_inference_dtypes
396
- controlnet_config = model_config.unet_config
397
-
398
- load_device = comfy.model_management.get_torch_device()
399
- if supported_inference_dtypes is None:
400
- unet_dtype = comfy.model_management.unet_dtype()
401
- else:
402
- unet_dtype = comfy.model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
403
-
404
- manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
405
- if manual_cast_dtype is not None:
406
- controlnet_config["operations"] = comfy.ops.manual_cast
407
- controlnet_config["dtype"] = unet_dtype
408
- controlnet_config.pop("out_channels")
409
- controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
410
- control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
411
-
412
- if pth:
413
- if 'difference' in controlnet_data:
414
- if model is not None:
415
- comfy.model_management.load_models_gpu([model])
416
- model_sd = model.model_state_dict()
417
- for x in controlnet_data:
418
- c_m = "control_model."
419
- if x.startswith(c_m):
420
- sd_key = "diffusion_model.{}".format(x[len(c_m):])
421
- if sd_key in model_sd:
422
- cd = controlnet_data[x]
423
- cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
424
- else:
425
- logging.warning("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
426
-
427
- class WeightsLoader(torch.nn.Module):
428
- pass
429
- w = WeightsLoader()
430
- w.control_model = control_model
431
- missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
432
- else:
433
- missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
434
-
435
- if len(missing) > 0:
436
- logging.warning("missing controlnet keys: {}".format(missing))
437
-
438
- if len(unexpected) > 0:
439
- logging.debug("unexpected controlnet keys: {}".format(unexpected))
440
-
441
- global_average_pooling = False
442
- filename = os.path.splitext(ckpt_path)[0]
443
- if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
444
- global_average_pooling = True
445
-
446
- control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
447
- return control
448
-
449
- class T2IAdapter(ControlBase):
450
- def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
451
- super().__init__(device)
452
- self.t2i_model = t2i_model
453
- self.channels_in = channels_in
454
- self.control_input = None
455
- self.compression_ratio = compression_ratio
456
- self.upscale_algorithm = upscale_algorithm
457
-
458
- def scale_image_to(self, width, height):
459
- unshuffle_amount = self.t2i_model.unshuffle_amount
460
- width = math.ceil(width / unshuffle_amount) * unshuffle_amount
461
- height = math.ceil(height / unshuffle_amount) * unshuffle_amount
462
- return width, height
463
-
464
- def get_control(self, x_noisy, t, cond, batched_number):
465
- control_prev = None
466
- if self.previous_controlnet is not None:
467
- control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
468
-
469
- if self.timestep_range is not None:
470
- if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
471
- if control_prev is not None:
472
- return control_prev
473
- else:
474
- return None
475
-
476
- if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
477
- if self.cond_hint is not None:
478
- del self.cond_hint
479
- self.control_input = None
480
- self.cond_hint = None
481
- width, height = self.scale_image_to(x_noisy.shape[3] * self.compression_ratio, x_noisy.shape[2] * self.compression_ratio)
482
- self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, width, height, self.upscale_algorithm, "center").float().to(self.device)
483
- if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
484
- self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
485
- if x_noisy.shape[0] != self.cond_hint.shape[0]:
486
- self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
487
- if self.control_input is None:
488
- self.t2i_model.to(x_noisy.dtype)
489
- self.t2i_model.to(self.device)
490
- self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
491
- self.t2i_model.cpu()
492
-
493
- control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
494
- mid = None
495
- if self.t2i_model.xl == True:
496
- mid = control_input[-1:]
497
- control_input = control_input[:-1]
498
- return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
499
-
500
- def copy(self):
501
- c = T2IAdapter(self.t2i_model, self.channels_in, self.compression_ratio, self.upscale_algorithm)
502
- self.copy_to(c)
503
- return c
504
-
505
- def load_t2i_adapter(t2i_data):
506
- compression_ratio = 8
507
- upscale_algorithm = 'nearest-exact'
508
-
509
- if 'adapter' in t2i_data:
510
- t2i_data = t2i_data['adapter']
511
- if 'adapter.body.0.resnets.0.block1.weight' in t2i_data: #diffusers format
512
- prefix_replace = {}
513
- for i in range(4):
514
- for j in range(2):
515
- prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
516
- prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2)
517
- prefix_replace["adapter."] = ""
518
- t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
519
- keys = t2i_data.keys()
520
-
521
- if "body.0.in_conv.weight" in keys:
522
- cin = t2i_data['body.0.in_conv.weight'].shape[1]
523
- model_ad = comfy.t2i_adapter.adapter.Adapter_light(cin=cin, channels=[320, 640, 1280, 1280], nums_rb=4)
524
- elif 'conv_in.weight' in keys:
525
- cin = t2i_data['conv_in.weight'].shape[1]
526
- channel = t2i_data['conv_in.weight'].shape[0]
527
- ksize = t2i_data['body.0.block2.weight'].shape[2]
528
- use_conv = False
529
- down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
530
- if len(down_opts) > 0:
531
- use_conv = True
532
- xl = False
533
- if cin == 256 or cin == 768:
534
- xl = True
535
- model_ad = comfy.t2i_adapter.adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
536
- elif "backbone.0.0.weight" in keys:
537
- model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.0.weight'].shape[1], proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
538
- compression_ratio = 32
539
- upscale_algorithm = 'bilinear'
540
- elif "backbone.10.blocks.0.weight" in keys:
541
- model_ad = comfy.ldm.cascade.controlnet.ControlNet(c_in=t2i_data['backbone.0.weight'].shape[1], bottleneck_mode="large", proj_blocks=[0, 4, 8, 12, 51, 55, 59, 63])
542
- compression_ratio = 1
543
- upscale_algorithm = 'nearest-exact'
544
- else:
545
- return None
546
-
547
- missing, unexpected = model_ad.load_state_dict(t2i_data)
548
- if len(missing) > 0:
549
- logging.warning("t2i missing {}".format(missing))
550
-
551
- if len(unexpected) > 0:
552
- logging.debug("t2i unexpected {}".format(unexpected))
553
-
554
- return T2IAdapter(model_ad, model_ad.input_channels, compression_ratio, upscale_algorithm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/diffusers_convert.py DELETED
@@ -1,281 +0,0 @@
1
- import re
2
- import torch
3
- import logging
4
-
5
- # conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
6
-
7
- # =================#
8
- # UNet Conversion #
9
- # =================#
10
-
11
- unet_conversion_map = [
12
- # (stable-diffusion, HF Diffusers)
13
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
14
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
15
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
16
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
17
- ("input_blocks.0.0.weight", "conv_in.weight"),
18
- ("input_blocks.0.0.bias", "conv_in.bias"),
19
- ("out.0.weight", "conv_norm_out.weight"),
20
- ("out.0.bias", "conv_norm_out.bias"),
21
- ("out.2.weight", "conv_out.weight"),
22
- ("out.2.bias", "conv_out.bias"),
23
- ]
24
-
25
- unet_conversion_map_resnet = [
26
- # (stable-diffusion, HF Diffusers)
27
- ("in_layers.0", "norm1"),
28
- ("in_layers.2", "conv1"),
29
- ("out_layers.0", "norm2"),
30
- ("out_layers.3", "conv2"),
31
- ("emb_layers.1", "time_emb_proj"),
32
- ("skip_connection", "conv_shortcut"),
33
- ]
34
-
35
- unet_conversion_map_layer = []
36
- # hardcoded number of downblocks and resnets/attentions...
37
- # would need smarter logic for other networks.
38
- for i in range(4):
39
- # loop over downblocks/upblocks
40
-
41
- for j in range(2):
42
- # loop over resnets/attentions for downblocks
43
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
44
- sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
45
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
46
-
47
- if i < 3:
48
- # no attention layers in down_blocks.3
49
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
50
- sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
51
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
52
-
53
- for j in range(3):
54
- # loop over resnets/attentions for upblocks
55
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
56
- sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
57
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
58
-
59
- if i > 0:
60
- # no attention layers in up_blocks.0
61
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
62
- sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
63
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
64
-
65
- if i < 3:
66
- # no downsample in down_blocks.3
67
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
68
- sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
69
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
70
-
71
- # no upsample in up_blocks.3
72
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
73
- sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
74
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
75
-
76
- hf_mid_atn_prefix = "mid_block.attentions.0."
77
- sd_mid_atn_prefix = "middle_block.1."
78
- unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
79
-
80
- for j in range(2):
81
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
82
- sd_mid_res_prefix = f"middle_block.{2 * j}."
83
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
84
-
85
-
86
- def convert_unet_state_dict(unet_state_dict):
87
- # buyer beware: this is a *brittle* function,
88
- # and correct output requires that all of these pieces interact in
89
- # the exact order in which I have arranged them.
90
- mapping = {k: k for k in unet_state_dict.keys()}
91
- for sd_name, hf_name in unet_conversion_map:
92
- mapping[hf_name] = sd_name
93
- for k, v in mapping.items():
94
- if "resnets" in k:
95
- for sd_part, hf_part in unet_conversion_map_resnet:
96
- v = v.replace(hf_part, sd_part)
97
- mapping[k] = v
98
- for k, v in mapping.items():
99
- for sd_part, hf_part in unet_conversion_map_layer:
100
- v = v.replace(hf_part, sd_part)
101
- mapping[k] = v
102
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
103
- return new_state_dict
104
-
105
-
106
- # ================#
107
- # VAE Conversion #
108
- # ================#
109
-
110
- vae_conversion_map = [
111
- # (stable-diffusion, HF Diffusers)
112
- ("nin_shortcut", "conv_shortcut"),
113
- ("norm_out", "conv_norm_out"),
114
- ("mid.attn_1.", "mid_block.attentions.0."),
115
- ]
116
-
117
- for i in range(4):
118
- # down_blocks have two resnets
119
- for j in range(2):
120
- hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
121
- sd_down_prefix = f"encoder.down.{i}.block.{j}."
122
- vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
123
-
124
- if i < 3:
125
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
126
- sd_downsample_prefix = f"down.{i}.downsample."
127
- vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
128
-
129
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
130
- sd_upsample_prefix = f"up.{3 - i}.upsample."
131
- vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
132
-
133
- # up_blocks have three resnets
134
- # also, up blocks in hf are numbered in reverse from sd
135
- for j in range(3):
136
- hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
137
- sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
138
- vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
139
-
140
- # this part accounts for mid blocks in both the encoder and the decoder
141
- for i in range(2):
142
- hf_mid_res_prefix = f"mid_block.resnets.{i}."
143
- sd_mid_res_prefix = f"mid.block_{i + 1}."
144
- vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
145
-
146
- vae_conversion_map_attn = [
147
- # (stable-diffusion, HF Diffusers)
148
- ("norm.", "group_norm."),
149
- ("q.", "query."),
150
- ("k.", "key."),
151
- ("v.", "value."),
152
- ("q.", "to_q."),
153
- ("k.", "to_k."),
154
- ("v.", "to_v."),
155
- ("proj_out.", "to_out.0."),
156
- ("proj_out.", "proj_attn."),
157
- ]
158
-
159
-
160
- def reshape_weight_for_sd(w):
161
- # convert HF linear weights to SD conv2d weights
162
- return w.reshape(*w.shape, 1, 1)
163
-
164
-
165
- def convert_vae_state_dict(vae_state_dict):
166
- mapping = {k: k for k in vae_state_dict.keys()}
167
- for k, v in mapping.items():
168
- for sd_part, hf_part in vae_conversion_map:
169
- v = v.replace(hf_part, sd_part)
170
- mapping[k] = v
171
- for k, v in mapping.items():
172
- if "attentions" in k:
173
- for sd_part, hf_part in vae_conversion_map_attn:
174
- v = v.replace(hf_part, sd_part)
175
- mapping[k] = v
176
- new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
177
- weights_to_convert = ["q", "k", "v", "proj_out"]
178
- for k, v in new_state_dict.items():
179
- for weight_name in weights_to_convert:
180
- if f"mid.attn_1.{weight_name}.weight" in k:
181
- logging.debug(f"Reshaping {k} for SD format")
182
- new_state_dict[k] = reshape_weight_for_sd(v)
183
- return new_state_dict
184
-
185
-
186
- # =========================#
187
- # Text Encoder Conversion #
188
- # =========================#
189
-
190
-
191
- textenc_conversion_lst = [
192
- # (stable-diffusion, HF Diffusers)
193
- ("resblocks.", "text_model.encoder.layers."),
194
- ("ln_1", "layer_norm1"),
195
- ("ln_2", "layer_norm2"),
196
- (".c_fc.", ".fc1."),
197
- (".c_proj.", ".fc2."),
198
- (".attn", ".self_attn"),
199
- ("ln_final.", "transformer.text_model.final_layer_norm."),
200
- ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
201
- ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
202
- ]
203
- protected = {re.escape(x[1]): x[0] for x in textenc_conversion_lst}
204
- textenc_pattern = re.compile("|".join(protected.keys()))
205
-
206
- # Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
207
- code2idx = {"q": 0, "k": 1, "v": 2}
208
-
209
- # This function exists because at the time of writing torch.cat can't do fp8 with cuda
210
- def cat_tensors(tensors):
211
- x = 0
212
- for t in tensors:
213
- x += t.shape[0]
214
-
215
- shape = [x] + list(tensors[0].shape)[1:]
216
- out = torch.empty(shape, device=tensors[0].device, dtype=tensors[0].dtype)
217
-
218
- x = 0
219
- for t in tensors:
220
- out[x:x + t.shape[0]] = t
221
- x += t.shape[0]
222
-
223
- return out
224
-
225
- def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
226
- new_state_dict = {}
227
- capture_qkv_weight = {}
228
- capture_qkv_bias = {}
229
- for k, v in text_enc_dict.items():
230
- if not k.startswith(prefix):
231
- continue
232
- if (
233
- k.endswith(".self_attn.q_proj.weight")
234
- or k.endswith(".self_attn.k_proj.weight")
235
- or k.endswith(".self_attn.v_proj.weight")
236
- ):
237
- k_pre = k[: -len(".q_proj.weight")]
238
- k_code = k[-len("q_proj.weight")]
239
- if k_pre not in capture_qkv_weight:
240
- capture_qkv_weight[k_pre] = [None, None, None]
241
- capture_qkv_weight[k_pre][code2idx[k_code]] = v
242
- continue
243
-
244
- if (
245
- k.endswith(".self_attn.q_proj.bias")
246
- or k.endswith(".self_attn.k_proj.bias")
247
- or k.endswith(".self_attn.v_proj.bias")
248
- ):
249
- k_pre = k[: -len(".q_proj.bias")]
250
- k_code = k[-len("q_proj.bias")]
251
- if k_pre not in capture_qkv_bias:
252
- capture_qkv_bias[k_pre] = [None, None, None]
253
- capture_qkv_bias[k_pre][code2idx[k_code]] = v
254
- continue
255
-
256
- text_proj = "transformer.text_projection.weight"
257
- if k.endswith(text_proj):
258
- new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
259
- else:
260
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
261
- new_state_dict[relabelled_key] = v
262
-
263
- for k_pre, tensors in capture_qkv_weight.items():
264
- if None in tensors:
265
- raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
266
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
267
- new_state_dict[relabelled_key + ".in_proj_weight"] = cat_tensors(tensors)
268
-
269
- for k_pre, tensors in capture_qkv_bias.items():
270
- if None in tensors:
271
- raise Exception("CORRUPTED MODEL: one of the q-k-v values for the text encoder was missing")
272
- relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k_pre)
273
- new_state_dict[relabelled_key + ".in_proj_bias"] = cat_tensors(tensors)
274
-
275
- return new_state_dict
276
-
277
-
278
- def convert_text_enc_state_dict(text_enc_dict):
279
- return text_enc_dict
280
-
281
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/diffusers_load.py DELETED
@@ -1,36 +0,0 @@
1
- import os
2
-
3
- import comfy.sd
4
-
5
- def first_file(path, filenames):
6
- for f in filenames:
7
- p = os.path.join(path, f)
8
- if os.path.exists(p):
9
- return p
10
- return None
11
-
12
- def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_directory=None):
13
- diffusion_model_names = ["diffusion_pytorch_model.fp16.safetensors", "diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.fp16.bin", "diffusion_pytorch_model.bin"]
14
- unet_path = first_file(os.path.join(model_path, "unet"), diffusion_model_names)
15
- vae_path = first_file(os.path.join(model_path, "vae"), diffusion_model_names)
16
-
17
- text_encoder_model_names = ["model.fp16.safetensors", "model.safetensors", "pytorch_model.fp16.bin", "pytorch_model.bin"]
18
- text_encoder1_path = first_file(os.path.join(model_path, "text_encoder"), text_encoder_model_names)
19
- text_encoder2_path = first_file(os.path.join(model_path, "text_encoder_2"), text_encoder_model_names)
20
-
21
- text_encoder_paths = [text_encoder1_path]
22
- if text_encoder2_path is not None:
23
- text_encoder_paths.append(text_encoder2_path)
24
-
25
- unet = comfy.sd.load_unet(unet_path)
26
-
27
- clip = None
28
- if output_clip:
29
- clip = comfy.sd.load_clip(text_encoder_paths, embedding_directory=embedding_directory)
30
-
31
- vae = None
32
- if output_vae:
33
- sd = comfy.utils.load_torch_file(vae_path)
34
- vae = comfy.sd.VAE(sd=sd)
35
-
36
- return (unet, clip, vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/extra_samplers/__pycache__/uni_pc.cpython-310.pyc DELETED
Binary file (28.5 kB)
 
MagicQuill/comfy/extra_samplers/uni_pc.py DELETED
@@ -1,875 +0,0 @@
1
- #code taken from: https://github.com/wl-zhao/UniPC and modified
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- import math
6
-
7
- from tqdm.auto import trange, tqdm
8
-
9
-
10
- class NoiseScheduleVP:
11
- def __init__(
12
- self,
13
- schedule='discrete',
14
- betas=None,
15
- alphas_cumprod=None,
16
- continuous_beta_0=0.1,
17
- continuous_beta_1=20.,
18
- ):
19
- """Create a wrapper class for the forward SDE (VP type).
20
-
21
- ***
22
- Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
23
- We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
24
- ***
25
-
26
- The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
27
- We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
28
- Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
29
-
30
- log_alpha_t = self.marginal_log_mean_coeff(t)
31
- sigma_t = self.marginal_std(t)
32
- lambda_t = self.marginal_lambda(t)
33
-
34
- Moreover, as lambda(t) is an invertible function, we also support its inverse function:
35
-
36
- t = self.inverse_lambda(lambda_t)
37
-
38
- ===============================================================
39
-
40
- We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
41
-
42
- 1. For discrete-time DPMs:
43
-
44
- For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
45
- t_i = (i + 1) / N
46
- e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
47
- We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
48
-
49
- Args:
50
- betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
51
- alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
52
-
53
- Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
54
-
55
- **Important**: Please pay special attention for the args for `alphas_cumprod`:
56
- The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
57
- q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
58
- Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
59
- alpha_{t_n} = \sqrt{\hat{alpha_n}},
60
- and
61
- log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
62
-
63
-
64
- 2. For continuous-time DPMs:
65
-
66
- We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
67
- schedule are the default settings in DDPM and improved-DDPM:
68
-
69
- Args:
70
- beta_min: A `float` number. The smallest beta for the linear schedule.
71
- beta_max: A `float` number. The largest beta for the linear schedule.
72
- cosine_s: A `float` number. The hyperparameter in the cosine schedule.
73
- cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
74
- T: A `float` number. The ending time of the forward process.
75
-
76
- ===============================================================
77
-
78
- Args:
79
- schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
80
- 'linear' or 'cosine' for continuous-time DPMs.
81
- Returns:
82
- A wrapper object of the forward SDE (VP type).
83
-
84
- ===============================================================
85
-
86
- Example:
87
-
88
- # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
89
- >>> ns = NoiseScheduleVP('discrete', betas=betas)
90
-
91
- # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
92
- >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
93
-
94
- # For continuous-time DPMs (VPSDE), linear schedule:
95
- >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
96
-
97
- """
98
-
99
- if schedule not in ['discrete', 'linear', 'cosine']:
100
- raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
101
-
102
- self.schedule = schedule
103
- if schedule == 'discrete':
104
- if betas is not None:
105
- log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
106
- else:
107
- assert alphas_cumprod is not None
108
- log_alphas = 0.5 * torch.log(alphas_cumprod)
109
- self.total_N = len(log_alphas)
110
- self.T = 1.
111
- self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1))
112
- self.log_alpha_array = log_alphas.reshape((1, -1,))
113
- else:
114
- self.total_N = 1000
115
- self.beta_0 = continuous_beta_0
116
- self.beta_1 = continuous_beta_1
117
- self.cosine_s = 0.008
118
- self.cosine_beta_max = 999.
119
- self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
120
- self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
121
- self.schedule = schedule
122
- if schedule == 'cosine':
123
- # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
124
- # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
125
- self.T = 0.9946
126
- else:
127
- self.T = 1.
128
-
129
- def marginal_log_mean_coeff(self, t):
130
- """
131
- Compute log(alpha_t) of a given continuous-time label t in [0, T].
132
- """
133
- if self.schedule == 'discrete':
134
- return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
135
- elif self.schedule == 'linear':
136
- return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
137
- elif self.schedule == 'cosine':
138
- log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
139
- log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
140
- return log_alpha_t
141
-
142
- def marginal_alpha(self, t):
143
- """
144
- Compute alpha_t of a given continuous-time label t in [0, T].
145
- """
146
- return torch.exp(self.marginal_log_mean_coeff(t))
147
-
148
- def marginal_std(self, t):
149
- """
150
- Compute sigma_t of a given continuous-time label t in [0, T].
151
- """
152
- return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
153
-
154
- def marginal_lambda(self, t):
155
- """
156
- Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
157
- """
158
- log_mean_coeff = self.marginal_log_mean_coeff(t)
159
- log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
160
- return log_mean_coeff - log_std
161
-
162
- def inverse_lambda(self, lamb):
163
- """
164
- Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
165
- """
166
- if self.schedule == 'linear':
167
- tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
168
- Delta = self.beta_0**2 + tmp
169
- return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
170
- elif self.schedule == 'discrete':
171
- log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
172
- t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
173
- return t.reshape((-1,))
174
- else:
175
- log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
176
- t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
177
- t = t_fn(log_alpha)
178
- return t
179
-
180
-
181
- def model_wrapper(
182
- model,
183
- noise_schedule,
184
- model_type="noise",
185
- model_kwargs={},
186
- guidance_type="uncond",
187
- condition=None,
188
- unconditional_condition=None,
189
- guidance_scale=1.,
190
- classifier_fn=None,
191
- classifier_kwargs={},
192
- ):
193
- """Create a wrapper function for the noise prediction model.
194
-
195
- DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
196
- firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
197
-
198
- We support four types of the diffusion model by setting `model_type`:
199
-
200
- 1. "noise": noise prediction model. (Trained by predicting noise).
201
-
202
- 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
203
-
204
- 3. "v": velocity prediction model. (Trained by predicting the velocity).
205
- The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
206
-
207
- [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
208
- arXiv preprint arXiv:2202.00512 (2022).
209
- [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
210
- arXiv preprint arXiv:2210.02303 (2022).
211
-
212
- 4. "score": marginal score function. (Trained by denoising score matching).
213
- Note that the score function and the noise prediction model follows a simple relationship:
214
- ```
215
- noise(x_t, t) = -sigma_t * score(x_t, t)
216
- ```
217
-
218
- We support three types of guided sampling by DPMs by setting `guidance_type`:
219
- 1. "uncond": unconditional sampling by DPMs.
220
- The input `model` has the following format:
221
- ``
222
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
223
- ``
224
-
225
- 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
226
- The input `model` has the following format:
227
- ``
228
- model(x, t_input, **model_kwargs) -> noise | x_start | v | score
229
- ``
230
-
231
- The input `classifier_fn` has the following format:
232
- ``
233
- classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
234
- ``
235
-
236
- [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
237
- in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
238
-
239
- 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
240
- The input `model` has the following format:
241
- ``
242
- model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
243
- ``
244
- And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
245
-
246
- [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
247
- arXiv preprint arXiv:2207.12598 (2022).
248
-
249
-
250
- The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
251
- or continuous-time labels (i.e. epsilon to T).
252
-
253
- We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
254
- ``
255
- def model_fn(x, t_continuous) -> noise:
256
- t_input = get_model_input_time(t_continuous)
257
- return noise_pred(model, x, t_input, **model_kwargs)
258
- ``
259
- where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
260
-
261
- ===============================================================
262
-
263
- Args:
264
- model: A diffusion model with the corresponding format described above.
265
- noise_schedule: A noise schedule object, such as NoiseScheduleVP.
266
- model_type: A `str`. The parameterization type of the diffusion model.
267
- "noise" or "x_start" or "v" or "score".
268
- model_kwargs: A `dict`. A dict for the other inputs of the model function.
269
- guidance_type: A `str`. The type of the guidance for sampling.
270
- "uncond" or "classifier" or "classifier-free".
271
- condition: A pytorch tensor. The condition for the guided sampling.
272
- Only used for "classifier" or "classifier-free" guidance type.
273
- unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
274
- Only used for "classifier-free" guidance type.
275
- guidance_scale: A `float`. The scale for the guided sampling.
276
- classifier_fn: A classifier function. Only used for the classifier guidance.
277
- classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
278
- Returns:
279
- A noise prediction model that accepts the noised data and the continuous time as the inputs.
280
- """
281
-
282
- def get_model_input_time(t_continuous):
283
- """
284
- Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
285
- For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
286
- For continuous-time DPMs, we just use `t_continuous`.
287
- """
288
- if noise_schedule.schedule == 'discrete':
289
- return (t_continuous - 1. / noise_schedule.total_N) * 1000.
290
- else:
291
- return t_continuous
292
-
293
- def noise_pred_fn(x, t_continuous, cond=None):
294
- if t_continuous.reshape((-1,)).shape[0] == 1:
295
- t_continuous = t_continuous.expand((x.shape[0]))
296
- t_input = get_model_input_time(t_continuous)
297
- output = model(x, t_input, **model_kwargs)
298
- if model_type == "noise":
299
- return output
300
- elif model_type == "x_start":
301
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
302
- dims = x.dim()
303
- return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
304
- elif model_type == "v":
305
- alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
306
- dims = x.dim()
307
- return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
308
- elif model_type == "score":
309
- sigma_t = noise_schedule.marginal_std(t_continuous)
310
- dims = x.dim()
311
- return -expand_dims(sigma_t, dims) * output
312
-
313
- def cond_grad_fn(x, t_input):
314
- """
315
- Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
316
- """
317
- with torch.enable_grad():
318
- x_in = x.detach().requires_grad_(True)
319
- log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
320
- return torch.autograd.grad(log_prob.sum(), x_in)[0]
321
-
322
- def model_fn(x, t_continuous):
323
- """
324
- The noise predicition model function that is used for DPM-Solver.
325
- """
326
- if t_continuous.reshape((-1,)).shape[0] == 1:
327
- t_continuous = t_continuous.expand((x.shape[0]))
328
- if guidance_type == "uncond":
329
- return noise_pred_fn(x, t_continuous)
330
- elif guidance_type == "classifier":
331
- assert classifier_fn is not None
332
- t_input = get_model_input_time(t_continuous)
333
- cond_grad = cond_grad_fn(x, t_input)
334
- sigma_t = noise_schedule.marginal_std(t_continuous)
335
- noise = noise_pred_fn(x, t_continuous)
336
- return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
337
- elif guidance_type == "classifier-free":
338
- if guidance_scale == 1. or unconditional_condition is None:
339
- return noise_pred_fn(x, t_continuous, cond=condition)
340
- else:
341
- x_in = torch.cat([x] * 2)
342
- t_in = torch.cat([t_continuous] * 2)
343
- c_in = torch.cat([unconditional_condition, condition])
344
- noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
345
- return noise_uncond + guidance_scale * (noise - noise_uncond)
346
-
347
- assert model_type in ["noise", "x_start", "v"]
348
- assert guidance_type in ["uncond", "classifier", "classifier-free"]
349
- return model_fn
350
-
351
-
352
- class UniPC:
353
- def __init__(
354
- self,
355
- model_fn,
356
- noise_schedule,
357
- predict_x0=True,
358
- thresholding=False,
359
- max_val=1.,
360
- variant='bh1',
361
- ):
362
- """Construct a UniPC.
363
-
364
- We support both data_prediction and noise_prediction.
365
- """
366
- self.model = model_fn
367
- self.noise_schedule = noise_schedule
368
- self.variant = variant
369
- self.predict_x0 = predict_x0
370
- self.thresholding = thresholding
371
- self.max_val = max_val
372
-
373
- def dynamic_thresholding_fn(self, x0, t=None):
374
- """
375
- The dynamic thresholding method.
376
- """
377
- dims = x0.dim()
378
- p = self.dynamic_thresholding_ratio
379
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
380
- s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
381
- x0 = torch.clamp(x0, -s, s) / s
382
- return x0
383
-
384
- def noise_prediction_fn(self, x, t):
385
- """
386
- Return the noise prediction model.
387
- """
388
- return self.model(x, t)
389
-
390
- def data_prediction_fn(self, x, t):
391
- """
392
- Return the data prediction model (with thresholding).
393
- """
394
- noise = self.noise_prediction_fn(x, t)
395
- dims = x.dim()
396
- alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
397
- x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
398
- if self.thresholding:
399
- p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
400
- s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
401
- s = expand_dims(torch.maximum(s, self.max_val * torch.ones_like(s).to(s.device)), dims)
402
- x0 = torch.clamp(x0, -s, s) / s
403
- return x0
404
-
405
- def model_fn(self, x, t):
406
- """
407
- Convert the model to the noise prediction model or the data prediction model.
408
- """
409
- if self.predict_x0:
410
- return self.data_prediction_fn(x, t)
411
- else:
412
- return self.noise_prediction_fn(x, t)
413
-
414
- def get_time_steps(self, skip_type, t_T, t_0, N, device):
415
- """Compute the intermediate time steps for sampling.
416
- """
417
- if skip_type == 'logSNR':
418
- lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
419
- lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
420
- logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
421
- return self.noise_schedule.inverse_lambda(logSNR_steps)
422
- elif skip_type == 'time_uniform':
423
- return torch.linspace(t_T, t_0, N + 1).to(device)
424
- elif skip_type == 'time_quadratic':
425
- t_order = 2
426
- t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
427
- return t
428
- else:
429
- raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
430
-
431
- def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
432
- """
433
- Get the order of each step for sampling by the singlestep DPM-Solver.
434
- """
435
- if order == 3:
436
- K = steps // 3 + 1
437
- if steps % 3 == 0:
438
- orders = [3,] * (K - 2) + [2, 1]
439
- elif steps % 3 == 1:
440
- orders = [3,] * (K - 1) + [1]
441
- else:
442
- orders = [3,] * (K - 1) + [2]
443
- elif order == 2:
444
- if steps % 2 == 0:
445
- K = steps // 2
446
- orders = [2,] * K
447
- else:
448
- K = steps // 2 + 1
449
- orders = [2,] * (K - 1) + [1]
450
- elif order == 1:
451
- K = steps
452
- orders = [1,] * steps
453
- else:
454
- raise ValueError("'order' must be '1' or '2' or '3'.")
455
- if skip_type == 'logSNR':
456
- # To reproduce the results in DPM-Solver paper
457
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
458
- else:
459
- timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
460
- return timesteps_outer, orders
461
-
462
- def denoise_to_zero_fn(self, x, s):
463
- """
464
- Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
465
- """
466
- return self.data_prediction_fn(x, s)
467
-
468
- def multistep_uni_pc_update(self, x, model_prev_list, t_prev_list, t, order, **kwargs):
469
- if len(t.shape) == 0:
470
- t = t.view(-1)
471
- if 'bh' in self.variant:
472
- return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
473
- else:
474
- assert self.variant == 'vary_coeff'
475
- return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
476
-
477
- def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
478
- print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
479
- ns = self.noise_schedule
480
- assert order <= len(model_prev_list)
481
-
482
- # first compute rks
483
- t_prev_0 = t_prev_list[-1]
484
- lambda_prev_0 = ns.marginal_lambda(t_prev_0)
485
- lambda_t = ns.marginal_lambda(t)
486
- model_prev_0 = model_prev_list[-1]
487
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
488
- log_alpha_t = ns.marginal_log_mean_coeff(t)
489
- alpha_t = torch.exp(log_alpha_t)
490
-
491
- h = lambda_t - lambda_prev_0
492
-
493
- rks = []
494
- D1s = []
495
- for i in range(1, order):
496
- t_prev_i = t_prev_list[-(i + 1)]
497
- model_prev_i = model_prev_list[-(i + 1)]
498
- lambda_prev_i = ns.marginal_lambda(t_prev_i)
499
- rk = (lambda_prev_i - lambda_prev_0) / h
500
- rks.append(rk)
501
- D1s.append((model_prev_i - model_prev_0) / rk)
502
-
503
- rks.append(1.)
504
- rks = torch.tensor(rks, device=x.device)
505
-
506
- K = len(rks)
507
- # build C matrix
508
- C = []
509
-
510
- col = torch.ones_like(rks)
511
- for k in range(1, K + 1):
512
- C.append(col)
513
- col = col * rks / (k + 1)
514
- C = torch.stack(C, dim=1)
515
-
516
- if len(D1s) > 0:
517
- D1s = torch.stack(D1s, dim=1) # (B, K)
518
- C_inv_p = torch.linalg.inv(C[:-1, :-1])
519
- A_p = C_inv_p
520
-
521
- if use_corrector:
522
- print('using corrector')
523
- C_inv = torch.linalg.inv(C)
524
- A_c = C_inv
525
-
526
- hh = -h if self.predict_x0 else h
527
- h_phi_1 = torch.expm1(hh)
528
- h_phi_ks = []
529
- factorial_k = 1
530
- h_phi_k = h_phi_1
531
- for k in range(1, K + 2):
532
- h_phi_ks.append(h_phi_k)
533
- h_phi_k = h_phi_k / hh - 1 / factorial_k
534
- factorial_k *= (k + 1)
535
-
536
- model_t = None
537
- if self.predict_x0:
538
- x_t_ = (
539
- sigma_t / sigma_prev_0 * x
540
- - alpha_t * h_phi_1 * model_prev_0
541
- )
542
- # now predictor
543
- x_t = x_t_
544
- if len(D1s) > 0:
545
- # compute the residuals for predictor
546
- for k in range(K - 1):
547
- x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
548
- # now corrector
549
- if use_corrector:
550
- model_t = self.model_fn(x_t, t)
551
- D1_t = (model_t - model_prev_0)
552
- x_t = x_t_
553
- k = 0
554
- for k in range(K - 1):
555
- x_t = x_t - alpha_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
556
- x_t = x_t - alpha_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
557
- else:
558
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
559
- x_t_ = (
560
- (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
561
- - (sigma_t * h_phi_1) * model_prev_0
562
- )
563
- # now predictor
564
- x_t = x_t_
565
- if len(D1s) > 0:
566
- # compute the residuals for predictor
567
- for k in range(K - 1):
568
- x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
569
- # now corrector
570
- if use_corrector:
571
- model_t = self.model_fn(x_t, t)
572
- D1_t = (model_t - model_prev_0)
573
- x_t = x_t_
574
- k = 0
575
- for k in range(K - 1):
576
- x_t = x_t - sigma_t * h_phi_ks[k + 1] * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
577
- x_t = x_t - sigma_t * h_phi_ks[K] * (D1_t * A_c[k][-1])
578
- return x_t, model_t
579
-
580
- def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
581
- # print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
582
- ns = self.noise_schedule
583
- assert order <= len(model_prev_list)
584
- dims = x.dim()
585
-
586
- # first compute rks
587
- t_prev_0 = t_prev_list[-1]
588
- lambda_prev_0 = ns.marginal_lambda(t_prev_0)
589
- lambda_t = ns.marginal_lambda(t)
590
- model_prev_0 = model_prev_list[-1]
591
- sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
592
- log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
593
- alpha_t = torch.exp(log_alpha_t)
594
-
595
- h = lambda_t - lambda_prev_0
596
-
597
- rks = []
598
- D1s = []
599
- for i in range(1, order):
600
- t_prev_i = t_prev_list[-(i + 1)]
601
- model_prev_i = model_prev_list[-(i + 1)]
602
- lambda_prev_i = ns.marginal_lambda(t_prev_i)
603
- rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
604
- rks.append(rk)
605
- D1s.append((model_prev_i - model_prev_0) / rk)
606
-
607
- rks.append(1.)
608
- rks = torch.tensor(rks, device=x.device)
609
-
610
- R = []
611
- b = []
612
-
613
- hh = -h[0] if self.predict_x0 else h[0]
614
- h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
615
- h_phi_k = h_phi_1 / hh - 1
616
-
617
- factorial_i = 1
618
-
619
- if self.variant == 'bh1':
620
- B_h = hh
621
- elif self.variant == 'bh2':
622
- B_h = torch.expm1(hh)
623
- else:
624
- raise NotImplementedError()
625
-
626
- for i in range(1, order + 1):
627
- R.append(torch.pow(rks, i - 1))
628
- b.append(h_phi_k * factorial_i / B_h)
629
- factorial_i *= (i + 1)
630
- h_phi_k = h_phi_k / hh - 1 / factorial_i
631
-
632
- R = torch.stack(R)
633
- b = torch.tensor(b, device=x.device)
634
-
635
- # now predictor
636
- use_predictor = len(D1s) > 0 and x_t is None
637
- if len(D1s) > 0:
638
- D1s = torch.stack(D1s, dim=1) # (B, K)
639
- if x_t is None:
640
- # for order 2, we use a simplified version
641
- if order == 2:
642
- rhos_p = torch.tensor([0.5], device=b.device)
643
- else:
644
- rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
645
- else:
646
- D1s = None
647
-
648
- if use_corrector:
649
- # print('using corrector')
650
- # for order 1, we use a simplified version
651
- if order == 1:
652
- rhos_c = torch.tensor([0.5], device=b.device)
653
- else:
654
- rhos_c = torch.linalg.solve(R, b)
655
-
656
- model_t = None
657
- if self.predict_x0:
658
- x_t_ = (
659
- expand_dims(sigma_t / sigma_prev_0, dims) * x
660
- - expand_dims(alpha_t * h_phi_1, dims)* model_prev_0
661
- )
662
-
663
- if x_t is None:
664
- if use_predictor:
665
- pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
666
- else:
667
- pred_res = 0
668
- x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
669
-
670
- if use_corrector:
671
- model_t = self.model_fn(x_t, t)
672
- if D1s is not None:
673
- corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
674
- else:
675
- corr_res = 0
676
- D1_t = (model_t - model_prev_0)
677
- x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
678
- else:
679
- x_t_ = (
680
- expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
681
- - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
682
- )
683
- if x_t is None:
684
- if use_predictor:
685
- pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
686
- else:
687
- pred_res = 0
688
- x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
689
-
690
- if use_corrector:
691
- model_t = self.model_fn(x_t, t)
692
- if D1s is not None:
693
- corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
694
- else:
695
- corr_res = 0
696
- D1_t = (model_t - model_prev_0)
697
- x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
698
- return x_t, model_t
699
-
700
-
701
- def sample(self, x, timesteps, t_start=None, t_end=None, order=3, skip_type='time_uniform',
702
- method='singlestep', lower_order_final=True, denoise_to_zero=False, solver_type='dpm_solver',
703
- atol=0.0078, rtol=0.05, corrector=False, callback=None, disable_pbar=False
704
- ):
705
- # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
706
- # t_T = self.noise_schedule.T if t_start is None else t_start
707
- device = x.device
708
- steps = len(timesteps) - 1
709
- if method == 'multistep':
710
- assert steps >= order
711
- # timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
712
- assert timesteps.shape[0] - 1 == steps
713
- # with torch.no_grad():
714
- for step_index in trange(steps, disable=disable_pbar):
715
- if step_index == 0:
716
- vec_t = timesteps[0].expand((x.shape[0]))
717
- model_prev_list = [self.model_fn(x, vec_t)]
718
- t_prev_list = [vec_t]
719
- elif step_index < order:
720
- init_order = step_index
721
- # Init the first `order` values by lower order multistep DPM-Solver.
722
- # for init_order in range(1, order):
723
- vec_t = timesteps[init_order].expand(x.shape[0])
724
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
725
- if model_x is None:
726
- model_x = self.model_fn(x, vec_t)
727
- model_prev_list.append(model_x)
728
- t_prev_list.append(vec_t)
729
- else:
730
- extra_final_step = 0
731
- if step_index == (steps - 1):
732
- extra_final_step = 1
733
- for step in range(step_index, step_index + 1 + extra_final_step):
734
- vec_t = timesteps[step].expand(x.shape[0])
735
- if lower_order_final:
736
- step_order = min(order, steps + 1 - step)
737
- else:
738
- step_order = order
739
- # print('this step order:', step_order)
740
- if step == steps:
741
- # print('do not run corrector at the last step')
742
- use_corrector = False
743
- else:
744
- use_corrector = True
745
- x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
746
- for i in range(order - 1):
747
- t_prev_list[i] = t_prev_list[i + 1]
748
- model_prev_list[i] = model_prev_list[i + 1]
749
- t_prev_list[-1] = vec_t
750
- # We do not need to evaluate the final model value.
751
- if step < steps:
752
- if model_x is None:
753
- model_x = self.model_fn(x, vec_t)
754
- model_prev_list[-1] = model_x
755
- if callback is not None:
756
- callback({'x': x, 'i': step_index, 'denoised': model_prev_list[-1]})
757
- else:
758
- raise NotImplementedError()
759
- # if denoise_to_zero:
760
- # x = self.denoise_to_zero_fn(x, torch.ones((x.shape[0],)).to(device) * t_0)
761
- return x
762
-
763
-
764
- #############################################################
765
- # other utility functions
766
- #############################################################
767
-
768
- def interpolate_fn(x, xp, yp):
769
- """
770
- A piecewise linear function y = f(x), using xp and yp as keypoints.
771
- We implement f(x) in a differentiable way (i.e. applicable for autograd).
772
- The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
773
-
774
- Args:
775
- x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
776
- xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
777
- yp: PyTorch tensor with shape [C, K].
778
- Returns:
779
- The function values f(x), with shape [N, C].
780
- """
781
- N, K = x.shape[0], xp.shape[1]
782
- all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
783
- sorted_all_x, x_indices = torch.sort(all_x, dim=2)
784
- x_idx = torch.argmin(x_indices, dim=2)
785
- cand_start_idx = x_idx - 1
786
- start_idx = torch.where(
787
- torch.eq(x_idx, 0),
788
- torch.tensor(1, device=x.device),
789
- torch.where(
790
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
791
- ),
792
- )
793
- end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
794
- start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
795
- end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
796
- start_idx2 = torch.where(
797
- torch.eq(x_idx, 0),
798
- torch.tensor(0, device=x.device),
799
- torch.where(
800
- torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
801
- ),
802
- )
803
- y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
804
- start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
805
- end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
806
- cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
807
- return cand
808
-
809
-
810
- def expand_dims(v, dims):
811
- """
812
- Expand the tensor `v` to the dim `dims`.
813
-
814
- Args:
815
- `v`: a PyTorch tensor with shape [N].
816
- `dim`: a `int`.
817
- Returns:
818
- a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
819
- """
820
- return v[(...,) + (None,)*(dims - 1)]
821
-
822
-
823
- class SigmaConvert:
824
- schedule = ""
825
- def marginal_log_mean_coeff(self, sigma):
826
- return 0.5 * torch.log(1 / ((sigma * sigma) + 1))
827
-
828
- def marginal_alpha(self, t):
829
- return torch.exp(self.marginal_log_mean_coeff(t))
830
-
831
- def marginal_std(self, t):
832
- return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
833
-
834
- def marginal_lambda(self, t):
835
- """
836
- Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
837
- """
838
- log_mean_coeff = self.marginal_log_mean_coeff(t)
839
- log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
840
- return log_mean_coeff - log_std
841
-
842
- def predict_eps_sigma(model, input, sigma_in, **kwargs):
843
- sigma = sigma_in.view(sigma_in.shape[:1] + (1,) * (input.ndim - 1))
844
- input = input * ((sigma ** 2 + 1.0) ** 0.5)
845
- return (input - model(input, sigma_in, **kwargs)) / sigma
846
-
847
-
848
- def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
849
- timesteps = sigmas.clone()
850
- if sigmas[-1] == 0:
851
- timesteps = sigmas[:]
852
- timesteps[-1] = 0.001
853
- else:
854
- timesteps = sigmas.clone()
855
- ns = SigmaConvert()
856
-
857
- noise = noise / torch.sqrt(1.0 + timesteps[0] ** 2.0)
858
- model_type = "noise"
859
-
860
- model_fn = model_wrapper(
861
- lambda input, sigma, **kwargs: predict_eps_sigma(model, input, sigma, **kwargs),
862
- ns,
863
- model_type=model_type,
864
- guidance_type="uncond",
865
- model_kwargs=extra_args,
866
- )
867
-
868
- order = min(3, len(timesteps) - 2)
869
- uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=variant)
870
- x = uni_pc.sample(noise, timesteps=timesteps, skip_type="time_uniform", method="multistep", order=order, lower_order_final=True, callback=callback, disable_pbar=disable)
871
- x /= ns.marginal_alpha(timesteps[-1])
872
- return x
873
-
874
- def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
875
- return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/gligen.py DELETED
@@ -1,343 +0,0 @@
1
- import torch
2
- from torch import nn
3
- from .ldm.modules.attention import CrossAttention
4
- from inspect import isfunction
5
- import comfy.ops
6
- ops = comfy.ops.manual_cast
7
-
8
- def exists(val):
9
- return val is not None
10
-
11
-
12
- def uniq(arr):
13
- return{el: True for el in arr}.keys()
14
-
15
-
16
- def default(val, d):
17
- if exists(val):
18
- return val
19
- return d() if isfunction(d) else d
20
-
21
-
22
- # feedforward
23
- class GEGLU(nn.Module):
24
- def __init__(self, dim_in, dim_out):
25
- super().__init__()
26
- self.proj = ops.Linear(dim_in, dim_out * 2)
27
-
28
- def forward(self, x):
29
- x, gate = self.proj(x).chunk(2, dim=-1)
30
- return x * torch.nn.functional.gelu(gate)
31
-
32
-
33
- class FeedForward(nn.Module):
34
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
35
- super().__init__()
36
- inner_dim = int(dim * mult)
37
- dim_out = default(dim_out, dim)
38
- project_in = nn.Sequential(
39
- ops.Linear(dim, inner_dim),
40
- nn.GELU()
41
- ) if not glu else GEGLU(dim, inner_dim)
42
-
43
- self.net = nn.Sequential(
44
- project_in,
45
- nn.Dropout(dropout),
46
- ops.Linear(inner_dim, dim_out)
47
- )
48
-
49
- def forward(self, x):
50
- return self.net(x)
51
-
52
-
53
- class GatedCrossAttentionDense(nn.Module):
54
- def __init__(self, query_dim, context_dim, n_heads, d_head):
55
- super().__init__()
56
-
57
- self.attn = CrossAttention(
58
- query_dim=query_dim,
59
- context_dim=context_dim,
60
- heads=n_heads,
61
- dim_head=d_head,
62
- operations=ops)
63
- self.ff = FeedForward(query_dim, glu=True)
64
-
65
- self.norm1 = ops.LayerNorm(query_dim)
66
- self.norm2 = ops.LayerNorm(query_dim)
67
-
68
- self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
69
- self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
70
-
71
- # this can be useful: we can externally change magnitude of tanh(alpha)
72
- # for example, when it is set to 0, then the entire model is same as
73
- # original one
74
- self.scale = 1
75
-
76
- def forward(self, x, objs):
77
-
78
- x = x + self.scale * \
79
- torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs)
80
- x = x + self.scale * \
81
- torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
82
-
83
- return x
84
-
85
-
86
- class GatedSelfAttentionDense(nn.Module):
87
- def __init__(self, query_dim, context_dim, n_heads, d_head):
88
- super().__init__()
89
-
90
- # we need a linear projection since we need cat visual feature and obj
91
- # feature
92
- self.linear = ops.Linear(context_dim, query_dim)
93
-
94
- self.attn = CrossAttention(
95
- query_dim=query_dim,
96
- context_dim=query_dim,
97
- heads=n_heads,
98
- dim_head=d_head,
99
- operations=ops)
100
- self.ff = FeedForward(query_dim, glu=True)
101
-
102
- self.norm1 = ops.LayerNorm(query_dim)
103
- self.norm2 = ops.LayerNorm(query_dim)
104
-
105
- self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
106
- self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
107
-
108
- # this can be useful: we can externally change magnitude of tanh(alpha)
109
- # for example, when it is set to 0, then the entire model is same as
110
- # original one
111
- self.scale = 1
112
-
113
- def forward(self, x, objs):
114
-
115
- N_visual = x.shape[1]
116
- objs = self.linear(objs)
117
-
118
- x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn(
119
- self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :]
120
- x = x + self.scale * \
121
- torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
122
-
123
- return x
124
-
125
-
126
- class GatedSelfAttentionDense2(nn.Module):
127
- def __init__(self, query_dim, context_dim, n_heads, d_head):
128
- super().__init__()
129
-
130
- # we need a linear projection since we need cat visual feature and obj
131
- # feature
132
- self.linear = ops.Linear(context_dim, query_dim)
133
-
134
- self.attn = CrossAttention(
135
- query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops)
136
- self.ff = FeedForward(query_dim, glu=True)
137
-
138
- self.norm1 = ops.LayerNorm(query_dim)
139
- self.norm2 = ops.LayerNorm(query_dim)
140
-
141
- self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.)))
142
- self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.)))
143
-
144
- # this can be useful: we can externally change magnitude of tanh(alpha)
145
- # for example, when it is set to 0, then the entire model is same as
146
- # original one
147
- self.scale = 1
148
-
149
- def forward(self, x, objs):
150
-
151
- B, N_visual, _ = x.shape
152
- B, N_ground, _ = objs.shape
153
-
154
- objs = self.linear(objs)
155
-
156
- # sanity check
157
- size_v = math.sqrt(N_visual)
158
- size_g = math.sqrt(N_ground)
159
- assert int(size_v) == size_v, "Visual tokens must be square rootable"
160
- assert int(size_g) == size_g, "Grounding tokens must be square rootable"
161
- size_v = int(size_v)
162
- size_g = int(size_g)
163
-
164
- # select grounding token and resize it to visual token size as residual
165
- out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[
166
- :, N_visual:, :]
167
- out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g)
168
- out = torch.nn.functional.interpolate(
169
- out, (size_v, size_v), mode='bicubic')
170
- residual = out.reshape(B, -1, N_visual).permute(0, 2, 1)
171
-
172
- # add residual to visual feature
173
- x = x + self.scale * torch.tanh(self.alpha_attn) * residual
174
- x = x + self.scale * \
175
- torch.tanh(self.alpha_dense) * self.ff(self.norm2(x))
176
-
177
- return x
178
-
179
-
180
- class FourierEmbedder():
181
- def __init__(self, num_freqs=64, temperature=100):
182
-
183
- self.num_freqs = num_freqs
184
- self.temperature = temperature
185
- self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
186
-
187
- @torch.no_grad()
188
- def __call__(self, x, cat_dim=-1):
189
- "x: arbitrary shape of tensor. dim: cat dim"
190
- out = []
191
- for freq in self.freq_bands:
192
- out.append(torch.sin(freq * x))
193
- out.append(torch.cos(freq * x))
194
- return torch.cat(out, cat_dim)
195
-
196
-
197
- class PositionNet(nn.Module):
198
- def __init__(self, in_dim, out_dim, fourier_freqs=8):
199
- super().__init__()
200
- self.in_dim = in_dim
201
- self.out_dim = out_dim
202
-
203
- self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
204
- self.position_dim = fourier_freqs * 2 * 4 # 2 is sin&cos, 4 is xyxy
205
-
206
- self.linears = nn.Sequential(
207
- ops.Linear(self.in_dim + self.position_dim, 512),
208
- nn.SiLU(),
209
- ops.Linear(512, 512),
210
- nn.SiLU(),
211
- ops.Linear(512, out_dim),
212
- )
213
-
214
- self.null_positive_feature = torch.nn.Parameter(
215
- torch.zeros([self.in_dim]))
216
- self.null_position_feature = torch.nn.Parameter(
217
- torch.zeros([self.position_dim]))
218
-
219
- def forward(self, boxes, masks, positive_embeddings):
220
- B, N, _ = boxes.shape
221
- masks = masks.unsqueeze(-1)
222
- positive_embeddings = positive_embeddings
223
-
224
- # embedding position (it may includes padding as placeholder)
225
- xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 --> B*N*C
226
-
227
- # learnable null embedding
228
- positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
229
- xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1)
230
-
231
- # replace padding with learnable null embedding
232
- positive_embeddings = positive_embeddings * \
233
- masks + (1 - masks) * positive_null
234
- xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
235
-
236
- objs = self.linears(
237
- torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
238
- assert objs.shape == torch.Size([B, N, self.out_dim])
239
- return objs
240
-
241
-
242
- class Gligen(nn.Module):
243
- def __init__(self, modules, position_net, key_dim):
244
- super().__init__()
245
- self.module_list = nn.ModuleList(modules)
246
- self.position_net = position_net
247
- self.key_dim = key_dim
248
- self.max_objs = 30
249
- self.current_device = torch.device("cpu")
250
-
251
- def _set_position(self, boxes, masks, positive_embeddings):
252
- objs = self.position_net(boxes, masks, positive_embeddings)
253
- def func(x, extra_options):
254
- key = extra_options["transformer_index"]
255
- module = self.module_list[key]
256
- return module(x, objs.to(device=x.device, dtype=x.dtype))
257
- return func
258
-
259
- def set_position(self, latent_image_shape, position_params, device):
260
- batch, c, h, w = latent_image_shape
261
- masks = torch.zeros([self.max_objs], device="cpu")
262
- boxes = []
263
- positive_embeddings = []
264
- for p in position_params:
265
- x1 = (p[4]) / w
266
- y1 = (p[3]) / h
267
- x2 = (p[4] + p[2]) / w
268
- y2 = (p[3] + p[1]) / h
269
- masks[len(boxes)] = 1.0
270
- boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)]
271
- positive_embeddings += [p[0]]
272
- append_boxes = []
273
- append_conds = []
274
- if len(boxes) < self.max_objs:
275
- append_boxes = [torch.zeros(
276
- [self.max_objs - len(boxes), 4], device="cpu")]
277
- append_conds = [torch.zeros(
278
- [self.max_objs - len(boxes), self.key_dim], device="cpu")]
279
-
280
- box_out = torch.cat(
281
- boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1)
282
- masks = masks.unsqueeze(0).repeat(batch, 1)
283
- conds = torch.cat(positive_embeddings +
284
- append_conds).unsqueeze(0).repeat(batch, 1, 1)
285
- return self._set_position(
286
- box_out.to(device),
287
- masks.to(device),
288
- conds.to(device))
289
-
290
- def set_empty(self, latent_image_shape, device):
291
- batch, c, h, w = latent_image_shape
292
- masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1)
293
- box_out = torch.zeros([self.max_objs, 4],
294
- device="cpu").repeat(batch, 1, 1)
295
- conds = torch.zeros([self.max_objs, self.key_dim],
296
- device="cpu").repeat(batch, 1, 1)
297
- return self._set_position(
298
- box_out.to(device),
299
- masks.to(device),
300
- conds.to(device))
301
-
302
-
303
- def load_gligen(sd):
304
- sd_k = sd.keys()
305
- output_list = []
306
- key_dim = 768
307
- for a in ["input_blocks", "middle_block", "output_blocks"]:
308
- for b in range(20):
309
- k_temp = filter(lambda k: "{}.{}.".format(a, b)
310
- in k and ".fuser." in k, sd_k)
311
- k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp)
312
-
313
- n_sd = {}
314
- for k in k_temp:
315
- n_sd[k[1]] = sd[k[0]]
316
- if len(n_sd) > 0:
317
- query_dim = n_sd["linear.weight"].shape[0]
318
- key_dim = n_sd["linear.weight"].shape[1]
319
-
320
- if key_dim == 768: # SD1.x
321
- n_heads = 8
322
- d_head = query_dim // n_heads
323
- else:
324
- d_head = 64
325
- n_heads = query_dim // d_head
326
-
327
- gated = GatedSelfAttentionDense(
328
- query_dim, key_dim, n_heads, d_head)
329
- gated.load_state_dict(n_sd, strict=False)
330
- output_list.append(gated)
331
-
332
- if "position_net.null_positive_feature" in sd_k:
333
- in_dim = sd["position_net.null_positive_feature"].shape[0]
334
- out_dim = sd["position_net.linears.4.weight"].shape[0]
335
-
336
- class WeightsLoader(torch.nn.Module):
337
- pass
338
- w = WeightsLoader()
339
- w.position_net = PositionNet(in_dim, out_dim)
340
- w.load_state_dict(sd, strict=False)
341
-
342
- gligen = Gligen(output_list, w.position_net, key_dim)
343
- return gligen
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/k_diffusion/__pycache__/sampling.cpython-310.pyc DELETED
Binary file (28.2 kB)
 
MagicQuill/comfy/k_diffusion/__pycache__/utils.cpython-310.pyc DELETED
Binary file (14 kB)
 
MagicQuill/comfy/k_diffusion/sampling.py DELETED
@@ -1,843 +0,0 @@
1
- import math
2
-
3
- from scipy import integrate
4
- import torch
5
- from torch import nn
6
- import torchsde
7
- from tqdm.auto import trange, tqdm
8
-
9
- from . import utils
10
-
11
-
12
- def append_zero(x):
13
- return torch.cat([x, x.new_zeros([1])])
14
-
15
-
16
- def get_sigmas_karras(n, sigma_min, sigma_max, rho=7., device='cpu'):
17
- """Constructs the noise schedule of Karras et al. (2022)."""
18
- ramp = torch.linspace(0, 1, n, device=device)
19
- min_inv_rho = sigma_min ** (1 / rho)
20
- max_inv_rho = sigma_max ** (1 / rho)
21
- sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
22
- return append_zero(sigmas).to(device)
23
-
24
-
25
- def get_sigmas_exponential(n, sigma_min, sigma_max, device='cpu'):
26
- """Constructs an exponential noise schedule."""
27
- sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), n, device=device).exp()
28
- return append_zero(sigmas)
29
-
30
-
31
- def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
32
- """Constructs an polynomial in log sigma noise schedule."""
33
- ramp = torch.linspace(1, 0, n, device=device) ** rho
34
- sigmas = torch.exp(ramp * (math.log(sigma_max) - math.log(sigma_min)) + math.log(sigma_min))
35
- return append_zero(sigmas)
36
-
37
-
38
- def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
39
- """Constructs a continuous VP noise schedule."""
40
- t = torch.linspace(1, eps_s, n, device=device)
41
- sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
42
- return append_zero(sigmas)
43
-
44
-
45
- def to_d(x, sigma, denoised):
46
- """Converts a denoiser output to a Karras ODE derivative."""
47
- return (x - denoised) / utils.append_dims(sigma, x.ndim)
48
-
49
-
50
- def get_ancestral_step(sigma_from, sigma_to, eta=1.):
51
- """Calculates the noise level (sigma_down) to step down to and the amount
52
- of noise to add (sigma_up) when doing an ancestral sampling step."""
53
- if not eta:
54
- return sigma_to, 0.
55
- sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5)
56
- sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5
57
- return sigma_down, sigma_up
58
-
59
-
60
- def default_noise_sampler(x):
61
- return lambda sigma, sigma_next: torch.randn_like(x)
62
-
63
-
64
- class BatchedBrownianTree:
65
- """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
66
-
67
- def __init__(self, x, t0, t1, seed=None, **kwargs):
68
- self.cpu_tree = True
69
- if "cpu" in kwargs:
70
- self.cpu_tree = kwargs.pop("cpu")
71
- t0, t1, self.sign = self.sort(t0, t1)
72
- w0 = kwargs.get('w0', torch.zeros_like(x))
73
- if seed is None:
74
- seed = torch.randint(0, 2 ** 63 - 1, []).item()
75
- self.batched = True
76
- try:
77
- assert len(seed) == x.shape[0]
78
- w0 = w0[0]
79
- except TypeError:
80
- seed = [seed]
81
- self.batched = False
82
- if self.cpu_tree:
83
- self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
84
- else:
85
- self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
86
-
87
- @staticmethod
88
- def sort(a, b):
89
- return (a, b, 1) if a < b else (b, a, -1)
90
-
91
- def __call__(self, t0, t1):
92
- t0, t1, sign = self.sort(t0, t1)
93
- if self.cpu_tree:
94
- w = torch.stack([tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]) * (self.sign * sign)
95
- else:
96
- w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
97
-
98
- return w if self.batched else w[0]
99
-
100
-
101
- class BrownianTreeNoiseSampler:
102
- """A noise sampler backed by a torchsde.BrownianTree.
103
-
104
- Args:
105
- x (Tensor): The tensor whose shape, device and dtype to use to generate
106
- random samples.
107
- sigma_min (float): The low end of the valid interval.
108
- sigma_max (float): The high end of the valid interval.
109
- seed (int or List[int]): The random seed. If a list of seeds is
110
- supplied instead of a single integer, then the noise sampler will
111
- use one BrownianTree per batch item, each with its own seed.
112
- transform (callable): A function that maps sigma to the sampler's
113
- internal timestep.
114
- """
115
-
116
- def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
117
- self.transform = transform
118
- t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
119
- self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
120
-
121
- def __call__(self, sigma, sigma_next):
122
- t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
123
- return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
124
-
125
-
126
- @torch.no_grad()
127
- def sample_euler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
128
- """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
129
- extra_args = {} if extra_args is None else extra_args
130
- s_in = x.new_ones([x.shape[0]])
131
- for i in trange(len(sigmas) - 1, disable=disable):
132
- if s_churn > 0:
133
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
134
- sigma_hat = sigmas[i] * (gamma + 1)
135
- else:
136
- gamma = 0
137
- sigma_hat = sigmas[i]
138
-
139
- if gamma > 0:
140
- eps = torch.randn_like(x) * s_noise
141
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
142
- denoised = model(x, sigma_hat * s_in, **extra_args)
143
- d = to_d(x, sigma_hat, denoised)
144
- if callback is not None:
145
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
146
- dt = sigmas[i + 1] - sigma_hat
147
- # Euler method
148
- x = x + d * dt
149
- return x
150
-
151
-
152
- @torch.no_grad()
153
- def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
154
- """Ancestral sampling with Euler method steps."""
155
- extra_args = {} if extra_args is None else extra_args
156
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
157
- s_in = x.new_ones([x.shape[0]])
158
- for i in trange(len(sigmas) - 1, disable=disable):
159
- denoised = model(x, sigmas[i] * s_in, **extra_args)
160
- sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
161
- if callback is not None:
162
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
163
- d = to_d(x, sigmas[i], denoised)
164
- # Euler method
165
- dt = sigma_down - sigmas[i]
166
- x = x + d * dt
167
- if sigmas[i + 1] > 0:
168
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
169
- return x
170
-
171
-
172
- @torch.no_grad()
173
- def sample_heun(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
174
- """Implements Algorithm 2 (Heun steps) from Karras et al. (2022)."""
175
- extra_args = {} if extra_args is None else extra_args
176
- s_in = x.new_ones([x.shape[0]])
177
- for i in trange(len(sigmas) - 1, disable=disable):
178
- if s_churn > 0:
179
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
180
- sigma_hat = sigmas[i] * (gamma + 1)
181
- else:
182
- gamma = 0
183
- sigma_hat = sigmas[i]
184
-
185
- sigma_hat = sigmas[i] * (gamma + 1)
186
- if gamma > 0:
187
- eps = torch.randn_like(x) * s_noise
188
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
189
- denoised = model(x, sigma_hat * s_in, **extra_args)
190
- d = to_d(x, sigma_hat, denoised)
191
- if callback is not None:
192
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
193
- dt = sigmas[i + 1] - sigma_hat
194
- if sigmas[i + 1] == 0:
195
- # Euler method
196
- x = x + d * dt
197
- else:
198
- # Heun's method
199
- x_2 = x + d * dt
200
- denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
201
- d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
202
- d_prime = (d + d_2) / 2
203
- x = x + d_prime * dt
204
- return x
205
-
206
-
207
- @torch.no_grad()
208
- def sample_dpm_2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
209
- """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022)."""
210
- extra_args = {} if extra_args is None else extra_args
211
- s_in = x.new_ones([x.shape[0]])
212
- for i in trange(len(sigmas) - 1, disable=disable):
213
- if s_churn > 0:
214
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
215
- sigma_hat = sigmas[i] * (gamma + 1)
216
- else:
217
- gamma = 0
218
- sigma_hat = sigmas[i]
219
-
220
- if gamma > 0:
221
- eps = torch.randn_like(x) * s_noise
222
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
223
- denoised = model(x, sigma_hat * s_in, **extra_args)
224
- d = to_d(x, sigma_hat, denoised)
225
- if callback is not None:
226
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
227
- if sigmas[i + 1] == 0:
228
- # Euler method
229
- dt = sigmas[i + 1] - sigma_hat
230
- x = x + d * dt
231
- else:
232
- # DPM-Solver-2
233
- sigma_mid = sigma_hat.log().lerp(sigmas[i + 1].log(), 0.5).exp()
234
- dt_1 = sigma_mid - sigma_hat
235
- dt_2 = sigmas[i + 1] - sigma_hat
236
- x_2 = x + d * dt_1
237
- denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
238
- d_2 = to_d(x_2, sigma_mid, denoised_2)
239
- x = x + d_2 * dt_2
240
- return x
241
-
242
-
243
- @torch.no_grad()
244
- def sample_dpm_2_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
245
- """Ancestral sampling with DPM-Solver second-order steps."""
246
- extra_args = {} if extra_args is None else extra_args
247
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
248
- s_in = x.new_ones([x.shape[0]])
249
- for i in trange(len(sigmas) - 1, disable=disable):
250
- denoised = model(x, sigmas[i] * s_in, **extra_args)
251
- sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
252
- if callback is not None:
253
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
254
- d = to_d(x, sigmas[i], denoised)
255
- if sigma_down == 0:
256
- # Euler method
257
- dt = sigma_down - sigmas[i]
258
- x = x + d * dt
259
- else:
260
- # DPM-Solver-2
261
- sigma_mid = sigmas[i].log().lerp(sigma_down.log(), 0.5).exp()
262
- dt_1 = sigma_mid - sigmas[i]
263
- dt_2 = sigma_down - sigmas[i]
264
- x_2 = x + d * dt_1
265
- denoised_2 = model(x_2, sigma_mid * s_in, **extra_args)
266
- d_2 = to_d(x_2, sigma_mid, denoised_2)
267
- x = x + d_2 * dt_2
268
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
269
- return x
270
-
271
-
272
- def linear_multistep_coeff(order, t, i, j):
273
- if order - 1 > i:
274
- raise ValueError(f'Order {order} too high for step {i}')
275
- def fn(tau):
276
- prod = 1.
277
- for k in range(order):
278
- if j == k:
279
- continue
280
- prod *= (tau - t[i - k]) / (t[i - j] - t[i - k])
281
- return prod
282
- return integrate.quad(fn, t[i], t[i + 1], epsrel=1e-4)[0]
283
-
284
-
285
- @torch.no_grad()
286
- def sample_lms(model, x, sigmas, extra_args=None, callback=None, disable=None, order=4):
287
- extra_args = {} if extra_args is None else extra_args
288
- s_in = x.new_ones([x.shape[0]])
289
- sigmas_cpu = sigmas.detach().cpu().numpy()
290
- ds = []
291
- for i in trange(len(sigmas) - 1, disable=disable):
292
- denoised = model(x, sigmas[i] * s_in, **extra_args)
293
- d = to_d(x, sigmas[i], denoised)
294
- ds.append(d)
295
- if len(ds) > order:
296
- ds.pop(0)
297
- if callback is not None:
298
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
299
- cur_order = min(i + 1, order)
300
- coeffs = [linear_multistep_coeff(cur_order, sigmas_cpu, i, j) for j in range(cur_order)]
301
- x = x + sum(coeff * d for coeff, d in zip(coeffs, reversed(ds)))
302
- return x
303
-
304
-
305
- class PIDStepSizeController:
306
- """A PID controller for ODE adaptive step size control."""
307
- def __init__(self, h, pcoeff, icoeff, dcoeff, order=1, accept_safety=0.81, eps=1e-8):
308
- self.h = h
309
- self.b1 = (pcoeff + icoeff + dcoeff) / order
310
- self.b2 = -(pcoeff + 2 * dcoeff) / order
311
- self.b3 = dcoeff / order
312
- self.accept_safety = accept_safety
313
- self.eps = eps
314
- self.errs = []
315
-
316
- def limiter(self, x):
317
- return 1 + math.atan(x - 1)
318
-
319
- def propose_step(self, error):
320
- inv_error = 1 / (float(error) + self.eps)
321
- if not self.errs:
322
- self.errs = [inv_error, inv_error, inv_error]
323
- self.errs[0] = inv_error
324
- factor = self.errs[0] ** self.b1 * self.errs[1] ** self.b2 * self.errs[2] ** self.b3
325
- factor = self.limiter(factor)
326
- accept = factor >= self.accept_safety
327
- if accept:
328
- self.errs[2] = self.errs[1]
329
- self.errs[1] = self.errs[0]
330
- self.h *= factor
331
- return accept
332
-
333
-
334
- class DPMSolver(nn.Module):
335
- """DPM-Solver. See https://arxiv.org/abs/2206.00927."""
336
-
337
- def __init__(self, model, extra_args=None, eps_callback=None, info_callback=None):
338
- super().__init__()
339
- self.model = model
340
- self.extra_args = {} if extra_args is None else extra_args
341
- self.eps_callback = eps_callback
342
- self.info_callback = info_callback
343
-
344
- def t(self, sigma):
345
- return -sigma.log()
346
-
347
- def sigma(self, t):
348
- return t.neg().exp()
349
-
350
- def eps(self, eps_cache, key, x, t, *args, **kwargs):
351
- if key in eps_cache:
352
- return eps_cache[key], eps_cache
353
- sigma = self.sigma(t) * x.new_ones([x.shape[0]])
354
- eps = (x - self.model(x, sigma, *args, **self.extra_args, **kwargs)) / self.sigma(t)
355
- if self.eps_callback is not None:
356
- self.eps_callback()
357
- return eps, {key: eps, **eps_cache}
358
-
359
- def dpm_solver_1_step(self, x, t, t_next, eps_cache=None):
360
- eps_cache = {} if eps_cache is None else eps_cache
361
- h = t_next - t
362
- eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
363
- x_1 = x - self.sigma(t_next) * h.expm1() * eps
364
- return x_1, eps_cache
365
-
366
- def dpm_solver_2_step(self, x, t, t_next, r1=1 / 2, eps_cache=None):
367
- eps_cache = {} if eps_cache is None else eps_cache
368
- h = t_next - t
369
- eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
370
- s1 = t + r1 * h
371
- u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
372
- eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
373
- x_2 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / (2 * r1) * h.expm1() * (eps_r1 - eps)
374
- return x_2, eps_cache
375
-
376
- def dpm_solver_3_step(self, x, t, t_next, r1=1 / 3, r2=2 / 3, eps_cache=None):
377
- eps_cache = {} if eps_cache is None else eps_cache
378
- h = t_next - t
379
- eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
380
- s1 = t + r1 * h
381
- s2 = t + r2 * h
382
- u1 = x - self.sigma(s1) * (r1 * h).expm1() * eps
383
- eps_r1, eps_cache = self.eps(eps_cache, 'eps_r1', u1, s1)
384
- u2 = x - self.sigma(s2) * (r2 * h).expm1() * eps - self.sigma(s2) * (r2 / r1) * ((r2 * h).expm1() / (r2 * h) - 1) * (eps_r1 - eps)
385
- eps_r2, eps_cache = self.eps(eps_cache, 'eps_r2', u2, s2)
386
- x_3 = x - self.sigma(t_next) * h.expm1() * eps - self.sigma(t_next) / r2 * (h.expm1() / h - 1) * (eps_r2 - eps)
387
- return x_3, eps_cache
388
-
389
- def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1., noise_sampler=None):
390
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
391
- if not t_end > t_start and eta:
392
- raise ValueError('eta must be 0 for reverse sampling')
393
-
394
- m = math.floor(nfe / 3) + 1
395
- ts = torch.linspace(t_start, t_end, m + 1, device=x.device)
396
-
397
- if nfe % 3 == 0:
398
- orders = [3] * (m - 2) + [2, 1]
399
- else:
400
- orders = [3] * (m - 1) + [nfe % 3]
401
-
402
- for i in range(len(orders)):
403
- eps_cache = {}
404
- t, t_next = ts[i], ts[i + 1]
405
- if eta:
406
- sd, su = get_ancestral_step(self.sigma(t), self.sigma(t_next), eta)
407
- t_next_ = torch.minimum(t_end, self.t(sd))
408
- su = (self.sigma(t_next) ** 2 - self.sigma(t_next_) ** 2) ** 0.5
409
- else:
410
- t_next_, su = t_next, 0.
411
-
412
- eps, eps_cache = self.eps(eps_cache, 'eps', x, t)
413
- denoised = x - self.sigma(t) * eps
414
- if self.info_callback is not None:
415
- self.info_callback({'x': x, 'i': i, 't': ts[i], 't_up': t, 'denoised': denoised})
416
-
417
- if orders[i] == 1:
418
- x, eps_cache = self.dpm_solver_1_step(x, t, t_next_, eps_cache=eps_cache)
419
- elif orders[i] == 2:
420
- x, eps_cache = self.dpm_solver_2_step(x, t, t_next_, eps_cache=eps_cache)
421
- else:
422
- x, eps_cache = self.dpm_solver_3_step(x, t, t_next_, eps_cache=eps_cache)
423
-
424
- x = x + su * s_noise * noise_sampler(self.sigma(t), self.sigma(t_next))
425
-
426
- return x
427
-
428
- def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None):
429
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
430
- if order not in {2, 3}:
431
- raise ValueError('order should be 2 or 3')
432
- forward = t_end > t_start
433
- if not forward and eta:
434
- raise ValueError('eta must be 0 for reverse sampling')
435
- h_init = abs(h_init) * (1 if forward else -1)
436
- atol = torch.tensor(atol)
437
- rtol = torch.tensor(rtol)
438
- s = t_start
439
- x_prev = x
440
- accept = True
441
- pid = PIDStepSizeController(h_init, pcoeff, icoeff, dcoeff, 1.5 if eta else order, accept_safety)
442
- info = {'steps': 0, 'nfe': 0, 'n_accept': 0, 'n_reject': 0}
443
-
444
- while s < t_end - 1e-5 if forward else s > t_end + 1e-5:
445
- eps_cache = {}
446
- t = torch.minimum(t_end, s + pid.h) if forward else torch.maximum(t_end, s + pid.h)
447
- if eta:
448
- sd, su = get_ancestral_step(self.sigma(s), self.sigma(t), eta)
449
- t_ = torch.minimum(t_end, self.t(sd))
450
- su = (self.sigma(t) ** 2 - self.sigma(t_) ** 2) ** 0.5
451
- else:
452
- t_, su = t, 0.
453
-
454
- eps, eps_cache = self.eps(eps_cache, 'eps', x, s)
455
- denoised = x - self.sigma(s) * eps
456
-
457
- if order == 2:
458
- x_low, eps_cache = self.dpm_solver_1_step(x, s, t_, eps_cache=eps_cache)
459
- x_high, eps_cache = self.dpm_solver_2_step(x, s, t_, eps_cache=eps_cache)
460
- else:
461
- x_low, eps_cache = self.dpm_solver_2_step(x, s, t_, r1=1 / 3, eps_cache=eps_cache)
462
- x_high, eps_cache = self.dpm_solver_3_step(x, s, t_, eps_cache=eps_cache)
463
- delta = torch.maximum(atol, rtol * torch.maximum(x_low.abs(), x_prev.abs()))
464
- error = torch.linalg.norm((x_low - x_high) / delta) / x.numel() ** 0.5
465
- accept = pid.propose_step(error)
466
- if accept:
467
- x_prev = x_low
468
- x = x_high + su * s_noise * noise_sampler(self.sigma(s), self.sigma(t))
469
- s = t
470
- info['n_accept'] += 1
471
- else:
472
- info['n_reject'] += 1
473
- info['nfe'] += order
474
- info['steps'] += 1
475
-
476
- if self.info_callback is not None:
477
- self.info_callback({'x': x, 'i': info['steps'] - 1, 't': s, 't_up': s, 'denoised': denoised, 'error': error, 'h': pid.h, **info})
478
-
479
- return x, info
480
-
481
-
482
- @torch.no_grad()
483
- def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1., noise_sampler=None):
484
- """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
485
- if sigma_min <= 0 or sigma_max <= 0:
486
- raise ValueError('sigma_min and sigma_max must not be 0')
487
- with tqdm(total=n, disable=disable) as pbar:
488
- dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
489
- if callback is not None:
490
- dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
491
- return dpm_solver.dpm_solver_fast(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), n, eta, s_noise, noise_sampler)
492
-
493
-
494
- @torch.no_grad()
495
- def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, eta=0., s_noise=1., noise_sampler=None, return_info=False):
496
- """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
497
- if sigma_min <= 0 or sigma_max <= 0:
498
- raise ValueError('sigma_min and sigma_max must not be 0')
499
- with tqdm(disable=disable) as pbar:
500
- dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
501
- if callback is not None:
502
- dpm_solver.info_callback = lambda info: callback({'sigma': dpm_solver.sigma(info['t']), 'sigma_hat': dpm_solver.sigma(info['t_up']), **info})
503
- x, info = dpm_solver.dpm_solver_adaptive(x, dpm_solver.t(torch.tensor(sigma_max)), dpm_solver.t(torch.tensor(sigma_min)), order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise, noise_sampler)
504
- if return_info:
505
- return x, info
506
- return x
507
-
508
-
509
- @torch.no_grad()
510
- def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
511
- """Ancestral sampling with DPM-Solver++(2S) second-order steps."""
512
- extra_args = {} if extra_args is None else extra_args
513
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
514
- s_in = x.new_ones([x.shape[0]])
515
- sigma_fn = lambda t: t.neg().exp()
516
- t_fn = lambda sigma: sigma.log().neg()
517
-
518
- for i in trange(len(sigmas) - 1, disable=disable):
519
- denoised = model(x, sigmas[i] * s_in, **extra_args)
520
- sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
521
- if callback is not None:
522
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
523
- if sigma_down == 0:
524
- # Euler method
525
- d = to_d(x, sigmas[i], denoised)
526
- dt = sigma_down - sigmas[i]
527
- x = x + d * dt
528
- else:
529
- # DPM-Solver++(2S)
530
- t, t_next = t_fn(sigmas[i]), t_fn(sigma_down)
531
- r = 1 / 2
532
- h = t_next - t
533
- s = t + r * h
534
- x_2 = (sigma_fn(s) / sigma_fn(t)) * x - (-h * r).expm1() * denoised
535
- denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
536
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_2
537
- # Noise addition
538
- if sigmas[i + 1] > 0:
539
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
540
- return x
541
-
542
-
543
- @torch.no_grad()
544
- def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
545
- """DPM-Solver++ (stochastic)."""
546
- if len(sigmas) <= 1:
547
- return x
548
-
549
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
550
- seed = extra_args.get("seed", None)
551
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
552
- extra_args = {} if extra_args is None else extra_args
553
- s_in = x.new_ones([x.shape[0]])
554
- sigma_fn = lambda t: t.neg().exp()
555
- t_fn = lambda sigma: sigma.log().neg()
556
-
557
- for i in trange(len(sigmas) - 1, disable=disable):
558
- denoised = model(x, sigmas[i] * s_in, **extra_args)
559
- if callback is not None:
560
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
561
- if sigmas[i + 1] == 0:
562
- # Euler method
563
- d = to_d(x, sigmas[i], denoised)
564
- dt = sigmas[i + 1] - sigmas[i]
565
- x = x + d * dt
566
- else:
567
- # DPM-Solver++
568
- t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
569
- h = t_next - t
570
- s = t + h * r
571
- fac = 1 / (2 * r)
572
-
573
- # Step 1
574
- sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
575
- s_ = t_fn(sd)
576
- x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
577
- x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
578
- denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
579
-
580
- # Step 2
581
- sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
582
- t_next_ = t_fn(sd)
583
- denoised_d = (1 - fac) * denoised + fac * denoised_2
584
- x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
585
- x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
586
- return x
587
-
588
-
589
- @torch.no_grad()
590
- def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
591
- """DPM-Solver++(2M)."""
592
- extra_args = {} if extra_args is None else extra_args
593
- s_in = x.new_ones([x.shape[0]])
594
- sigma_fn = lambda t: t.neg().exp()
595
- t_fn = lambda sigma: sigma.log().neg()
596
- old_denoised = None
597
-
598
- for i in trange(len(sigmas) - 1, disable=disable):
599
- denoised = model(x, sigmas[i] * s_in, **extra_args)
600
- if callback is not None:
601
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
602
- t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
603
- h = t_next - t
604
- if old_denoised is None or sigmas[i + 1] == 0:
605
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
606
- else:
607
- h_last = t - t_fn(sigmas[i - 1])
608
- r = h_last / h
609
- denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
610
- x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
611
- old_denoised = denoised
612
- return x
613
-
614
- @torch.no_grad()
615
- def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
616
- """DPM-Solver++(2M) SDE."""
617
- if len(sigmas) <= 1:
618
- return x
619
-
620
- if solver_type not in {'heun', 'midpoint'}:
621
- raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
622
-
623
- seed = extra_args.get("seed", None)
624
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
625
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
626
- extra_args = {} if extra_args is None else extra_args
627
- s_in = x.new_ones([x.shape[0]])
628
-
629
- old_denoised = None
630
- h_last = None
631
- h = None
632
-
633
- for i in trange(len(sigmas) - 1, disable=disable):
634
- denoised = model(x, sigmas[i] * s_in, **extra_args)
635
- if callback is not None:
636
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
637
- if sigmas[i + 1] == 0:
638
- # Denoising step
639
- x = denoised
640
- else:
641
- # DPM-Solver++(2M) SDE
642
- t, s = -sigmas[i].log(), -sigmas[i + 1].log()
643
- h = s - t
644
- eta_h = eta * h
645
-
646
- x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + (-h - eta_h).expm1().neg() * denoised
647
-
648
- if old_denoised is not None:
649
- r = h_last / h
650
- if solver_type == 'heun':
651
- x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * (1 / r) * (denoised - old_denoised)
652
- elif solver_type == 'midpoint':
653
- x = x + 0.5 * (-h - eta_h).expm1().neg() * (1 / r) * (denoised - old_denoised)
654
-
655
- if eta:
656
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise
657
-
658
- old_denoised = denoised
659
- h_last = h
660
- return x
661
-
662
- @torch.no_grad()
663
- def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
664
- """DPM-Solver++(3M) SDE."""
665
-
666
- if len(sigmas) <= 1:
667
- return x
668
-
669
- seed = extra_args.get("seed", None)
670
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
671
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
672
- extra_args = {} if extra_args is None else extra_args
673
- s_in = x.new_ones([x.shape[0]])
674
-
675
- denoised_1, denoised_2 = None, None
676
- h, h_1, h_2 = None, None, None
677
-
678
- for i in trange(len(sigmas) - 1, disable=disable):
679
- denoised = model(x, sigmas[i] * s_in, **extra_args)
680
- if callback is not None:
681
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
682
- if sigmas[i + 1] == 0:
683
- # Denoising step
684
- x = denoised
685
- else:
686
- t, s = -sigmas[i].log(), -sigmas[i + 1].log()
687
- h = s - t
688
- h_eta = h * (eta + 1)
689
-
690
- x = torch.exp(-h_eta) * x + (-h_eta).expm1().neg() * denoised
691
-
692
- if h_2 is not None:
693
- r0 = h_1 / h
694
- r1 = h_2 / h
695
- d1_0 = (denoised - denoised_1) / r0
696
- d1_1 = (denoised_1 - denoised_2) / r1
697
- d1 = d1_0 + (d1_0 - d1_1) * r0 / (r0 + r1)
698
- d2 = (d1_0 - d1_1) / (r0 + r1)
699
- phi_2 = h_eta.neg().expm1() / h_eta + 1
700
- phi_3 = phi_2 / h_eta - 0.5
701
- x = x + phi_2 * d1 - phi_3 * d2
702
- elif h_1 is not None:
703
- r = h_1 / h
704
- d = (denoised - denoised_1) / r
705
- phi_2 = h_eta.neg().expm1() / h_eta + 1
706
- x = x + phi_2 * d
707
-
708
- if eta:
709
- x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * h * eta).expm1().neg().sqrt() * s_noise
710
-
711
- denoised_1, denoised_2 = denoised, denoised_1
712
- h_1, h_2 = h, h_1
713
- return x
714
-
715
- @torch.no_grad()
716
- def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
717
- if len(sigmas) <= 1:
718
- return x
719
-
720
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
721
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
722
- return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
723
-
724
- @torch.no_grad()
725
- def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
726
- if len(sigmas) <= 1:
727
- return x
728
-
729
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
730
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
731
- return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
732
-
733
- @torch.no_grad()
734
- def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
735
- if len(sigmas) <= 1:
736
- return x
737
-
738
- sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
739
- noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
740
- return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
741
-
742
-
743
- def DDPMSampler_step(x, sigma, sigma_prev, noise, noise_sampler):
744
- alpha_cumprod = 1 / ((sigma * sigma) + 1)
745
- alpha_cumprod_prev = 1 / ((sigma_prev * sigma_prev) + 1)
746
- alpha = (alpha_cumprod / alpha_cumprod_prev)
747
-
748
- mu = (1.0 / alpha).sqrt() * (x - (1 - alpha) * noise / (1 - alpha_cumprod).sqrt())
749
- if sigma_prev > 0:
750
- mu += ((1 - alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * noise_sampler(sigma, sigma_prev)
751
- return mu
752
-
753
- def generic_step_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, step_function=None):
754
- extra_args = {} if extra_args is None else extra_args
755
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
756
- s_in = x.new_ones([x.shape[0]])
757
-
758
- for i in trange(len(sigmas) - 1, disable=disable):
759
- denoised = model(x, sigmas[i] * s_in, **extra_args)
760
- if callback is not None:
761
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
762
- x = step_function(x / torch.sqrt(1.0 + sigmas[i] ** 2.0), sigmas[i], sigmas[i + 1], (x - denoised) / sigmas[i], noise_sampler)
763
- if sigmas[i + 1] != 0:
764
- x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2.0)
765
- return x
766
-
767
-
768
- @torch.no_grad()
769
- def sample_ddpm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
770
- return generic_step_sampler(model, x, sigmas, extra_args, callback, disable, noise_sampler, DDPMSampler_step)
771
-
772
- @torch.no_grad()
773
- def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
774
- extra_args = {} if extra_args is None else extra_args
775
- noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
776
- s_in = x.new_ones([x.shape[0]])
777
- for i in trange(len(sigmas) - 1, disable=disable):
778
- denoised = model(x, sigmas[i] * s_in, **extra_args)
779
- if callback is not None:
780
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
781
-
782
- x = denoised
783
- if sigmas[i + 1] > 0:
784
- x = model.inner_model.inner_model.model_sampling.noise_scaling(sigmas[i + 1], noise_sampler(sigmas[i], sigmas[i + 1]), x)
785
- return x
786
-
787
-
788
-
789
- @torch.no_grad()
790
- def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
791
- # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
792
- extra_args = {} if extra_args is None else extra_args
793
- s_in = x.new_ones([x.shape[0]])
794
- s_end = sigmas[-1]
795
- for i in trange(len(sigmas) - 1, disable=disable):
796
- gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0.
797
- eps = torch.randn_like(x) * s_noise
798
- sigma_hat = sigmas[i] * (gamma + 1)
799
- if gamma > 0:
800
- x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5
801
- denoised = model(x, sigma_hat * s_in, **extra_args)
802
- d = to_d(x, sigma_hat, denoised)
803
- if callback is not None:
804
- callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised})
805
- dt = sigmas[i + 1] - sigma_hat
806
- if sigmas[i + 1] == s_end:
807
- # Euler method
808
- x = x + d * dt
809
- elif sigmas[i + 2] == s_end:
810
-
811
- # Heun's method
812
- x_2 = x + d * dt
813
- denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
814
- d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
815
-
816
- w = 2 * sigmas[0]
817
- w2 = sigmas[i+1]/w
818
- w1 = 1 - w2
819
-
820
- d_prime = d * w1 + d_2 * w2
821
-
822
-
823
- x = x + d_prime * dt
824
-
825
- else:
826
- # Heun++
827
- x_2 = x + d * dt
828
- denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args)
829
- d_2 = to_d(x_2, sigmas[i + 1], denoised_2)
830
- dt_2 = sigmas[i + 2] - sigmas[i + 1]
831
-
832
- x_3 = x_2 + d_2 * dt_2
833
- denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args)
834
- d_3 = to_d(x_3, sigmas[i + 2], denoised_3)
835
-
836
- w = 3 * sigmas[0]
837
- w2 = sigmas[i + 1] / w
838
- w3 = sigmas[i + 2] / w
839
- w1 = 1 - w2 - w3
840
-
841
- d_prime = w1 * d + w2 * d_2 + w3 * d_3
842
- x = x + d_prime * dt
843
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/k_diffusion/utils.py DELETED
@@ -1,313 +0,0 @@
1
- from contextlib import contextmanager
2
- import hashlib
3
- import math
4
- from pathlib import Path
5
- import shutil
6
- import urllib
7
- import warnings
8
-
9
- from PIL import Image
10
- import torch
11
- from torch import nn, optim
12
- from torch.utils import data
13
-
14
-
15
- def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
16
- """Apply passed in transforms for HuggingFace Datasets."""
17
- images = [transform(image.convert(mode)) for image in examples[image_key]]
18
- return {image_key: images}
19
-
20
-
21
- def append_dims(x, target_dims):
22
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
23
- dims_to_append = target_dims - x.ndim
24
- if dims_to_append < 0:
25
- raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
26
- expanded = x[(...,) + (None,) * dims_to_append]
27
- # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
28
- # https://github.com/pytorch/pytorch/issues/84364
29
- return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
30
-
31
-
32
- def n_params(module):
33
- """Returns the number of trainable parameters in a module."""
34
- return sum(p.numel() for p in module.parameters())
35
-
36
-
37
- def download_file(path, url, digest=None):
38
- """Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
39
- path = Path(path)
40
- path.parent.mkdir(parents=True, exist_ok=True)
41
- if not path.exists():
42
- with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
43
- shutil.copyfileobj(response, f)
44
- if digest is not None:
45
- file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
46
- if digest != file_digest:
47
- raise OSError(f'hash of {path} (url: {url}) failed to validate')
48
- return path
49
-
50
-
51
- @contextmanager
52
- def train_mode(model, mode=True):
53
- """A context manager that places a model into training mode and restores
54
- the previous mode on exit."""
55
- modes = [module.training for module in model.modules()]
56
- try:
57
- yield model.train(mode)
58
- finally:
59
- for i, module in enumerate(model.modules()):
60
- module.training = modes[i]
61
-
62
-
63
- def eval_mode(model):
64
- """A context manager that places a model into evaluation mode and restores
65
- the previous mode on exit."""
66
- return train_mode(model, False)
67
-
68
-
69
- @torch.no_grad()
70
- def ema_update(model, averaged_model, decay):
71
- """Incorporates updated model parameters into an exponential moving averaged
72
- version of a model. It should be called after each optimizer step."""
73
- model_params = dict(model.named_parameters())
74
- averaged_params = dict(averaged_model.named_parameters())
75
- assert model_params.keys() == averaged_params.keys()
76
-
77
- for name, param in model_params.items():
78
- averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
79
-
80
- model_buffers = dict(model.named_buffers())
81
- averaged_buffers = dict(averaged_model.named_buffers())
82
- assert model_buffers.keys() == averaged_buffers.keys()
83
-
84
- for name, buf in model_buffers.items():
85
- averaged_buffers[name].copy_(buf)
86
-
87
-
88
- class EMAWarmup:
89
- """Implements an EMA warmup using an inverse decay schedule.
90
- If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
91
- good values for models you plan to train for a million or more steps (reaches decay
92
- factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
93
- you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
94
- 215.4k steps).
95
- Args:
96
- inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
97
- power (float): Exponential factor of EMA warmup. Default: 1.
98
- min_value (float): The minimum EMA decay rate. Default: 0.
99
- max_value (float): The maximum EMA decay rate. Default: 1.
100
- start_at (int): The epoch to start averaging at. Default: 0.
101
- last_epoch (int): The index of last epoch. Default: 0.
102
- """
103
-
104
- def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
105
- last_epoch=0):
106
- self.inv_gamma = inv_gamma
107
- self.power = power
108
- self.min_value = min_value
109
- self.max_value = max_value
110
- self.start_at = start_at
111
- self.last_epoch = last_epoch
112
-
113
- def state_dict(self):
114
- """Returns the state of the class as a :class:`dict`."""
115
- return dict(self.__dict__.items())
116
-
117
- def load_state_dict(self, state_dict):
118
- """Loads the class's state.
119
- Args:
120
- state_dict (dict): scaler state. Should be an object returned
121
- from a call to :meth:`state_dict`.
122
- """
123
- self.__dict__.update(state_dict)
124
-
125
- def get_value(self):
126
- """Gets the current EMA decay rate."""
127
- epoch = max(0, self.last_epoch - self.start_at)
128
- value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
129
- return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
130
-
131
- def step(self):
132
- """Updates the step count."""
133
- self.last_epoch += 1
134
-
135
-
136
- class InverseLR(optim.lr_scheduler._LRScheduler):
137
- """Implements an inverse decay learning rate schedule with an optional exponential
138
- warmup. When last_epoch=-1, sets initial lr as lr.
139
- inv_gamma is the number of steps/epochs required for the learning rate to decay to
140
- (1 / 2)**power of its original value.
141
- Args:
142
- optimizer (Optimizer): Wrapped optimizer.
143
- inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
144
- power (float): Exponential factor of learning rate decay. Default: 1.
145
- warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
146
- Default: 0.
147
- min_lr (float): The minimum learning rate. Default: 0.
148
- last_epoch (int): The index of last epoch. Default: -1.
149
- verbose (bool): If ``True``, prints a message to stdout for
150
- each update. Default: ``False``.
151
- """
152
-
153
- def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
154
- last_epoch=-1, verbose=False):
155
- self.inv_gamma = inv_gamma
156
- self.power = power
157
- if not 0. <= warmup < 1:
158
- raise ValueError('Invalid value for warmup')
159
- self.warmup = warmup
160
- self.min_lr = min_lr
161
- super().__init__(optimizer, last_epoch, verbose)
162
-
163
- def get_lr(self):
164
- if not self._get_lr_called_within_step:
165
- warnings.warn("To get the last learning rate computed by the scheduler, "
166
- "please use `get_last_lr()`.")
167
-
168
- return self._get_closed_form_lr()
169
-
170
- def _get_closed_form_lr(self):
171
- warmup = 1 - self.warmup ** (self.last_epoch + 1)
172
- lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
173
- return [warmup * max(self.min_lr, base_lr * lr_mult)
174
- for base_lr in self.base_lrs]
175
-
176
-
177
- class ExponentialLR(optim.lr_scheduler._LRScheduler):
178
- """Implements an exponential learning rate schedule with an optional exponential
179
- warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
180
- continuously by decay (default 0.5) every num_steps steps.
181
- Args:
182
- optimizer (Optimizer): Wrapped optimizer.
183
- num_steps (float): The number of steps to decay the learning rate by decay in.
184
- decay (float): The factor by which to decay the learning rate every num_steps
185
- steps. Default: 0.5.
186
- warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
187
- Default: 0.
188
- min_lr (float): The minimum learning rate. Default: 0.
189
- last_epoch (int): The index of last epoch. Default: -1.
190
- verbose (bool): If ``True``, prints a message to stdout for
191
- each update. Default: ``False``.
192
- """
193
-
194
- def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
195
- last_epoch=-1, verbose=False):
196
- self.num_steps = num_steps
197
- self.decay = decay
198
- if not 0. <= warmup < 1:
199
- raise ValueError('Invalid value for warmup')
200
- self.warmup = warmup
201
- self.min_lr = min_lr
202
- super().__init__(optimizer, last_epoch, verbose)
203
-
204
- def get_lr(self):
205
- if not self._get_lr_called_within_step:
206
- warnings.warn("To get the last learning rate computed by the scheduler, "
207
- "please use `get_last_lr()`.")
208
-
209
- return self._get_closed_form_lr()
210
-
211
- def _get_closed_form_lr(self):
212
- warmup = 1 - self.warmup ** (self.last_epoch + 1)
213
- lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
214
- return [warmup * max(self.min_lr, base_lr * lr_mult)
215
- for base_lr in self.base_lrs]
216
-
217
-
218
- def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
219
- """Draws samples from an lognormal distribution."""
220
- return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
221
-
222
-
223
- def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
224
- """Draws samples from an optionally truncated log-logistic distribution."""
225
- min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
226
- max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
227
- min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
228
- max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
229
- u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
230
- return u.logit().mul(scale).add(loc).exp().to(dtype)
231
-
232
-
233
- def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
234
- """Draws samples from an log-uniform distribution."""
235
- min_value = math.log(min_value)
236
- max_value = math.log(max_value)
237
- return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
238
-
239
-
240
- def rand_v_diffusion(shape, sigma_data=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
241
- """Draws samples from a truncated v-diffusion training timestep distribution."""
242
- min_cdf = math.atan(min_value / sigma_data) * 2 / math.pi
243
- max_cdf = math.atan(max_value / sigma_data) * 2 / math.pi
244
- u = torch.rand(shape, device=device, dtype=dtype) * (max_cdf - min_cdf) + min_cdf
245
- return torch.tan(u * math.pi / 2) * sigma_data
246
-
247
-
248
- def rand_split_log_normal(shape, loc, scale_1, scale_2, device='cpu', dtype=torch.float32):
249
- """Draws samples from a split lognormal distribution."""
250
- n = torch.randn(shape, device=device, dtype=dtype).abs()
251
- u = torch.rand(shape, device=device, dtype=dtype)
252
- n_left = n * -scale_1 + loc
253
- n_right = n * scale_2 + loc
254
- ratio = scale_1 / (scale_1 + scale_2)
255
- return torch.where(u < ratio, n_left, n_right).exp()
256
-
257
-
258
- class FolderOfImages(data.Dataset):
259
- """Recursively finds all images in a directory. It does not support
260
- classes/targets."""
261
-
262
- IMG_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}
263
-
264
- def __init__(self, root, transform=None):
265
- super().__init__()
266
- self.root = Path(root)
267
- self.transform = nn.Identity() if transform is None else transform
268
- self.paths = sorted(path for path in self.root.rglob('*') if path.suffix.lower() in self.IMG_EXTENSIONS)
269
-
270
- def __repr__(self):
271
- return f'FolderOfImages(root="{self.root}", len: {len(self)})'
272
-
273
- def __len__(self):
274
- return len(self.paths)
275
-
276
- def __getitem__(self, key):
277
- path = self.paths[key]
278
- with open(path, 'rb') as f:
279
- image = Image.open(f).convert('RGB')
280
- image = self.transform(image)
281
- return image,
282
-
283
-
284
- class CSVLogger:
285
- def __init__(self, filename, columns):
286
- self.filename = Path(filename)
287
- self.columns = columns
288
- if self.filename.exists():
289
- self.file = open(self.filename, 'a')
290
- else:
291
- self.file = open(self.filename, 'w')
292
- self.write(*self.columns)
293
-
294
- def write(self, *args):
295
- print(*args, sep=',', file=self.file, flush=True)
296
-
297
-
298
- @contextmanager
299
- def tf32_mode(cudnn=None, matmul=None):
300
- """A context manager that sets whether TF32 is allowed on cuDNN or matmul."""
301
- cudnn_old = torch.backends.cudnn.allow_tf32
302
- matmul_old = torch.backends.cuda.matmul.allow_tf32
303
- try:
304
- if cudnn is not None:
305
- torch.backends.cudnn.allow_tf32 = cudnn
306
- if matmul is not None:
307
- torch.backends.cuda.matmul.allow_tf32 = matmul
308
- yield
309
- finally:
310
- if cudnn is not None:
311
- torch.backends.cudnn.allow_tf32 = cudnn_old
312
- if matmul is not None:
313
- torch.backends.cuda.matmul.allow_tf32 = matmul_old
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/latent_formats.py DELETED
@@ -1,141 +0,0 @@
1
- import torch
2
-
3
- class LatentFormat:
4
- scale_factor = 1.0
5
- latent_channels = 4
6
- latent_rgb_factors = None
7
- taesd_decoder_name = None
8
-
9
- def process_in(self, latent):
10
- return latent * self.scale_factor
11
-
12
- def process_out(self, latent):
13
- return latent / self.scale_factor
14
-
15
- class SD15(LatentFormat):
16
- def __init__(self, scale_factor=0.18215):
17
- self.scale_factor = scale_factor
18
- self.latent_rgb_factors = [
19
- # R G B
20
- [ 0.3512, 0.2297, 0.3227],
21
- [ 0.3250, 0.4974, 0.2350],
22
- [-0.2829, 0.1762, 0.2721],
23
- [-0.2120, -0.2616, -0.7177]
24
- ]
25
- self.taesd_decoder_name = "taesd_decoder"
26
-
27
- class SDXL(LatentFormat):
28
- scale_factor = 0.13025
29
-
30
- def __init__(self):
31
- self.latent_rgb_factors = [
32
- # R G B
33
- [ 0.3920, 0.4054, 0.4549],
34
- [-0.2634, -0.0196, 0.0653],
35
- [ 0.0568, 0.1687, -0.0755],
36
- [-0.3112, -0.2359, -0.2076]
37
- ]
38
- self.taesd_decoder_name = "taesdxl_decoder"
39
-
40
- class SDXL_Playground_2_5(LatentFormat):
41
- def __init__(self):
42
- self.scale_factor = 0.5
43
- self.latents_mean = torch.tensor([-1.6574, 1.886, -1.383, 2.5155]).view(1, 4, 1, 1)
44
- self.latents_std = torch.tensor([8.4927, 5.9022, 6.5498, 5.2299]).view(1, 4, 1, 1)
45
-
46
- self.latent_rgb_factors = [
47
- # R G B
48
- [ 0.3920, 0.4054, 0.4549],
49
- [-0.2634, -0.0196, 0.0653],
50
- [ 0.0568, 0.1687, -0.0755],
51
- [-0.3112, -0.2359, -0.2076]
52
- ]
53
- self.taesd_decoder_name = "taesdxl_decoder"
54
-
55
- def process_in(self, latent):
56
- latents_mean = self.latents_mean.to(latent.device, latent.dtype)
57
- latents_std = self.latents_std.to(latent.device, latent.dtype)
58
- return (latent - latents_mean) * self.scale_factor / latents_std
59
-
60
- def process_out(self, latent):
61
- latents_mean = self.latents_mean.to(latent.device, latent.dtype)
62
- latents_std = self.latents_std.to(latent.device, latent.dtype)
63
- return latent * latents_std / self.scale_factor + latents_mean
64
-
65
-
66
- class SD_X4(LatentFormat):
67
- def __init__(self):
68
- self.scale_factor = 0.08333
69
- self.latent_rgb_factors = [
70
- [-0.2340, -0.3863, -0.3257],
71
- [ 0.0994, 0.0885, -0.0908],
72
- [-0.2833, -0.2349, -0.3741],
73
- [ 0.2523, -0.0055, -0.1651]
74
- ]
75
-
76
- class SC_Prior(LatentFormat):
77
- latent_channels = 16
78
- def __init__(self):
79
- self.scale_factor = 1.0
80
- self.latent_rgb_factors = [
81
- [-0.0326, -0.0204, -0.0127],
82
- [-0.1592, -0.0427, 0.0216],
83
- [ 0.0873, 0.0638, -0.0020],
84
- [-0.0602, 0.0442, 0.1304],
85
- [ 0.0800, -0.0313, -0.1796],
86
- [-0.0810, -0.0638, -0.1581],
87
- [ 0.1791, 0.1180, 0.0967],
88
- [ 0.0740, 0.1416, 0.0432],
89
- [-0.1745, -0.1888, -0.1373],
90
- [ 0.2412, 0.1577, 0.0928],
91
- [ 0.1908, 0.0998, 0.0682],
92
- [ 0.0209, 0.0365, -0.0092],
93
- [ 0.0448, -0.0650, -0.1728],
94
- [-0.1658, -0.1045, -0.1308],
95
- [ 0.0542, 0.1545, 0.1325],
96
- [-0.0352, -0.1672, -0.2541]
97
- ]
98
-
99
- class SC_B(LatentFormat):
100
- def __init__(self):
101
- self.scale_factor = 1.0 / 0.43
102
- self.latent_rgb_factors = [
103
- [ 0.1121, 0.2006, 0.1023],
104
- [-0.2093, -0.0222, -0.0195],
105
- [-0.3087, -0.1535, 0.0366],
106
- [ 0.0290, -0.1574, -0.4078]
107
- ]
108
-
109
- class SD3(LatentFormat):
110
- latent_channels = 16
111
- def __init__(self):
112
- self.scale_factor = 1.5305
113
- self.shift_factor = 0.0609
114
- self.latent_rgb_factors = [
115
- [-0.0645, 0.0177, 0.1052],
116
- [ 0.0028, 0.0312, 0.0650],
117
- [ 0.1848, 0.0762, 0.0360],
118
- [ 0.0944, 0.0360, 0.0889],
119
- [ 0.0897, 0.0506, -0.0364],
120
- [-0.0020, 0.1203, 0.0284],
121
- [ 0.0855, 0.0118, 0.0283],
122
- [-0.0539, 0.0658, 0.1047],
123
- [-0.0057, 0.0116, 0.0700],
124
- [-0.0412, 0.0281, -0.0039],
125
- [ 0.1106, 0.1171, 0.1220],
126
- [-0.0248, 0.0682, -0.0481],
127
- [ 0.0815, 0.0846, 0.1207],
128
- [-0.0120, -0.0055, -0.0867],
129
- [-0.0749, -0.0634, -0.0456],
130
- [-0.1418, -0.1457, -0.1259]
131
- ]
132
- self.taesd_decoder_name = "taesd3_decoder"
133
-
134
- def process_in(self, latent):
135
- return (latent - self.shift_factor) * self.scale_factor
136
-
137
- def process_out(self, latent):
138
- return (latent / self.scale_factor) + self.shift_factor
139
-
140
- class StableAudio1(LatentFormat):
141
- latent_channels = 64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/.DS_Store DELETED
Binary file (6.15 kB)
 
MagicQuill/comfy/ldm/__pycache__/util.cpython-310.pyc DELETED
Binary file (6.19 kB)
 
MagicQuill/comfy/ldm/audio/__pycache__/autoencoder.cpython-310.pyc DELETED
Binary file (8.08 kB)
 
MagicQuill/comfy/ldm/audio/__pycache__/dit.cpython-310.pyc DELETED
Binary file (18.7 kB)
 
MagicQuill/comfy/ldm/audio/__pycache__/embedders.cpython-310.pyc DELETED
Binary file (4.34 kB)
 
MagicQuill/comfy/ldm/audio/autoencoder.py DELETED
@@ -1,282 +0,0 @@
1
- # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
-
3
- import torch
4
- from torch import nn
5
- from typing import Literal, Dict, Any
6
- import math
7
- import comfy.ops
8
- ops = comfy.ops.disable_weight_init
9
-
10
- def vae_sample(mean, scale):
11
- stdev = nn.functional.softplus(scale) + 1e-4
12
- var = stdev * stdev
13
- logvar = torch.log(var)
14
- latents = torch.randn_like(mean) * stdev + mean
15
-
16
- kl = (mean * mean + var - logvar - 1).sum(1).mean()
17
-
18
- return latents, kl
19
-
20
- class VAEBottleneck(nn.Module):
21
- def __init__(self):
22
- super().__init__()
23
- self.is_discrete = False
24
-
25
- def encode(self, x, return_info=False, **kwargs):
26
- info = {}
27
-
28
- mean, scale = x.chunk(2, dim=1)
29
-
30
- x, kl = vae_sample(mean, scale)
31
-
32
- info["kl"] = kl
33
-
34
- if return_info:
35
- return x, info
36
- else:
37
- return x
38
-
39
- def decode(self, x):
40
- return x
41
-
42
-
43
- def snake_beta(x, alpha, beta):
44
- return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
45
-
46
- # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
47
- class SnakeBeta(nn.Module):
48
-
49
- def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
50
- super(SnakeBeta, self).__init__()
51
- self.in_features = in_features
52
-
53
- # initialize alpha
54
- self.alpha_logscale = alpha_logscale
55
- if self.alpha_logscale: # log scale alphas initialized to zeros
56
- self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
57
- self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
58
- else: # linear scale alphas initialized to ones
59
- self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
60
- self.beta = nn.Parameter(torch.ones(in_features) * alpha)
61
-
62
- # self.alpha.requires_grad = alpha_trainable
63
- # self.beta.requires_grad = alpha_trainable
64
-
65
- self.no_div_by_zero = 0.000000001
66
-
67
- def forward(self, x):
68
- alpha = self.alpha.unsqueeze(0).unsqueeze(-1).to(x.device) # line up with x to [B, C, T]
69
- beta = self.beta.unsqueeze(0).unsqueeze(-1).to(x.device)
70
- if self.alpha_logscale:
71
- alpha = torch.exp(alpha)
72
- beta = torch.exp(beta)
73
- x = snake_beta(x, alpha, beta)
74
-
75
- return x
76
-
77
- def WNConv1d(*args, **kwargs):
78
- try:
79
- return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
80
- except:
81
- return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
82
-
83
- def WNConvTranspose1d(*args, **kwargs):
84
- try:
85
- return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
86
- except:
87
- return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
88
-
89
- def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
90
- if activation == "elu":
91
- act = torch.nn.ELU()
92
- elif activation == "snake":
93
- act = SnakeBeta(channels)
94
- elif activation == "none":
95
- act = torch.nn.Identity()
96
- else:
97
- raise ValueError(f"Unknown activation {activation}")
98
-
99
- if antialias:
100
- act = Activation1d(act)
101
-
102
- return act
103
-
104
-
105
- class ResidualUnit(nn.Module):
106
- def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
107
- super().__init__()
108
-
109
- self.dilation = dilation
110
-
111
- padding = (dilation * (7-1)) // 2
112
-
113
- self.layers = nn.Sequential(
114
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
115
- WNConv1d(in_channels=in_channels, out_channels=out_channels,
116
- kernel_size=7, dilation=dilation, padding=padding),
117
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
118
- WNConv1d(in_channels=out_channels, out_channels=out_channels,
119
- kernel_size=1)
120
- )
121
-
122
- def forward(self, x):
123
- res = x
124
-
125
- #x = checkpoint(self.layers, x)
126
- x = self.layers(x)
127
-
128
- return x + res
129
-
130
- class EncoderBlock(nn.Module):
131
- def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
132
- super().__init__()
133
-
134
- self.layers = nn.Sequential(
135
- ResidualUnit(in_channels=in_channels,
136
- out_channels=in_channels, dilation=1, use_snake=use_snake),
137
- ResidualUnit(in_channels=in_channels,
138
- out_channels=in_channels, dilation=3, use_snake=use_snake),
139
- ResidualUnit(in_channels=in_channels,
140
- out_channels=in_channels, dilation=9, use_snake=use_snake),
141
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
142
- WNConv1d(in_channels=in_channels, out_channels=out_channels,
143
- kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
144
- )
145
-
146
- def forward(self, x):
147
- return self.layers(x)
148
-
149
- class DecoderBlock(nn.Module):
150
- def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
151
- super().__init__()
152
-
153
- if use_nearest_upsample:
154
- upsample_layer = nn.Sequential(
155
- nn.Upsample(scale_factor=stride, mode="nearest"),
156
- WNConv1d(in_channels=in_channels,
157
- out_channels=out_channels,
158
- kernel_size=2*stride,
159
- stride=1,
160
- bias=False,
161
- padding='same')
162
- )
163
- else:
164
- upsample_layer = WNConvTranspose1d(in_channels=in_channels,
165
- out_channels=out_channels,
166
- kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
167
-
168
- self.layers = nn.Sequential(
169
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
170
- upsample_layer,
171
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
172
- dilation=1, use_snake=use_snake),
173
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
174
- dilation=3, use_snake=use_snake),
175
- ResidualUnit(in_channels=out_channels, out_channels=out_channels,
176
- dilation=9, use_snake=use_snake),
177
- )
178
-
179
- def forward(self, x):
180
- return self.layers(x)
181
-
182
- class OobleckEncoder(nn.Module):
183
- def __init__(self,
184
- in_channels=2,
185
- channels=128,
186
- latent_dim=32,
187
- c_mults = [1, 2, 4, 8],
188
- strides = [2, 4, 8, 8],
189
- use_snake=False,
190
- antialias_activation=False
191
- ):
192
- super().__init__()
193
-
194
- c_mults = [1] + c_mults
195
-
196
- self.depth = len(c_mults)
197
-
198
- layers = [
199
- WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
200
- ]
201
-
202
- for i in range(self.depth-1):
203
- layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
204
-
205
- layers += [
206
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
207
- WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
208
- ]
209
-
210
- self.layers = nn.Sequential(*layers)
211
-
212
- def forward(self, x):
213
- return self.layers(x)
214
-
215
-
216
- class OobleckDecoder(nn.Module):
217
- def __init__(self,
218
- out_channels=2,
219
- channels=128,
220
- latent_dim=32,
221
- c_mults = [1, 2, 4, 8],
222
- strides = [2, 4, 8, 8],
223
- use_snake=False,
224
- antialias_activation=False,
225
- use_nearest_upsample=False,
226
- final_tanh=True):
227
- super().__init__()
228
-
229
- c_mults = [1] + c_mults
230
-
231
- self.depth = len(c_mults)
232
-
233
- layers = [
234
- WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
235
- ]
236
-
237
- for i in range(self.depth-1, 0, -1):
238
- layers += [DecoderBlock(
239
- in_channels=c_mults[i]*channels,
240
- out_channels=c_mults[i-1]*channels,
241
- stride=strides[i-1],
242
- use_snake=use_snake,
243
- antialias_activation=antialias_activation,
244
- use_nearest_upsample=use_nearest_upsample
245
- )
246
- ]
247
-
248
- layers += [
249
- get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
250
- WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
251
- nn.Tanh() if final_tanh else nn.Identity()
252
- ]
253
-
254
- self.layers = nn.Sequential(*layers)
255
-
256
- def forward(self, x):
257
- return self.layers(x)
258
-
259
-
260
- class AudioOobleckVAE(nn.Module):
261
- def __init__(self,
262
- in_channels=2,
263
- channels=128,
264
- latent_dim=64,
265
- c_mults = [1, 2, 4, 8, 16],
266
- strides = [2, 4, 4, 8, 8],
267
- use_snake=True,
268
- antialias_activation=False,
269
- use_nearest_upsample=False,
270
- final_tanh=False):
271
- super().__init__()
272
- self.encoder = OobleckEncoder(in_channels, channels, latent_dim * 2, c_mults, strides, use_snake, antialias_activation)
273
- self.decoder = OobleckDecoder(in_channels, channels, latent_dim, c_mults, strides, use_snake, antialias_activation,
274
- use_nearest_upsample=use_nearest_upsample, final_tanh=final_tanh)
275
- self.bottleneck = VAEBottleneck()
276
-
277
- def encode(self, x):
278
- return self.bottleneck.encode(self.encoder(x))
279
-
280
- def decode(self, x):
281
- return self.decoder(self.bottleneck.decode(x))
282
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/audio/dit.py DELETED
@@ -1,888 +0,0 @@
1
- # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
-
3
- from comfy.ldm.modules.attention import optimized_attention
4
- import typing as tp
5
-
6
- import torch
7
-
8
- from einops import rearrange
9
- from torch import nn
10
- from torch.nn import functional as F
11
- import math
12
-
13
- class FourierFeatures(nn.Module):
14
- def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
15
- super().__init__()
16
- assert out_features % 2 == 0
17
- self.weight = nn.Parameter(torch.empty(
18
- [out_features // 2, in_features], dtype=dtype, device=device))
19
-
20
- def forward(self, input):
21
- f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
22
- return torch.cat([f.cos(), f.sin()], dim=-1)
23
-
24
- # norms
25
- class LayerNorm(nn.Module):
26
- def __init__(self, dim, bias=False, fix_scale=False, dtype=None, device=None):
27
- """
28
- bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
29
- """
30
- super().__init__()
31
-
32
- self.gamma = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
33
-
34
- if bias:
35
- self.beta = nn.Parameter(torch.empty(dim, dtype=dtype, device=device))
36
- else:
37
- self.beta = None
38
-
39
- def forward(self, x):
40
- beta = self.beta
41
- if self.beta is not None:
42
- beta = beta.to(dtype=x.dtype, device=x.device)
43
- return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
44
-
45
- class GLU(nn.Module):
46
- def __init__(
47
- self,
48
- dim_in,
49
- dim_out,
50
- activation,
51
- use_conv = False,
52
- conv_kernel_size = 3,
53
- dtype=None,
54
- device=None,
55
- operations=None,
56
- ):
57
- super().__init__()
58
- self.act = activation
59
- self.proj = operations.Linear(dim_in, dim_out * 2, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2), dtype=dtype, device=device)
60
- self.use_conv = use_conv
61
-
62
- def forward(self, x):
63
- if self.use_conv:
64
- x = rearrange(x, 'b n d -> b d n')
65
- x = self.proj(x)
66
- x = rearrange(x, 'b d n -> b n d')
67
- else:
68
- x = self.proj(x)
69
-
70
- x, gate = x.chunk(2, dim = -1)
71
- return x * self.act(gate)
72
-
73
- class AbsolutePositionalEmbedding(nn.Module):
74
- def __init__(self, dim, max_seq_len):
75
- super().__init__()
76
- self.scale = dim ** -0.5
77
- self.max_seq_len = max_seq_len
78
- self.emb = nn.Embedding(max_seq_len, dim)
79
-
80
- def forward(self, x, pos = None, seq_start_pos = None):
81
- seq_len, device = x.shape[1], x.device
82
- assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
83
-
84
- if pos is None:
85
- pos = torch.arange(seq_len, device = device)
86
-
87
- if seq_start_pos is not None:
88
- pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
89
-
90
- pos_emb = self.emb(pos)
91
- pos_emb = pos_emb * self.scale
92
- return pos_emb
93
-
94
- class ScaledSinusoidalEmbedding(nn.Module):
95
- def __init__(self, dim, theta = 10000):
96
- super().__init__()
97
- assert (dim % 2) == 0, 'dimension must be divisible by 2'
98
- self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
99
-
100
- half_dim = dim // 2
101
- freq_seq = torch.arange(half_dim).float() / half_dim
102
- inv_freq = theta ** -freq_seq
103
- self.register_buffer('inv_freq', inv_freq, persistent = False)
104
-
105
- def forward(self, x, pos = None, seq_start_pos = None):
106
- seq_len, device = x.shape[1], x.device
107
-
108
- if pos is None:
109
- pos = torch.arange(seq_len, device = device)
110
-
111
- if seq_start_pos is not None:
112
- pos = pos - seq_start_pos[..., None]
113
-
114
- emb = torch.einsum('i, j -> i j', pos, self.inv_freq)
115
- emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
116
- return emb * self.scale
117
-
118
- class RotaryEmbedding(nn.Module):
119
- def __init__(
120
- self,
121
- dim,
122
- use_xpos = False,
123
- scale_base = 512,
124
- interpolation_factor = 1.,
125
- base = 10000,
126
- base_rescale_factor = 1.
127
- ):
128
- super().__init__()
129
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
130
- # has some connection to NTK literature
131
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
132
- base *= base_rescale_factor ** (dim / (dim - 2))
133
-
134
- inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
135
- self.register_buffer('inv_freq', inv_freq)
136
-
137
- assert interpolation_factor >= 1.
138
- self.interpolation_factor = interpolation_factor
139
-
140
- if not use_xpos:
141
- self.register_buffer('scale', None)
142
- return
143
-
144
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
145
-
146
- self.scale_base = scale_base
147
- self.register_buffer('scale', scale)
148
-
149
- def forward_from_seq_len(self, seq_len, device, dtype):
150
- # device = self.inv_freq.device
151
-
152
- t = torch.arange(seq_len, device=device, dtype=dtype)
153
- return self.forward(t)
154
-
155
- def forward(self, t):
156
- # device = self.inv_freq.device
157
- device = t.device
158
- dtype = t.dtype
159
-
160
- # t = t.to(torch.float32)
161
-
162
- t = t / self.interpolation_factor
163
-
164
- freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
165
- freqs = torch.cat((freqs, freqs), dim = -1)
166
-
167
- if self.scale is None:
168
- return freqs, 1.
169
-
170
- power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
171
- scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
172
- scale = torch.cat((scale, scale), dim = -1)
173
-
174
- return freqs, scale
175
-
176
- def rotate_half(x):
177
- x = rearrange(x, '... (j d) -> ... j d', j = 2)
178
- x1, x2 = x.unbind(dim = -2)
179
- return torch.cat((-x2, x1), dim = -1)
180
-
181
- def apply_rotary_pos_emb(t, freqs, scale = 1):
182
- out_dtype = t.dtype
183
-
184
- # cast to float32 if necessary for numerical stability
185
- dtype = t.dtype #reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
186
- rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
187
- freqs, t = freqs.to(dtype), t.to(dtype)
188
- freqs = freqs[-seq_len:, :]
189
-
190
- if t.ndim == 4 and freqs.ndim == 3:
191
- freqs = rearrange(freqs, 'b n d -> b 1 n d')
192
-
193
- # partial rotary embeddings, Wang et al. GPT-J
194
- t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
195
- t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
196
-
197
- t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
198
-
199
- return torch.cat((t, t_unrotated), dim = -1)
200
-
201
- class FeedForward(nn.Module):
202
- def __init__(
203
- self,
204
- dim,
205
- dim_out = None,
206
- mult = 4,
207
- no_bias = False,
208
- glu = True,
209
- use_conv = False,
210
- conv_kernel_size = 3,
211
- zero_init_output = True,
212
- dtype=None,
213
- device=None,
214
- operations=None,
215
- ):
216
- super().__init__()
217
- inner_dim = int(dim * mult)
218
-
219
- # Default to SwiGLU
220
-
221
- activation = nn.SiLU()
222
-
223
- dim_out = dim if dim_out is None else dim_out
224
-
225
- if glu:
226
- linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
227
- else:
228
- linear_in = nn.Sequential(
229
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
230
- operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
231
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
232
- activation
233
- )
234
-
235
- linear_out = operations.Linear(inner_dim, dim_out, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device)
236
-
237
- # # init last linear layer to 0
238
- # if zero_init_output:
239
- # nn.init.zeros_(linear_out.weight)
240
- # if not no_bias:
241
- # nn.init.zeros_(linear_out.bias)
242
-
243
-
244
- self.ff = nn.Sequential(
245
- linear_in,
246
- Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
247
- linear_out,
248
- Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
249
- )
250
-
251
- def forward(self, x):
252
- return self.ff(x)
253
-
254
- class Attention(nn.Module):
255
- def __init__(
256
- self,
257
- dim,
258
- dim_heads = 64,
259
- dim_context = None,
260
- causal = False,
261
- zero_init_output=True,
262
- qk_norm = False,
263
- natten_kernel_size = None,
264
- dtype=None,
265
- device=None,
266
- operations=None,
267
- ):
268
- super().__init__()
269
- self.dim = dim
270
- self.dim_heads = dim_heads
271
- self.causal = causal
272
-
273
- dim_kv = dim_context if dim_context is not None else dim
274
-
275
- self.num_heads = dim // dim_heads
276
- self.kv_heads = dim_kv // dim_heads
277
-
278
- if dim_context is not None:
279
- self.to_q = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
280
- self.to_kv = operations.Linear(dim_kv, dim_kv * 2, bias=False, dtype=dtype, device=device)
281
- else:
282
- self.to_qkv = operations.Linear(dim, dim * 3, bias=False, dtype=dtype, device=device)
283
-
284
- self.to_out = operations.Linear(dim, dim, bias=False, dtype=dtype, device=device)
285
-
286
- # if zero_init_output:
287
- # nn.init.zeros_(self.to_out.weight)
288
-
289
- self.qk_norm = qk_norm
290
-
291
-
292
- def forward(
293
- self,
294
- x,
295
- context = None,
296
- mask = None,
297
- context_mask = None,
298
- rotary_pos_emb = None,
299
- causal = None
300
- ):
301
- h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
302
-
303
- kv_input = context if has_context else x
304
-
305
- if hasattr(self, 'to_q'):
306
- # Use separate linear projections for q and k/v
307
- q = self.to_q(x)
308
- q = rearrange(q, 'b n (h d) -> b h n d', h = h)
309
-
310
- k, v = self.to_kv(kv_input).chunk(2, dim=-1)
311
-
312
- k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
313
- else:
314
- # Use fused linear projection
315
- q, k, v = self.to_qkv(x).chunk(3, dim=-1)
316
- q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
317
-
318
- # Normalize q and k for cosine sim attention
319
- if self.qk_norm:
320
- q = F.normalize(q, dim=-1)
321
- k = F.normalize(k, dim=-1)
322
-
323
- if rotary_pos_emb is not None and not has_context:
324
- freqs, _ = rotary_pos_emb
325
-
326
- q_dtype = q.dtype
327
- k_dtype = k.dtype
328
-
329
- q = q.to(torch.float32)
330
- k = k.to(torch.float32)
331
- freqs = freqs.to(torch.float32)
332
-
333
- q = apply_rotary_pos_emb(q, freqs)
334
- k = apply_rotary_pos_emb(k, freqs)
335
-
336
- q = q.to(q_dtype)
337
- k = k.to(k_dtype)
338
-
339
- input_mask = context_mask
340
-
341
- if input_mask is None and not has_context:
342
- input_mask = mask
343
-
344
- # determine masking
345
- masks = []
346
- final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
347
-
348
- if input_mask is not None:
349
- input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
350
- masks.append(~input_mask)
351
-
352
- # Other masks will be added here later
353
-
354
- if len(masks) > 0:
355
- final_attn_mask = ~or_reduce(masks)
356
-
357
- n, device = q.shape[-2], q.device
358
-
359
- causal = self.causal if causal is None else causal
360
-
361
- if n == 1 and causal:
362
- causal = False
363
-
364
- if h != kv_h:
365
- # Repeat interleave kv_heads to match q_heads
366
- heads_per_kv_head = h // kv_h
367
- k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
368
-
369
- out = optimized_attention(q, k, v, h, skip_reshape=True)
370
- out = self.to_out(out)
371
-
372
- if mask is not None:
373
- mask = rearrange(mask, 'b n -> b n 1')
374
- out = out.masked_fill(~mask, 0.)
375
-
376
- return out
377
-
378
- class ConformerModule(nn.Module):
379
- def __init__(
380
- self,
381
- dim,
382
- norm_kwargs = {},
383
- ):
384
-
385
- super().__init__()
386
-
387
- self.dim = dim
388
-
389
- self.in_norm = LayerNorm(dim, **norm_kwargs)
390
- self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
391
- self.glu = GLU(dim, dim, nn.SiLU())
392
- self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
393
- self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
394
- self.swish = nn.SiLU()
395
- self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
396
-
397
- def forward(self, x):
398
- x = self.in_norm(x)
399
- x = rearrange(x, 'b n d -> b d n')
400
- x = self.pointwise_conv(x)
401
- x = rearrange(x, 'b d n -> b n d')
402
- x = self.glu(x)
403
- x = rearrange(x, 'b n d -> b d n')
404
- x = self.depthwise_conv(x)
405
- x = rearrange(x, 'b d n -> b n d')
406
- x = self.mid_norm(x)
407
- x = self.swish(x)
408
- x = rearrange(x, 'b n d -> b d n')
409
- x = self.pointwise_conv_2(x)
410
- x = rearrange(x, 'b d n -> b n d')
411
-
412
- return x
413
-
414
- class TransformerBlock(nn.Module):
415
- def __init__(
416
- self,
417
- dim,
418
- dim_heads = 64,
419
- cross_attend = False,
420
- dim_context = None,
421
- global_cond_dim = None,
422
- causal = False,
423
- zero_init_branch_outputs = True,
424
- conformer = False,
425
- layer_ix = -1,
426
- remove_norms = False,
427
- attn_kwargs = {},
428
- ff_kwargs = {},
429
- norm_kwargs = {},
430
- dtype=None,
431
- device=None,
432
- operations=None,
433
- ):
434
-
435
- super().__init__()
436
- self.dim = dim
437
- self.dim_heads = dim_heads
438
- self.cross_attend = cross_attend
439
- self.dim_context = dim_context
440
- self.causal = causal
441
-
442
- self.pre_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
443
-
444
- self.self_attn = Attention(
445
- dim,
446
- dim_heads = dim_heads,
447
- causal = causal,
448
- zero_init_output=zero_init_branch_outputs,
449
- dtype=dtype,
450
- device=device,
451
- operations=operations,
452
- **attn_kwargs
453
- )
454
-
455
- if cross_attend:
456
- self.cross_attend_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
457
- self.cross_attn = Attention(
458
- dim,
459
- dim_heads = dim_heads,
460
- dim_context=dim_context,
461
- causal = causal,
462
- zero_init_output=zero_init_branch_outputs,
463
- dtype=dtype,
464
- device=device,
465
- operations=operations,
466
- **attn_kwargs
467
- )
468
-
469
- self.ff_norm = LayerNorm(dim, dtype=dtype, device=device, **norm_kwargs) if not remove_norms else nn.Identity()
470
- self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, dtype=dtype, device=device, operations=operations,**ff_kwargs)
471
-
472
- self.layer_ix = layer_ix
473
-
474
- self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
475
-
476
- self.global_cond_dim = global_cond_dim
477
-
478
- if global_cond_dim is not None:
479
- self.to_scale_shift_gate = nn.Sequential(
480
- nn.SiLU(),
481
- nn.Linear(global_cond_dim, dim * 6, bias=False)
482
- )
483
-
484
- nn.init.zeros_(self.to_scale_shift_gate[1].weight)
485
- #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
486
-
487
- def forward(
488
- self,
489
- x,
490
- context = None,
491
- global_cond=None,
492
- mask = None,
493
- context_mask = None,
494
- rotary_pos_emb = None
495
- ):
496
- if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
497
-
498
- scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
499
-
500
- # self-attention with adaLN
501
- residual = x
502
- x = self.pre_norm(x)
503
- x = x * (1 + scale_self) + shift_self
504
- x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
505
- x = x * torch.sigmoid(1 - gate_self)
506
- x = x + residual
507
-
508
- if context is not None:
509
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
510
-
511
- if self.conformer is not None:
512
- x = x + self.conformer(x)
513
-
514
- # feedforward with adaLN
515
- residual = x
516
- x = self.ff_norm(x)
517
- x = x * (1 + scale_ff) + shift_ff
518
- x = self.ff(x)
519
- x = x * torch.sigmoid(1 - gate_ff)
520
- x = x + residual
521
-
522
- else:
523
- x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
524
-
525
- if context is not None:
526
- x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
527
-
528
- if self.conformer is not None:
529
- x = x + self.conformer(x)
530
-
531
- x = x + self.ff(self.ff_norm(x))
532
-
533
- return x
534
-
535
- class ContinuousTransformer(nn.Module):
536
- def __init__(
537
- self,
538
- dim,
539
- depth,
540
- *,
541
- dim_in = None,
542
- dim_out = None,
543
- dim_heads = 64,
544
- cross_attend=False,
545
- cond_token_dim=None,
546
- global_cond_dim=None,
547
- causal=False,
548
- rotary_pos_emb=True,
549
- zero_init_branch_outputs=True,
550
- conformer=False,
551
- use_sinusoidal_emb=False,
552
- use_abs_pos_emb=False,
553
- abs_pos_emb_max_length=10000,
554
- dtype=None,
555
- device=None,
556
- operations=None,
557
- **kwargs
558
- ):
559
-
560
- super().__init__()
561
-
562
- self.dim = dim
563
- self.depth = depth
564
- self.causal = causal
565
- self.layers = nn.ModuleList([])
566
-
567
- self.project_in = operations.Linear(dim_in, dim, bias=False, dtype=dtype, device=device) if dim_in is not None else nn.Identity()
568
- self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
569
-
570
- if rotary_pos_emb:
571
- self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
572
- else:
573
- self.rotary_pos_emb = None
574
-
575
- self.use_sinusoidal_emb = use_sinusoidal_emb
576
- if use_sinusoidal_emb:
577
- self.pos_emb = ScaledSinusoidalEmbedding(dim)
578
-
579
- self.use_abs_pos_emb = use_abs_pos_emb
580
- if use_abs_pos_emb:
581
- self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
582
-
583
- for i in range(depth):
584
- self.layers.append(
585
- TransformerBlock(
586
- dim,
587
- dim_heads = dim_heads,
588
- cross_attend = cross_attend,
589
- dim_context = cond_token_dim,
590
- global_cond_dim = global_cond_dim,
591
- causal = causal,
592
- zero_init_branch_outputs = zero_init_branch_outputs,
593
- conformer=conformer,
594
- layer_ix=i,
595
- dtype=dtype,
596
- device=device,
597
- operations=operations,
598
- **kwargs
599
- )
600
- )
601
-
602
- def forward(
603
- self,
604
- x,
605
- mask = None,
606
- prepend_embeds = None,
607
- prepend_mask = None,
608
- global_cond = None,
609
- return_info = False,
610
- **kwargs
611
- ):
612
- batch, seq, device = *x.shape[:2], x.device
613
-
614
- info = {
615
- "hidden_states": [],
616
- }
617
-
618
- x = self.project_in(x)
619
-
620
- if prepend_embeds is not None:
621
- prepend_length, prepend_dim = prepend_embeds.shape[1:]
622
-
623
- assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
624
-
625
- x = torch.cat((prepend_embeds, x), dim = -2)
626
-
627
- if prepend_mask is not None or mask is not None:
628
- mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
629
- prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
630
-
631
- mask = torch.cat((prepend_mask, mask), dim = -1)
632
-
633
- # Attention layers
634
-
635
- if self.rotary_pos_emb is not None:
636
- rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1], dtype=x.dtype, device=x.device)
637
- else:
638
- rotary_pos_emb = None
639
-
640
- if self.use_sinusoidal_emb or self.use_abs_pos_emb:
641
- x = x + self.pos_emb(x)
642
-
643
- # Iterate over the transformer layers
644
- for layer in self.layers:
645
- x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
646
- # x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
647
-
648
- if return_info:
649
- info["hidden_states"].append(x)
650
-
651
- x = self.project_out(x)
652
-
653
- if return_info:
654
- return x, info
655
-
656
- return x
657
-
658
- class AudioDiffusionTransformer(nn.Module):
659
- def __init__(self,
660
- io_channels=64,
661
- patch_size=1,
662
- embed_dim=1536,
663
- cond_token_dim=768,
664
- project_cond_tokens=False,
665
- global_cond_dim=1536,
666
- project_global_cond=True,
667
- input_concat_dim=0,
668
- prepend_cond_dim=0,
669
- depth=24,
670
- num_heads=24,
671
- transformer_type: tp.Literal["continuous_transformer"] = "continuous_transformer",
672
- global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
673
- audio_model="",
674
- dtype=None,
675
- device=None,
676
- operations=None,
677
- **kwargs):
678
-
679
- super().__init__()
680
-
681
- self.dtype = dtype
682
- self.cond_token_dim = cond_token_dim
683
-
684
- # Timestep embeddings
685
- timestep_features_dim = 256
686
-
687
- self.timestep_features = FourierFeatures(1, timestep_features_dim, dtype=dtype, device=device)
688
-
689
- self.to_timestep_embed = nn.Sequential(
690
- operations.Linear(timestep_features_dim, embed_dim, bias=True, dtype=dtype, device=device),
691
- nn.SiLU(),
692
- operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device),
693
- )
694
-
695
- if cond_token_dim > 0:
696
- # Conditioning tokens
697
-
698
- cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
699
- self.to_cond_embed = nn.Sequential(
700
- operations.Linear(cond_token_dim, cond_embed_dim, bias=False, dtype=dtype, device=device),
701
- nn.SiLU(),
702
- operations.Linear(cond_embed_dim, cond_embed_dim, bias=False, dtype=dtype, device=device)
703
- )
704
- else:
705
- cond_embed_dim = 0
706
-
707
- if global_cond_dim > 0:
708
- # Global conditioning
709
- global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
710
- self.to_global_embed = nn.Sequential(
711
- operations.Linear(global_cond_dim, global_embed_dim, bias=False, dtype=dtype, device=device),
712
- nn.SiLU(),
713
- operations.Linear(global_embed_dim, global_embed_dim, bias=False, dtype=dtype, device=device)
714
- )
715
-
716
- if prepend_cond_dim > 0:
717
- # Prepend conditioning
718
- self.to_prepend_embed = nn.Sequential(
719
- operations.Linear(prepend_cond_dim, embed_dim, bias=False, dtype=dtype, device=device),
720
- nn.SiLU(),
721
- operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
722
- )
723
-
724
- self.input_concat_dim = input_concat_dim
725
-
726
- dim_in = io_channels + self.input_concat_dim
727
-
728
- self.patch_size = patch_size
729
-
730
- # Transformer
731
-
732
- self.transformer_type = transformer_type
733
-
734
- self.global_cond_type = global_cond_type
735
-
736
- if self.transformer_type == "continuous_transformer":
737
-
738
- global_dim = None
739
-
740
- if self.global_cond_type == "adaLN":
741
- # The global conditioning is projected to the embed_dim already at this point
742
- global_dim = embed_dim
743
-
744
- self.transformer = ContinuousTransformer(
745
- dim=embed_dim,
746
- depth=depth,
747
- dim_heads=embed_dim // num_heads,
748
- dim_in=dim_in * patch_size,
749
- dim_out=io_channels * patch_size,
750
- cross_attend = cond_token_dim > 0,
751
- cond_token_dim = cond_embed_dim,
752
- global_cond_dim=global_dim,
753
- dtype=dtype,
754
- device=device,
755
- operations=operations,
756
- **kwargs
757
- )
758
- else:
759
- raise ValueError(f"Unknown transformer type: {self.transformer_type}")
760
-
761
- self.preprocess_conv = operations.Conv1d(dim_in, dim_in, 1, bias=False, dtype=dtype, device=device)
762
- self.postprocess_conv = operations.Conv1d(io_channels, io_channels, 1, bias=False, dtype=dtype, device=device)
763
-
764
- def _forward(
765
- self,
766
- x,
767
- t,
768
- mask=None,
769
- cross_attn_cond=None,
770
- cross_attn_cond_mask=None,
771
- input_concat_cond=None,
772
- global_embed=None,
773
- prepend_cond=None,
774
- prepend_cond_mask=None,
775
- return_info=False,
776
- **kwargs):
777
-
778
- if cross_attn_cond is not None:
779
- cross_attn_cond = self.to_cond_embed(cross_attn_cond)
780
-
781
- if global_embed is not None:
782
- # Project the global conditioning to the embedding dimension
783
- global_embed = self.to_global_embed(global_embed)
784
-
785
- prepend_inputs = None
786
- prepend_mask = None
787
- prepend_length = 0
788
- if prepend_cond is not None:
789
- # Project the prepend conditioning to the embedding dimension
790
- prepend_cond = self.to_prepend_embed(prepend_cond)
791
-
792
- prepend_inputs = prepend_cond
793
- if prepend_cond_mask is not None:
794
- prepend_mask = prepend_cond_mask
795
-
796
- if input_concat_cond is not None:
797
-
798
- # Interpolate input_concat_cond to the same length as x
799
- if input_concat_cond.shape[2] != x.shape[2]:
800
- input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
801
-
802
- x = torch.cat([x, input_concat_cond], dim=1)
803
-
804
- # Get the batch of timestep embeddings
805
- timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None]).to(x.dtype)) # (b, embed_dim)
806
-
807
- # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
808
- if global_embed is not None:
809
- global_embed = global_embed + timestep_embed
810
- else:
811
- global_embed = timestep_embed
812
-
813
- # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
814
- if self.global_cond_type == "prepend":
815
- if prepend_inputs is None:
816
- # Prepend inputs are just the global embed, and the mask is all ones
817
- prepend_inputs = global_embed.unsqueeze(1)
818
- prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
819
- else:
820
- # Prepend inputs are the prepend conditioning + the global embed
821
- prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
822
- prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
823
-
824
- prepend_length = prepend_inputs.shape[1]
825
-
826
- x = self.preprocess_conv(x) + x
827
-
828
- x = rearrange(x, "b c t -> b t c")
829
-
830
- extra_args = {}
831
-
832
- if self.global_cond_type == "adaLN":
833
- extra_args["global_cond"] = global_embed
834
-
835
- if self.patch_size > 1:
836
- x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
837
-
838
- if self.transformer_type == "x-transformers":
839
- output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
840
- elif self.transformer_type == "continuous_transformer":
841
- output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
842
-
843
- if return_info:
844
- output, info = output
845
- elif self.transformer_type == "mm_transformer":
846
- output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
847
-
848
- output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
849
-
850
- if self.patch_size > 1:
851
- output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
852
-
853
- output = self.postprocess_conv(output) + output
854
-
855
- if return_info:
856
- return output, info
857
-
858
- return output
859
-
860
- def forward(
861
- self,
862
- x,
863
- timestep,
864
- context=None,
865
- context_mask=None,
866
- input_concat_cond=None,
867
- global_embed=None,
868
- negative_global_embed=None,
869
- prepend_cond=None,
870
- prepend_cond_mask=None,
871
- mask=None,
872
- return_info=False,
873
- control=None,
874
- transformer_options={},
875
- **kwargs):
876
- return self._forward(
877
- x,
878
- timestep,
879
- cross_attn_cond=context,
880
- cross_attn_cond_mask=context_mask,
881
- input_concat_cond=input_concat_cond,
882
- global_embed=global_embed,
883
- prepend_cond=prepend_cond,
884
- prepend_cond_mask=prepend_cond_mask,
885
- mask=mask,
886
- return_info=return_info,
887
- **kwargs
888
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/audio/embedders.py DELETED
@@ -1,108 +0,0 @@
1
- # code adapted from: https://github.com/Stability-AI/stable-audio-tools
2
-
3
- import torch
4
- import torch.nn as nn
5
- from torch import Tensor, einsum
6
- from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
7
- from einops import rearrange
8
- import math
9
- import comfy.ops
10
-
11
- class LearnedPositionalEmbedding(nn.Module):
12
- """Used for continuous time"""
13
-
14
- def __init__(self, dim: int):
15
- super().__init__()
16
- assert (dim % 2) == 0
17
- half_dim = dim // 2
18
- self.weights = nn.Parameter(torch.empty(half_dim))
19
-
20
- def forward(self, x: Tensor) -> Tensor:
21
- x = rearrange(x, "b -> b 1")
22
- freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi
23
- fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
24
- fouriered = torch.cat((x, fouriered), dim=-1)
25
- return fouriered
26
-
27
- def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
28
- return nn.Sequential(
29
- LearnedPositionalEmbedding(dim),
30
- comfy.ops.manual_cast.Linear(in_features=dim + 1, out_features=out_features),
31
- )
32
-
33
-
34
- class NumberEmbedder(nn.Module):
35
- def __init__(
36
- self,
37
- features: int,
38
- dim: int = 256,
39
- ):
40
- super().__init__()
41
- self.features = features
42
- self.embedding = TimePositionalEmbedding(dim=dim, out_features=features)
43
-
44
- def forward(self, x: Union[List[float], Tensor]) -> Tensor:
45
- if not torch.is_tensor(x):
46
- device = next(self.embedding.parameters()).device
47
- x = torch.tensor(x, device=device)
48
- assert isinstance(x, Tensor)
49
- shape = x.shape
50
- x = rearrange(x, "... -> (...)")
51
- embedding = self.embedding(x)
52
- x = embedding.view(*shape, self.features)
53
- return x # type: ignore
54
-
55
-
56
- class Conditioner(nn.Module):
57
- def __init__(
58
- self,
59
- dim: int,
60
- output_dim: int,
61
- project_out: bool = False
62
- ):
63
-
64
- super().__init__()
65
-
66
- self.dim = dim
67
- self.output_dim = output_dim
68
- self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
69
-
70
- def forward(self, x):
71
- raise NotImplementedError()
72
-
73
- class NumberConditioner(Conditioner):
74
- '''
75
- Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
76
- '''
77
- def __init__(self,
78
- output_dim: int,
79
- min_val: float=0,
80
- max_val: float=1
81
- ):
82
- super().__init__(output_dim, output_dim)
83
-
84
- self.min_val = min_val
85
- self.max_val = max_val
86
-
87
- self.embedder = NumberEmbedder(features=output_dim)
88
-
89
- def forward(self, floats, device=None):
90
- # Cast the inputs to floats
91
- floats = [float(x) for x in floats]
92
-
93
- if device is None:
94
- device = next(self.embedder.parameters()).device
95
-
96
- floats = torch.tensor(floats).to(device)
97
-
98
- floats = floats.clamp(self.min_val, self.max_val)
99
-
100
- normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
101
-
102
- # Cast floats to same type as embedder
103
- embedder_dtype = next(self.embedder.parameters()).dtype
104
- normalized_floats = normalized_floats.to(embedder_dtype)
105
-
106
- float_embeds = self.embedder(normalized_floats).unsqueeze(1)
107
-
108
- return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/cascade/__pycache__/common.cpython-310.pyc DELETED
Binary file (7.69 kB)
 
MagicQuill/comfy/ldm/cascade/__pycache__/controlnet.cpython-310.pyc DELETED
Binary file (3.77 kB)
 
MagicQuill/comfy/ldm/cascade/__pycache__/stage_a.cpython-310.pyc DELETED
Binary file (9.41 kB)
 
MagicQuill/comfy/ldm/cascade/__pycache__/stage_b.cpython-310.pyc DELETED
Binary file (7.77 kB)
 
MagicQuill/comfy/ldm/cascade/__pycache__/stage_c.cpython-310.pyc DELETED
Binary file (8.58 kB)
 
MagicQuill/comfy/ldm/cascade/__pycache__/stage_c_coder.cpython-310.pyc DELETED
Binary file (3.5 kB)
 
MagicQuill/comfy/ldm/cascade/common.py DELETED
@@ -1,161 +0,0 @@
1
- """
2
- This file is part of ComfyUI.
3
- Copyright (C) 2024 Stability AI
4
-
5
- This program is free software: you can redistribute it and/or modify
6
- it under the terms of the GNU General Public License as published by
7
- the Free Software Foundation, either version 3 of the License, or
8
- (at your option) any later version.
9
-
10
- This program is distributed in the hope that it will be useful,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- GNU General Public License for more details.
14
-
15
- You should have received a copy of the GNU General Public License
16
- along with this program. If not, see <https://www.gnu.org/licenses/>.
17
- """
18
-
19
- import torch
20
- import torch.nn as nn
21
- from comfy.ldm.modules.attention import optimized_attention
22
-
23
- class Linear(torch.nn.Linear):
24
- def reset_parameters(self):
25
- return None
26
-
27
- class Conv2d(torch.nn.Conv2d):
28
- def reset_parameters(self):
29
- return None
30
-
31
- class OptimizedAttention(nn.Module):
32
- def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
33
- super().__init__()
34
- self.heads = nhead
35
-
36
- self.to_q = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
37
- self.to_k = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
38
- self.to_v = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
39
-
40
- self.out_proj = operations.Linear(c, c, bias=True, dtype=dtype, device=device)
41
-
42
- def forward(self, q, k, v):
43
- q = self.to_q(q)
44
- k = self.to_k(k)
45
- v = self.to_v(v)
46
-
47
- out = optimized_attention(q, k, v, self.heads)
48
-
49
- return self.out_proj(out)
50
-
51
- class Attention2D(nn.Module):
52
- def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
53
- super().__init__()
54
- self.attn = OptimizedAttention(c, nhead, dtype=dtype, device=device, operations=operations)
55
- # self.attn = nn.MultiheadAttention(c, nhead, dropout=dropout, bias=True, batch_first=True, dtype=dtype, device=device)
56
-
57
- def forward(self, x, kv, self_attn=False):
58
- orig_shape = x.shape
59
- x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1) # Bx4xHxW -> Bx(HxW)x4
60
- if self_attn:
61
- kv = torch.cat([x, kv], dim=1)
62
- # x = self.attn(x, kv, kv, need_weights=False)[0]
63
- x = self.attn(x, kv, kv)
64
- x = x.permute(0, 2, 1).view(*orig_shape)
65
- return x
66
-
67
-
68
- def LayerNorm2d_op(operations):
69
- class LayerNorm2d(operations.LayerNorm):
70
- def __init__(self, *args, **kwargs):
71
- super().__init__(*args, **kwargs)
72
-
73
- def forward(self, x):
74
- return super().forward(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
75
- return LayerNorm2d
76
-
77
- class GlobalResponseNorm(nn.Module):
78
- "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
79
- def __init__(self, dim, dtype=None, device=None):
80
- super().__init__()
81
- self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
82
- self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
83
-
84
- def forward(self, x):
85
- Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
86
- Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
87
- return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
88
-
89
-
90
- class ResBlock(nn.Module):
91
- def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0, dtype=None, device=None, operations=None): # , num_heads=4, expansion=2):
92
- super().__init__()
93
- self.depthwise = operations.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c, dtype=dtype, device=device)
94
- # self.depthwise = SAMBlock(c, num_heads, expansion)
95
- self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
96
- self.channelwise = nn.Sequential(
97
- operations.Linear(c + c_skip, c * 4, dtype=dtype, device=device),
98
- nn.GELU(),
99
- GlobalResponseNorm(c * 4, dtype=dtype, device=device),
100
- nn.Dropout(dropout),
101
- operations.Linear(c * 4, c, dtype=dtype, device=device)
102
- )
103
-
104
- def forward(self, x, x_skip=None):
105
- x_res = x
106
- x = self.norm(self.depthwise(x))
107
- if x_skip is not None:
108
- x = torch.cat([x, x_skip], dim=1)
109
- x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
110
- return x + x_res
111
-
112
-
113
- class AttnBlock(nn.Module):
114
- def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0, dtype=None, device=None, operations=None):
115
- super().__init__()
116
- self.self_attn = self_attn
117
- self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
118
- self.attention = Attention2D(c, nhead, dropout, dtype=dtype, device=device, operations=operations)
119
- self.kv_mapper = nn.Sequential(
120
- nn.SiLU(),
121
- operations.Linear(c_cond, c, dtype=dtype, device=device)
122
- )
123
-
124
- def forward(self, x, kv):
125
- kv = self.kv_mapper(kv)
126
- x = x + self.attention(self.norm(x), kv, self_attn=self.self_attn)
127
- return x
128
-
129
-
130
- class FeedForwardBlock(nn.Module):
131
- def __init__(self, c, dropout=0.0, dtype=None, device=None, operations=None):
132
- super().__init__()
133
- self.norm = LayerNorm2d_op(operations)(c, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
134
- self.channelwise = nn.Sequential(
135
- operations.Linear(c, c * 4, dtype=dtype, device=device),
136
- nn.GELU(),
137
- GlobalResponseNorm(c * 4, dtype=dtype, device=device),
138
- nn.Dropout(dropout),
139
- operations.Linear(c * 4, c, dtype=dtype, device=device)
140
- )
141
-
142
- def forward(self, x):
143
- x = x + self.channelwise(self.norm(x).permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
144
- return x
145
-
146
-
147
- class TimestepBlock(nn.Module):
148
- def __init__(self, c, c_timestep, conds=['sca'], dtype=None, device=None, operations=None):
149
- super().__init__()
150
- self.mapper = operations.Linear(c_timestep, c * 2, dtype=dtype, device=device)
151
- self.conds = conds
152
- for cname in conds:
153
- setattr(self, f"mapper_{cname}", operations.Linear(c_timestep, c * 2, dtype=dtype, device=device))
154
-
155
- def forward(self, x, t):
156
- t = t.chunk(len(self.conds) + 1, dim=1)
157
- a, b = self.mapper(t[0])[:, :, None, None].chunk(2, dim=1)
158
- for i, c in enumerate(self.conds):
159
- ac, bc = getattr(self, f"mapper_{c}")(t[i + 1])[:, :, None, None].chunk(2, dim=1)
160
- a, b = a + ac, b + bc
161
- return x * (1 + a) + b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/cascade/controlnet.py DELETED
@@ -1,93 +0,0 @@
1
- """
2
- This file is part of ComfyUI.
3
- Copyright (C) 2024 Stability AI
4
-
5
- This program is free software: you can redistribute it and/or modify
6
- it under the terms of the GNU General Public License as published by
7
- the Free Software Foundation, either version 3 of the License, or
8
- (at your option) any later version.
9
-
10
- This program is distributed in the hope that it will be useful,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- GNU General Public License for more details.
14
-
15
- You should have received a copy of the GNU General Public License
16
- along with this program. If not, see <https://www.gnu.org/licenses/>.
17
- """
18
-
19
- import torch
20
- import torchvision
21
- from torch import nn
22
- from .common import LayerNorm2d_op
23
-
24
-
25
- class CNetResBlock(nn.Module):
26
- def __init__(self, c, dtype=None, device=None, operations=None):
27
- super().__init__()
28
- self.blocks = nn.Sequential(
29
- LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
30
- nn.GELU(),
31
- operations.Conv2d(c, c, kernel_size=3, padding=1),
32
- LayerNorm2d_op(operations)(c, dtype=dtype, device=device),
33
- nn.GELU(),
34
- operations.Conv2d(c, c, kernel_size=3, padding=1),
35
- )
36
-
37
- def forward(self, x):
38
- return x + self.blocks(x)
39
-
40
-
41
- class ControlNet(nn.Module):
42
- def __init__(self, c_in=3, c_proj=2048, proj_blocks=None, bottleneck_mode=None, dtype=None, device=None, operations=nn):
43
- super().__init__()
44
- if bottleneck_mode is None:
45
- bottleneck_mode = 'effnet'
46
- self.proj_blocks = proj_blocks
47
- if bottleneck_mode == 'effnet':
48
- embd_channels = 1280
49
- self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
50
- if c_in != 3:
51
- in_weights = self.backbone[0][0].weight.data
52
- self.backbone[0][0] = operations.Conv2d(c_in, 24, kernel_size=3, stride=2, bias=False, dtype=dtype, device=device)
53
- if c_in > 3:
54
- # nn.init.constant_(self.backbone[0][0].weight, 0)
55
- self.backbone[0][0].weight.data[:, :3] = in_weights[:, :3].clone()
56
- else:
57
- self.backbone[0][0].weight.data = in_weights[:, :c_in].clone()
58
- elif bottleneck_mode == 'simple':
59
- embd_channels = c_in
60
- self.backbone = nn.Sequential(
61
- operations.Conv2d(embd_channels, embd_channels * 4, kernel_size=3, padding=1, dtype=dtype, device=device),
62
- nn.LeakyReLU(0.2, inplace=True),
63
- operations.Conv2d(embd_channels * 4, embd_channels, kernel_size=3, padding=1, dtype=dtype, device=device),
64
- )
65
- elif bottleneck_mode == 'large':
66
- self.backbone = nn.Sequential(
67
- operations.Conv2d(c_in, 4096 * 4, kernel_size=1, dtype=dtype, device=device),
68
- nn.LeakyReLU(0.2, inplace=True),
69
- operations.Conv2d(4096 * 4, 1024, kernel_size=1, dtype=dtype, device=device),
70
- *[CNetResBlock(1024, dtype=dtype, device=device, operations=operations) for _ in range(8)],
71
- operations.Conv2d(1024, 1280, kernel_size=1, dtype=dtype, device=device),
72
- )
73
- embd_channels = 1280
74
- else:
75
- raise ValueError(f'Unknown bottleneck mode: {bottleneck_mode}')
76
- self.projections = nn.ModuleList()
77
- for _ in range(len(proj_blocks)):
78
- self.projections.append(nn.Sequential(
79
- operations.Conv2d(embd_channels, embd_channels, kernel_size=1, bias=False, dtype=dtype, device=device),
80
- nn.LeakyReLU(0.2, inplace=True),
81
- operations.Conv2d(embd_channels, c_proj, kernel_size=1, bias=False, dtype=dtype, device=device),
82
- ))
83
- # nn.init.constant_(self.projections[-1][-1].weight, 0) # zero output projection
84
- self.xl = False
85
- self.input_channels = c_in
86
- self.unshuffle_amount = 8
87
-
88
- def forward(self, x):
89
- x = self.backbone(x)
90
- proj_outputs = [None for _ in range(max(self.proj_blocks) + 1)]
91
- for i, idx in enumerate(self.proj_blocks):
92
- proj_outputs[idx] = self.projections[i](x)
93
- return proj_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MagicQuill/comfy/ldm/cascade/stage_a.py DELETED
@@ -1,255 +0,0 @@
1
- """
2
- This file is part of ComfyUI.
3
- Copyright (C) 2024 Stability AI
4
-
5
- This program is free software: you can redistribute it and/or modify
6
- it under the terms of the GNU General Public License as published by
7
- the Free Software Foundation, either version 3 of the License, or
8
- (at your option) any later version.
9
-
10
- This program is distributed in the hope that it will be useful,
11
- but WITHOUT ANY WARRANTY; without even the implied warranty of
12
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
- GNU General Public License for more details.
14
-
15
- You should have received a copy of the GNU General Public License
16
- along with this program. If not, see <https://www.gnu.org/licenses/>.
17
- """
18
-
19
- import torch
20
- from torch import nn
21
- from torch.autograd import Function
22
-
23
- class vector_quantize(Function):
24
- @staticmethod
25
- def forward(ctx, x, codebook):
26
- with torch.no_grad():
27
- codebook_sqr = torch.sum(codebook ** 2, dim=1)
28
- x_sqr = torch.sum(x ** 2, dim=1, keepdim=True)
29
-
30
- dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0)
31
- _, indices = dist.min(dim=1)
32
-
33
- ctx.save_for_backward(indices, codebook)
34
- ctx.mark_non_differentiable(indices)
35
-
36
- nn = torch.index_select(codebook, 0, indices)
37
- return nn, indices
38
-
39
- @staticmethod
40
- def backward(ctx, grad_output, grad_indices):
41
- grad_inputs, grad_codebook = None, None
42
-
43
- if ctx.needs_input_grad[0]:
44
- grad_inputs = grad_output.clone()
45
- if ctx.needs_input_grad[1]:
46
- # Gradient wrt. the codebook
47
- indices, codebook = ctx.saved_tensors
48
-
49
- grad_codebook = torch.zeros_like(codebook)
50
- grad_codebook.index_add_(0, indices, grad_output)
51
-
52
- return (grad_inputs, grad_codebook)
53
-
54
-
55
- class VectorQuantize(nn.Module):
56
- def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False):
57
- """
58
- Takes an input of variable size (as long as the last dimension matches the embedding size).
59
- Returns one tensor containing the nearest neigbour embeddings to each of the inputs,
60
- with the same size as the input, vq and commitment components for the loss as a touple
61
- in the second output and the indices of the quantized vectors in the third:
62
- quantized, (vq_loss, commit_loss), indices
63
- """
64
- super(VectorQuantize, self).__init__()
65
-
66
- self.codebook = nn.Embedding(k, embedding_size)
67
- self.codebook.weight.data.uniform_(-1./k, 1./k)
68
- self.vq = vector_quantize.apply
69
-
70
- self.ema_decay = ema_decay
71
- self.ema_loss = ema_loss
72
- if ema_loss:
73
- self.register_buffer('ema_element_count', torch.ones(k))
74
- self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight))
75
-
76
- def _laplace_smoothing(self, x, epsilon):
77
- n = torch.sum(x)
78
- return ((x + epsilon) / (n + x.size(0) * epsilon) * n)
79
-
80
- def _updateEMA(self, z_e_x, indices):
81
- mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float()
82
- elem_count = mask.sum(dim=0)
83
- weight_sum = torch.mm(mask.t(), z_e_x)
84
-
85
- self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count)
86
- self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5)
87
- self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum)
88
-
89
- self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1)
90
-
91
- def idx2vq(self, idx, dim=-1):
92
- q_idx = self.codebook(idx)
93
- if dim != -1:
94
- q_idx = q_idx.movedim(-1, dim)
95
- return q_idx
96
-
97
- def forward(self, x, get_losses=True, dim=-1):
98
- if dim != -1:
99
- x = x.movedim(dim, -1)
100
- z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x
101
- z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach())
102
- vq_loss, commit_loss = None, None
103
- if self.ema_loss and self.training:
104
- self._updateEMA(z_e_x.detach(), indices.detach())
105
- # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss
106
- z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices)
107
- if get_losses:
108
- vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean()
109
- commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean()
110
-
111
- z_q_x = z_q_x.view(x.shape)
112
- if dim != -1:
113
- z_q_x = z_q_x.movedim(-1, dim)
114
- return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1])
115
-
116
-
117
- class ResBlock(nn.Module):
118
- def __init__(self, c, c_hidden):
119
- super().__init__()
120
- # depthwise/attention
121
- self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
122
- self.depthwise = nn.Sequential(
123
- nn.ReplicationPad2d(1),
124
- nn.Conv2d(c, c, kernel_size=3, groups=c)
125
- )
126
-
127
- # channelwise
128
- self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
129
- self.channelwise = nn.Sequential(
130
- nn.Linear(c, c_hidden),
131
- nn.GELU(),
132
- nn.Linear(c_hidden, c),
133
- )
134
-
135
- self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
136
-
137
- # Init weights
138
- def _basic_init(module):
139
- if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
140
- torch.nn.init.xavier_uniform_(module.weight)
141
- if module.bias is not None:
142
- nn.init.constant_(module.bias, 0)
143
-
144
- self.apply(_basic_init)
145
-
146
- def _norm(self, x, norm):
147
- return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
148
-
149
- def forward(self, x):
150
- mods = self.gammas
151
-
152
- x_temp = self._norm(x, self.norm1) * (1 + mods[0]) + mods[1]
153
- try:
154
- x = x + self.depthwise(x_temp) * mods[2]
155
- except: #operation not implemented for bf16
156
- x_temp = self.depthwise[0](x_temp.float()).to(x.dtype)
157
- x = x + self.depthwise[1](x_temp) * mods[2]
158
-
159
- x_temp = self._norm(x, self.norm2) * (1 + mods[3]) + mods[4]
160
- x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
161
-
162
- return x
163
-
164
-
165
- class StageA(nn.Module):
166
- def __init__(self, levels=2, bottleneck_blocks=12, c_hidden=384, c_latent=4, codebook_size=8192):
167
- super().__init__()
168
- self.c_latent = c_latent
169
- c_levels = [c_hidden // (2 ** i) for i in reversed(range(levels))]
170
-
171
- # Encoder blocks
172
- self.in_block = nn.Sequential(
173
- nn.PixelUnshuffle(2),
174
- nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
175
- )
176
- down_blocks = []
177
- for i in range(levels):
178
- if i > 0:
179
- down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
180
- block = ResBlock(c_levels[i], c_levels[i] * 4)
181
- down_blocks.append(block)
182
- down_blocks.append(nn.Sequential(
183
- nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
184
- nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
185
- ))
186
- self.down_blocks = nn.Sequential(*down_blocks)
187
- self.down_blocks[0]
188
-
189
- self.codebook_size = codebook_size
190
- self.vquantizer = VectorQuantize(c_latent, k=codebook_size)
191
-
192
- # Decoder blocks
193
- up_blocks = [nn.Sequential(
194
- nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
195
- )]
196
- for i in range(levels):
197
- for j in range(bottleneck_blocks if i == 0 else 1):
198
- block = ResBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
199
- up_blocks.append(block)
200
- if i < levels - 1:
201
- up_blocks.append(
202
- nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
203
- padding=1))
204
- self.up_blocks = nn.Sequential(*up_blocks)
205
- self.out_block = nn.Sequential(
206
- nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
207
- nn.PixelShuffle(2),
208
- )
209
-
210
- def encode(self, x, quantize=False):
211
- x = self.in_block(x)
212
- x = self.down_blocks(x)
213
- if quantize:
214
- qe, (vq_loss, commit_loss), indices = self.vquantizer.forward(x, dim=1)
215
- return qe, x, indices, vq_loss + commit_loss * 0.25
216
- else:
217
- return x
218
-
219
- def decode(self, x):
220
- x = self.up_blocks(x)
221
- x = self.out_block(x)
222
- return x
223
-
224
- def forward(self, x, quantize=False):
225
- qe, x, _, vq_loss = self.encode(x, quantize)
226
- x = self.decode(qe)
227
- return x, vq_loss
228
-
229
-
230
- class Discriminator(nn.Module):
231
- def __init__(self, c_in=3, c_cond=0, c_hidden=512, depth=6):
232
- super().__init__()
233
- d = max(depth - 3, 3)
234
- layers = [
235
- nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
236
- nn.LeakyReLU(0.2),
237
- ]
238
- for i in range(depth - 1):
239
- c_in = c_hidden // (2 ** max((d - i), 0))
240
- c_out = c_hidden // (2 ** max((d - 1 - i), 0))
241
- layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
242
- layers.append(nn.InstanceNorm2d(c_out))
243
- layers.append(nn.LeakyReLU(0.2))
244
- self.encoder = nn.Sequential(*layers)
245
- self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
246
- self.logits = nn.Sigmoid()
247
-
248
- def forward(self, x, cond=None):
249
- x = self.encoder(x)
250
- if cond is not None:
251
- cond = cond.view(cond.size(0), cond.size(1), 1, 1, ).expand(-1, -1, x.size(-2), x.size(-1))
252
- x = torch.cat([x, cond], dim=1)
253
- x = self.shuffle(x)
254
- x = self.logits(x)
255
- return x