Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from math import gcd | |
from typing import Any, Dict, List, Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import Tensor, nn | |
from ..configuration_utils import ConfigMixin, register_to_config | |
from ..utils import BaseOutput, is_torch_version, logging | |
from ..utils.torch_utils import apply_freeu | |
from .attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
Attention, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
FusedAttnProcessor2_0, | |
) | |
from .controlnet import ControlNetConditioningEmbedding | |
from .embeddings import TimestepEmbedding, Timesteps | |
from .modeling_utils import ModelMixin | |
from .unets.unet_2d_blocks import ( | |
CrossAttnDownBlock2D, | |
CrossAttnUpBlock2D, | |
Downsample2D, | |
ResnetBlock2D, | |
Transformer2DModel, | |
UNetMidBlock2DCrossAttn, | |
Upsample2D, | |
) | |
from .unets.unet_2d_condition import UNet2DConditionModel | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class ControlNetXSOutput(BaseOutput): | |
""" | |
The output of [`UNetControlNetXSModel`]. | |
Args: | |
sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): | |
The output of the `UNetControlNetXSModel`. Unlike `ControlNetOutput` this is NOT to be added to the base | |
model output, but is already the final output. | |
""" | |
sample: Tensor = None | |
class DownBlockControlNetXSAdapter(nn.Module): | |
"""Components that together with corresponding components from the base model will form a | |
`ControlNetXSCrossAttnDownBlock2D`""" | |
def __init__( | |
self, | |
resnets: nn.ModuleList, | |
base_to_ctrl: nn.ModuleList, | |
ctrl_to_base: nn.ModuleList, | |
attentions: Optional[nn.ModuleList] = None, | |
downsampler: Optional[nn.Conv2d] = None, | |
): | |
super().__init__() | |
self.resnets = resnets | |
self.base_to_ctrl = base_to_ctrl | |
self.ctrl_to_base = ctrl_to_base | |
self.attentions = attentions | |
self.downsamplers = downsampler | |
class MidBlockControlNetXSAdapter(nn.Module): | |
"""Components that together with corresponding components from the base model will form a | |
`ControlNetXSCrossAttnMidBlock2D`""" | |
def __init__(self, midblock: UNetMidBlock2DCrossAttn, base_to_ctrl: nn.ModuleList, ctrl_to_base: nn.ModuleList): | |
super().__init__() | |
self.midblock = midblock | |
self.base_to_ctrl = base_to_ctrl | |
self.ctrl_to_base = ctrl_to_base | |
class UpBlockControlNetXSAdapter(nn.Module): | |
"""Components that together with corresponding components from the base model will form a `ControlNetXSCrossAttnUpBlock2D`""" | |
def __init__(self, ctrl_to_base: nn.ModuleList): | |
super().__init__() | |
self.ctrl_to_base = ctrl_to_base | |
def get_down_block_adapter( | |
base_in_channels: int, | |
base_out_channels: int, | |
ctrl_in_channels: int, | |
ctrl_out_channels: int, | |
temb_channels: int, | |
max_norm_num_groups: Optional[int] = 32, | |
has_crossattn=True, | |
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, | |
num_attention_heads: Optional[int] = 1, | |
cross_attention_dim: Optional[int] = 1024, | |
add_downsample: bool = True, | |
upcast_attention: Optional[bool] = False, | |
use_linear_projection: Optional[bool] = True, | |
): | |
num_layers = 2 # only support sd + sdxl | |
resnets = [] | |
attentions = [] | |
ctrl_to_base = [] | |
base_to_ctrl = [] | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
base_in_channels = base_in_channels if i == 0 else base_out_channels | |
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels | |
# Before the resnet/attention application, information is concatted from base to control. | |
# Concat doesn't require change in number of channels | |
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl | |
out_channels=ctrl_out_channels, | |
temb_channels=temb_channels, | |
groups=find_largest_factor(ctrl_in_channels + base_in_channels, max_factor=max_norm_num_groups), | |
groups_out=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), | |
eps=1e-5, | |
) | |
) | |
if has_crossattn: | |
attentions.append( | |
Transformer2DModel( | |
num_attention_heads, | |
ctrl_out_channels // num_attention_heads, | |
in_channels=ctrl_out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=max_norm_num_groups), | |
) | |
) | |
# After the resnet/attention application, information is added from control to base | |
# Addition requires change in number of channels | |
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) | |
if add_downsample: | |
# Before the downsampler application, information is concatted from base to control | |
# Concat doesn't require change in number of channels | |
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) | |
downsamplers = Downsample2D( | |
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" | |
) | |
# After the downsampler application, information is added from control to base | |
# Addition requires change in number of channels | |
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) | |
else: | |
downsamplers = None | |
down_block_components = DownBlockControlNetXSAdapter( | |
resnets=nn.ModuleList(resnets), | |
base_to_ctrl=nn.ModuleList(base_to_ctrl), | |
ctrl_to_base=nn.ModuleList(ctrl_to_base), | |
) | |
if has_crossattn: | |
down_block_components.attentions = nn.ModuleList(attentions) | |
if downsamplers is not None: | |
down_block_components.downsamplers = downsamplers | |
return down_block_components | |
def get_mid_block_adapter( | |
base_channels: int, | |
ctrl_channels: int, | |
temb_channels: Optional[int] = None, | |
max_norm_num_groups: Optional[int] = 32, | |
transformer_layers_per_block: int = 1, | |
num_attention_heads: Optional[int] = 1, | |
cross_attention_dim: Optional[int] = 1024, | |
upcast_attention: bool = False, | |
use_linear_projection: bool = True, | |
): | |
# Before the midblock application, information is concatted from base to control. | |
# Concat doesn't require change in number of channels | |
base_to_ctrl = make_zero_conv(base_channels, base_channels) | |
midblock = UNetMidBlock2DCrossAttn( | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=ctrl_channels + base_channels, | |
out_channels=ctrl_channels, | |
temb_channels=temb_channels, | |
# number or norm groups must divide both in_channels and out_channels | |
resnet_groups=find_largest_factor(gcd(ctrl_channels, ctrl_channels + base_channels), max_norm_num_groups), | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=num_attention_heads, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
) | |
# After the midblock application, information is added from control to base | |
# Addition requires change in number of channels | |
ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) | |
return MidBlockControlNetXSAdapter(base_to_ctrl=base_to_ctrl, midblock=midblock, ctrl_to_base=ctrl_to_base) | |
def get_up_block_adapter( | |
out_channels: int, | |
prev_output_channel: int, | |
ctrl_skip_channels: List[int], | |
): | |
ctrl_to_base = [] | |
num_layers = 3 # only support sd + sdxl | |
for i in range(num_layers): | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) | |
return UpBlockControlNetXSAdapter(ctrl_to_base=nn.ModuleList(ctrl_to_base)) | |
class ControlNetXSAdapter(ModelMixin, ConfigMixin): | |
r""" | |
A `ControlNetXSAdapter` model. To use it, pass it into a `UNetControlNetXSModel` (together with a | |
`UNet2DConditionModel` base model). | |
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic | |
methods implemented for all models (such as downloading or saving). | |
Like `UNetControlNetXSModel`, `ControlNetXSAdapter` is compatible with StableDiffusion and StableDiffusion-XL. It's | |
default parameters are compatible with StableDiffusion. | |
Parameters: | |
conditioning_channels (`int`, defaults to 3): | |
Number of channels of conditioning input (e.g. an image) | |
conditioning_channel_order (`str`, defaults to `"rgb"`): | |
The channel order of conditional image. Will convert to `rgb` if it's `bgr`. | |
conditioning_embedding_out_channels (`tuple[int]`, defaults to `(16, 32, 96, 256)`): | |
The tuple of output channels for each block in the `controlnet_cond_embedding` layer. | |
time_embedding_mix (`float`, defaults to 1.0): | |
If 0, then only the control adapters's time embedding is used. If 1, then only the base unet's time | |
embedding is used. Otherwise, both are combined. | |
learn_time_embedding (`bool`, defaults to `False`): | |
Whether a time embedding should be learned. If yes, `UNetControlNetXSModel` will combine the time | |
embeddings of the base model and the control adapter. If no, `UNetControlNetXSModel` will use the base | |
model's time embedding. | |
num_attention_heads (`list[int]`, defaults to `[4]`): | |
The number of attention heads. | |
block_out_channels (`list[int]`, defaults to `[4, 8, 16, 16]`): | |
The tuple of output channels for each block. | |
base_block_out_channels (`list[int]`, defaults to `[320, 640, 1280, 1280]`): | |
The tuple of output channels for each block in the base unet. | |
cross_attention_dim (`int`, defaults to 1024): | |
The dimension of the cross attention features. | |
down_block_types (`list[str]`, defaults to `["CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D"]`): | |
The tuple of downsample blocks to use. | |
sample_size (`int`, defaults to 96): | |
Height and width of input/output sample. | |
transformer_layers_per_block (`Union[int, Tuple[int]]`, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. | |
upcast_attention (`bool`, defaults to `True`): | |
Whether the attention computation should always be upcasted. | |
max_norm_num_groups (`int`, defaults to 32): | |
Maximum number of groups in group normal. The actual number will be the largest divisor of the respective | |
channels, that is <= max_norm_num_groups. | |
""" | |
def __init__( | |
self, | |
conditioning_channels: int = 3, | |
conditioning_channel_order: str = "rgb", | |
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), | |
time_embedding_mix: float = 1.0, | |
learn_time_embedding: bool = False, | |
num_attention_heads: Union[int, Tuple[int]] = 4, | |
block_out_channels: Tuple[int] = (4, 8, 16, 16), | |
base_block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
cross_attention_dim: int = 1024, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
), | |
sample_size: Optional[int] = 96, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
upcast_attention: bool = True, | |
max_norm_num_groups: int = 32, | |
use_linear_projection: bool = True, | |
): | |
super().__init__() | |
time_embedding_input_dim = base_block_out_channels[0] | |
time_embedding_dim = base_block_out_channels[0] * 4 | |
# Check inputs | |
if conditioning_channel_order not in ["rgb", "bgr"]: | |
raise ValueError(f"unknown `conditioning_channel_order`: {conditioning_channel_order}") | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(transformer_layers_per_block, (list, tuple)): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
if not isinstance(cross_attention_dim, (list, tuple)): | |
cross_attention_dim = [cross_attention_dim] * len(down_block_types) | |
# see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why `ControlNetXSAdapter` takes `num_attention_heads` instead of `attention_head_dim` | |
if not isinstance(num_attention_heads, (list, tuple)): | |
num_attention_heads = [num_attention_heads] * len(down_block_types) | |
if len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
# 5 - Create conditioning hint embedding | |
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( | |
conditioning_embedding_channels=block_out_channels[0], | |
block_out_channels=conditioning_embedding_out_channels, | |
conditioning_channels=conditioning_channels, | |
) | |
# time | |
if learn_time_embedding: | |
self.time_embedding = TimestepEmbedding(time_embedding_input_dim, time_embedding_dim) | |
else: | |
self.time_embedding = None | |
self.down_blocks = nn.ModuleList([]) | |
self.up_connections = nn.ModuleList([]) | |
# input | |
self.conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) | |
self.control_to_base_for_conv_in = make_zero_conv(block_out_channels[0], base_block_out_channels[0]) | |
# down | |
base_out_channels = base_block_out_channels[0] | |
ctrl_out_channels = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
base_in_channels = base_out_channels | |
base_out_channels = base_block_out_channels[i] | |
ctrl_in_channels = ctrl_out_channels | |
ctrl_out_channels = block_out_channels[i] | |
has_crossattn = "CrossAttn" in down_block_type | |
is_final_block = i == len(down_block_types) - 1 | |
self.down_blocks.append( | |
get_down_block_adapter( | |
base_in_channels=base_in_channels, | |
base_out_channels=base_out_channels, | |
ctrl_in_channels=ctrl_in_channels, | |
ctrl_out_channels=ctrl_out_channels, | |
temb_channels=time_embedding_dim, | |
max_norm_num_groups=max_norm_num_groups, | |
has_crossattn=has_crossattn, | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
num_attention_heads=num_attention_heads[i], | |
cross_attention_dim=cross_attention_dim[i], | |
add_downsample=not is_final_block, | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
) | |
# mid | |
self.mid_block = get_mid_block_adapter( | |
base_channels=base_block_out_channels[-1], | |
ctrl_channels=block_out_channels[-1], | |
temb_channels=time_embedding_dim, | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
num_attention_heads=num_attention_heads[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
# up | |
# The skip connection channels are the output of the conv_in and of all the down subblocks | |
ctrl_skip_channels = [block_out_channels[0]] | |
for i, out_channels in enumerate(block_out_channels): | |
number_of_subblocks = ( | |
3 if i < len(block_out_channels) - 1 else 2 | |
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler | |
ctrl_skip_channels.extend([out_channels] * number_of_subblocks) | |
reversed_base_block_out_channels = list(reversed(base_block_out_channels)) | |
base_out_channels = reversed_base_block_out_channels[0] | |
for i in range(len(down_block_types)): | |
prev_base_output_channel = base_out_channels | |
base_out_channels = reversed_base_block_out_channels[i] | |
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] | |
self.up_connections.append( | |
get_up_block_adapter( | |
out_channels=base_out_channels, | |
prev_output_channel=prev_base_output_channel, | |
ctrl_skip_channels=ctrl_skip_channels_, | |
) | |
) | |
def from_unet( | |
cls, | |
unet: UNet2DConditionModel, | |
size_ratio: Optional[float] = None, | |
block_out_channels: Optional[List[int]] = None, | |
num_attention_heads: Optional[List[int]] = None, | |
learn_time_embedding: bool = False, | |
time_embedding_mix: int = 1.0, | |
conditioning_channels: int = 3, | |
conditioning_channel_order: str = "rgb", | |
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), | |
): | |
r""" | |
Instantiate a [`ControlNetXSAdapter`] from a [`UNet2DConditionModel`]. | |
Parameters: | |
unet (`UNet2DConditionModel`): | |
The UNet model we want to control. The dimensions of the ControlNetXSAdapter will be adapted to it. | |
size_ratio (float, *optional*, defaults to `None`): | |
When given, block_out_channels is set to a fraction of the base model's block_out_channels. Either this | |
or `block_out_channels` must be given. | |
block_out_channels (`List[int]`, *optional*, defaults to `None`): | |
Down blocks output channels in control model. Either this or `size_ratio` must be given. | |
num_attention_heads (`List[int]`, *optional*, defaults to `None`): | |
The dimension of the attention heads. The naming seems a bit confusing and it is, see | |
https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. | |
learn_time_embedding (`bool`, defaults to `False`): | |
Whether the `ControlNetXSAdapter` should learn a time embedding. | |
time_embedding_mix (`float`, defaults to 1.0): | |
If 0, then only the control adapter's time embedding is used. If 1, then only the base unet's time | |
embedding is used. Otherwise, both are combined. | |
conditioning_channels (`int`, defaults to 3): | |
Number of channels of conditioning input (e.g. an image) | |
conditioning_channel_order (`str`, defaults to `"rgb"`): | |
The channel order of conditional image. Will convert to `rgb` if it's `bgr`. | |
conditioning_embedding_out_channels (`Tuple[int]`, defaults to `(16, 32, 96, 256)`): | |
The tuple of output channel for each block in the `controlnet_cond_embedding` layer. | |
""" | |
# Check input | |
fixed_size = block_out_channels is not None | |
relative_size = size_ratio is not None | |
if not (fixed_size ^ relative_size): | |
raise ValueError( | |
"Pass exactly one of `block_out_channels` (for absolute sizing) or `size_ratio` (for relative sizing)." | |
) | |
# Create model | |
block_out_channels = block_out_channels or [int(b * size_ratio) for b in unet.config.block_out_channels] | |
if num_attention_heads is None: | |
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. | |
num_attention_heads = unet.config.attention_head_dim | |
model = cls( | |
conditioning_channels=conditioning_channels, | |
conditioning_channel_order=conditioning_channel_order, | |
conditioning_embedding_out_channels=conditioning_embedding_out_channels, | |
time_embedding_mix=time_embedding_mix, | |
learn_time_embedding=learn_time_embedding, | |
num_attention_heads=num_attention_heads, | |
block_out_channels=block_out_channels, | |
base_block_out_channels=unet.config.block_out_channels, | |
cross_attention_dim=unet.config.cross_attention_dim, | |
down_block_types=unet.config.down_block_types, | |
sample_size=unet.config.sample_size, | |
transformer_layers_per_block=unet.config.transformer_layers_per_block, | |
upcast_attention=unet.config.upcast_attention, | |
max_norm_num_groups=unet.config.norm_num_groups, | |
use_linear_projection=unet.config.use_linear_projection, | |
) | |
# ensure that the ControlNetXSAdapter is the same dtype as the UNet2DConditionModel | |
model.to(unet.dtype) | |
return model | |
def forward(self, *args, **kwargs): | |
raise ValueError( | |
"A ControlNetXSAdapter cannot be run by itself. Use it together with a UNet2DConditionModel to instantiate a UNetControlNetXSModel." | |
) | |
class UNetControlNetXSModel(ModelMixin, ConfigMixin): | |
r""" | |
A UNet fused with a ControlNet-XS adapter model | |
This model inherits from [`ModelMixin`] and [`ConfigMixin`]. Check the superclass documentation for it's generic | |
methods implemented for all models (such as downloading or saving). | |
`UNetControlNetXSModel` is compatible with StableDiffusion and StableDiffusion-XL. It's default parameters are | |
compatible with StableDiffusion. | |
It's parameters are either passed to the underlying `UNet2DConditionModel` or used exactly like in | |
`ControlNetXSAdapter` . See their documentation for details. | |
""" | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
# unet configs | |
sample_size: Optional[int] = 96, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"CrossAttnDownBlock2D", | |
"DownBlock2D", | |
), | |
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
norm_num_groups: Optional[int] = 32, | |
cross_attention_dim: Union[int, Tuple[int]] = 1024, | |
transformer_layers_per_block: Union[int, Tuple[int]] = 1, | |
num_attention_heads: Union[int, Tuple[int]] = 8, | |
addition_embed_type: Optional[str] = None, | |
addition_time_embed_dim: Optional[int] = None, | |
upcast_attention: bool = True, | |
use_linear_projection: bool = True, | |
time_cond_proj_dim: Optional[int] = None, | |
projection_class_embeddings_input_dim: Optional[int] = None, | |
# additional controlnet configs | |
time_embedding_mix: float = 1.0, | |
ctrl_conditioning_channels: int = 3, | |
ctrl_conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256), | |
ctrl_conditioning_channel_order: str = "rgb", | |
ctrl_learn_time_embedding: bool = False, | |
ctrl_block_out_channels: Tuple[int] = (4, 8, 16, 16), | |
ctrl_num_attention_heads: Union[int, Tuple[int]] = 4, | |
ctrl_max_norm_num_groups: int = 32, | |
): | |
super().__init__() | |
if time_embedding_mix < 0 or time_embedding_mix > 1: | |
raise ValueError("`time_embedding_mix` needs to be between 0 and 1.") | |
if time_embedding_mix < 1 and not ctrl_learn_time_embedding: | |
raise ValueError("To use `time_embedding_mix` < 1, `ctrl_learn_time_embedding` must be `True`") | |
if addition_embed_type is not None and addition_embed_type != "text_time": | |
raise ValueError( | |
"As `UNetControlNetXSModel` currently only supports StableDiffusion and StableDiffusion-XL, `addition_embed_type` must be `None` or `'text_time'`." | |
) | |
if not isinstance(transformer_layers_per_block, (list, tuple)): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
if not isinstance(cross_attention_dim, (list, tuple)): | |
cross_attention_dim = [cross_attention_dim] * len(down_block_types) | |
if not isinstance(num_attention_heads, (list, tuple)): | |
num_attention_heads = [num_attention_heads] * len(down_block_types) | |
if not isinstance(ctrl_num_attention_heads, (list, tuple)): | |
ctrl_num_attention_heads = [ctrl_num_attention_heads] * len(down_block_types) | |
base_num_attention_heads = num_attention_heads | |
self.in_channels = 4 | |
# # Input | |
self.base_conv_in = nn.Conv2d(4, block_out_channels[0], kernel_size=3, padding=1) | |
self.controlnet_cond_embedding = ControlNetConditioningEmbedding( | |
conditioning_embedding_channels=ctrl_block_out_channels[0], | |
block_out_channels=ctrl_conditioning_embedding_out_channels, | |
conditioning_channels=ctrl_conditioning_channels, | |
) | |
self.ctrl_conv_in = nn.Conv2d(4, ctrl_block_out_channels[0], kernel_size=3, padding=1) | |
self.control_to_base_for_conv_in = make_zero_conv(ctrl_block_out_channels[0], block_out_channels[0]) | |
# # Time | |
time_embed_input_dim = block_out_channels[0] | |
time_embed_dim = block_out_channels[0] * 4 | |
self.base_time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.base_time_embedding = TimestepEmbedding( | |
time_embed_input_dim, | |
time_embed_dim, | |
cond_proj_dim=time_cond_proj_dim, | |
) | |
if ctrl_learn_time_embedding: | |
self.ctrl_time_embedding = TimestepEmbedding( | |
in_channels=time_embed_input_dim, time_embed_dim=time_embed_dim | |
) | |
else: | |
self.ctrl_time_embedding = None | |
if addition_embed_type is None: | |
self.base_add_time_proj = None | |
self.base_add_embedding = None | |
else: | |
self.base_add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.base_add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
# # Create down blocks | |
down_blocks = [] | |
base_out_channels = block_out_channels[0] | |
ctrl_out_channels = ctrl_block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
base_in_channels = base_out_channels | |
base_out_channels = block_out_channels[i] | |
ctrl_in_channels = ctrl_out_channels | |
ctrl_out_channels = ctrl_block_out_channels[i] | |
has_crossattn = "CrossAttn" in down_block_type | |
is_final_block = i == len(down_block_types) - 1 | |
down_blocks.append( | |
ControlNetXSCrossAttnDownBlock2D( | |
base_in_channels=base_in_channels, | |
base_out_channels=base_out_channels, | |
ctrl_in_channels=ctrl_in_channels, | |
ctrl_out_channels=ctrl_out_channels, | |
temb_channels=time_embed_dim, | |
norm_num_groups=norm_num_groups, | |
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, | |
has_crossattn=has_crossattn, | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
base_num_attention_heads=base_num_attention_heads[i], | |
ctrl_num_attention_heads=ctrl_num_attention_heads[i], | |
cross_attention_dim=cross_attention_dim[i], | |
add_downsample=not is_final_block, | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
) | |
# # Create mid block | |
self.mid_block = ControlNetXSCrossAttnMidBlock2D( | |
base_channels=block_out_channels[-1], | |
ctrl_channels=ctrl_block_out_channels[-1], | |
temb_channels=time_embed_dim, | |
norm_num_groups=norm_num_groups, | |
ctrl_max_norm_num_groups=ctrl_max_norm_num_groups, | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
base_num_attention_heads=base_num_attention_heads[-1], | |
ctrl_num_attention_heads=ctrl_num_attention_heads[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
# # Create up blocks | |
up_blocks = [] | |
rev_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | |
rev_num_attention_heads = list(reversed(base_num_attention_heads)) | |
rev_cross_attention_dim = list(reversed(cross_attention_dim)) | |
# The skip connection channels are the output of the conv_in and of all the down subblocks | |
ctrl_skip_channels = [ctrl_block_out_channels[0]] | |
for i, out_channels in enumerate(ctrl_block_out_channels): | |
number_of_subblocks = ( | |
3 if i < len(ctrl_block_out_channels) - 1 else 2 | |
) # every block has 3 subblocks, except last one, which has 2 as it has no downsampler | |
ctrl_skip_channels.extend([out_channels] * number_of_subblocks) | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
out_channels = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
prev_output_channel = out_channels | |
out_channels = reversed_block_out_channels[i] | |
in_channels = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
ctrl_skip_channels_ = [ctrl_skip_channels.pop() for _ in range(3)] | |
has_crossattn = "CrossAttn" in up_block_type | |
is_final_block = i == len(block_out_channels) - 1 | |
up_blocks.append( | |
ControlNetXSCrossAttnUpBlock2D( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channel, | |
ctrl_skip_channels=ctrl_skip_channels_, | |
temb_channels=time_embed_dim, | |
resolution_idx=i, | |
has_crossattn=has_crossattn, | |
transformer_layers_per_block=rev_transformer_layers_per_block[i], | |
num_attention_heads=rev_num_attention_heads[i], | |
cross_attention_dim=rev_cross_attention_dim[i], | |
add_upsample=not is_final_block, | |
upcast_attention=upcast_attention, | |
norm_num_groups=norm_num_groups, | |
use_linear_projection=use_linear_projection, | |
) | |
) | |
self.down_blocks = nn.ModuleList(down_blocks) | |
self.up_blocks = nn.ModuleList(up_blocks) | |
self.base_conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups) | |
self.base_conv_act = nn.SiLU() | |
self.base_conv_out = nn.Conv2d(block_out_channels[0], 4, kernel_size=3, padding=1) | |
def from_unet( | |
cls, | |
unet: UNet2DConditionModel, | |
controlnet: Optional[ControlNetXSAdapter] = None, | |
size_ratio: Optional[float] = None, | |
ctrl_block_out_channels: Optional[List[float]] = None, | |
time_embedding_mix: Optional[float] = None, | |
ctrl_optional_kwargs: Optional[Dict] = None, | |
): | |
r""" | |
Instantiate a [`UNetControlNetXSModel`] from a [`UNet2DConditionModel`] and an optional [`ControlNetXSAdapter`] | |
. | |
Parameters: | |
unet (`UNet2DConditionModel`): | |
The UNet model we want to control. | |
controlnet (`ControlNetXSAdapter`): | |
The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS | |
adapter will be created. | |
size_ratio (float, *optional*, defaults to `None`): | |
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. | |
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): | |
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, | |
where this parameter is called `block_out_channels`. | |
time_embedding_mix (`float`, *optional*, defaults to None): | |
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. | |
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): | |
Passed to the `init` of the new controlent if no controlent was given. | |
""" | |
if controlnet is None: | |
controlnet = ControlNetXSAdapter.from_unet( | |
unet, size_ratio, ctrl_block_out_channels, **ctrl_optional_kwargs | |
) | |
else: | |
if any( | |
o is not None for o in (size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs) | |
): | |
raise ValueError( | |
"When a controlnet is passed, none of these parameters should be passed: size_ratio, ctrl_block_out_channels, time_embedding_mix, ctrl_optional_kwargs." | |
) | |
# # get params | |
params_for_unet = [ | |
"sample_size", | |
"down_block_types", | |
"up_block_types", | |
"block_out_channels", | |
"norm_num_groups", | |
"cross_attention_dim", | |
"transformer_layers_per_block", | |
"addition_embed_type", | |
"addition_time_embed_dim", | |
"upcast_attention", | |
"use_linear_projection", | |
"time_cond_proj_dim", | |
"projection_class_embeddings_input_dim", | |
] | |
params_for_unet = {k: v for k, v in unet.config.items() if k in params_for_unet} | |
# The naming seems a bit confusing and it is, see https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 for why. | |
params_for_unet["num_attention_heads"] = unet.config.attention_head_dim | |
params_for_controlnet = [ | |
"conditioning_channels", | |
"conditioning_embedding_out_channels", | |
"conditioning_channel_order", | |
"learn_time_embedding", | |
"block_out_channels", | |
"num_attention_heads", | |
"max_norm_num_groups", | |
] | |
params_for_controlnet = {"ctrl_" + k: v for k, v in controlnet.config.items() if k in params_for_controlnet} | |
params_for_controlnet["time_embedding_mix"] = controlnet.config.time_embedding_mix | |
# # create model | |
model = cls.from_config({**params_for_unet, **params_for_controlnet}) | |
# # load weights | |
# from unet | |
modules_from_unet = [ | |
"time_embedding", | |
"conv_in", | |
"conv_norm_out", | |
"conv_out", | |
] | |
for m in modules_from_unet: | |
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) | |
optional_modules_from_unet = [ | |
"add_time_proj", | |
"add_embedding", | |
] | |
for m in optional_modules_from_unet: | |
if hasattr(unet, m) and getattr(unet, m) is not None: | |
getattr(model, "base_" + m).load_state_dict(getattr(unet, m).state_dict()) | |
# from controlnet | |
model.controlnet_cond_embedding.load_state_dict(controlnet.controlnet_cond_embedding.state_dict()) | |
model.ctrl_conv_in.load_state_dict(controlnet.conv_in.state_dict()) | |
if controlnet.time_embedding is not None: | |
model.ctrl_time_embedding.load_state_dict(controlnet.time_embedding.state_dict()) | |
model.control_to_base_for_conv_in.load_state_dict(controlnet.control_to_base_for_conv_in.state_dict()) | |
# from both | |
model.down_blocks = nn.ModuleList( | |
ControlNetXSCrossAttnDownBlock2D.from_modules(b, c) | |
for b, c in zip(unet.down_blocks, controlnet.down_blocks) | |
) | |
model.mid_block = ControlNetXSCrossAttnMidBlock2D.from_modules(unet.mid_block, controlnet.mid_block) | |
model.up_blocks = nn.ModuleList( | |
ControlNetXSCrossAttnUpBlock2D.from_modules(b, c) | |
for b, c in zip(unet.up_blocks, controlnet.up_connections) | |
) | |
# ensure that the UNetControlNetXSModel is the same dtype as the UNet2DConditionModel | |
model.to(unet.dtype) | |
return model | |
def freeze_unet_params(self) -> None: | |
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine | |
tuning.""" | |
# Freeze everything | |
for param in self.parameters(): | |
param.requires_grad = True | |
# Unfreeze ControlNetXSAdapter | |
base_parts = [ | |
"base_time_proj", | |
"base_time_embedding", | |
"base_add_time_proj", | |
"base_add_embedding", | |
"base_conv_in", | |
"base_conv_norm_out", | |
"base_conv_act", | |
"base_conv_out", | |
] | |
base_parts = [getattr(self, part) for part in base_parts if getattr(self, part) is not None] | |
for part in base_parts: | |
for param in part.parameters(): | |
param.requires_grad = False | |
for d in self.down_blocks: | |
d.freeze_base_params() | |
self.mid_block.freeze_base_params() | |
for u in self.up_blocks: | |
u.freeze_base_params() | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor() | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu | |
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): | |
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. | |
The suffixes after the scaling factors represent the stage blocks where they are being applied. | |
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that | |
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. | |
Args: | |
s1 (`float`): | |
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
s2 (`float`): | |
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to | |
mitigate the "oversmoothing effect" in the enhanced denoising process. | |
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. | |
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. | |
""" | |
for i, upsample_block in enumerate(self.up_blocks): | |
setattr(upsample_block, "s1", s1) | |
setattr(upsample_block, "s2", s2) | |
setattr(upsample_block, "b1", b1) | |
setattr(upsample_block, "b2", b2) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu | |
def disable_freeu(self): | |
"""Disables the FreeU mechanism.""" | |
freeu_keys = {"s1", "s2", "b1", "b2"} | |
for i, upsample_block in enumerate(self.up_blocks): | |
for k in freeu_keys: | |
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: | |
setattr(upsample_block, k, None) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
self.set_attn_processor(FusedAttnProcessor2_0()) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
def forward( | |
self, | |
sample: Tensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
controlnet_cond: Optional[torch.Tensor] = None, | |
conditioning_scale: Optional[float] = 1.0, | |
class_labels: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, | |
return_dict: bool = True, | |
apply_control: bool = True, | |
) -> Union[ControlNetXSOutput, Tuple]: | |
""" | |
The [`ControlNetXSModel`] forward method. | |
Args: | |
sample (`Tensor`): | |
The noisy input tensor. | |
timestep (`Union[torch.Tensor, float, int]`): | |
The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.Tensor`): | |
The encoder hidden states. | |
controlnet_cond (`Tensor`): | |
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. | |
conditioning_scale (`float`, defaults to `1.0`): | |
How much the control model affects the base model outputs. | |
class_labels (`torch.Tensor`, *optional*, defaults to `None`): | |
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. | |
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): | |
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the | |
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep | |
embeddings. | |
attention_mask (`torch.Tensor`, *optional*, defaults to `None`): | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask | |
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large | |
negative values to the attention scores corresponding to "discard" tokens. | |
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): | |
A kwargs dictionary that if specified is passed along to the `AttnProcessor`. | |
added_cond_kwargs (`dict`): | |
Additional conditions for the Stable Diffusion XL UNet. | |
return_dict (`bool`, defaults to `True`): | |
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple. | |
apply_control (`bool`, defaults to `True`): | |
If `False`, the input is run only through the base model. | |
Returns: | |
[`~models.controlnetxs.ControlNetXSOutput`] **or** `tuple`: | |
If `return_dict` is `True`, a [`~models.controlnetxs.ControlNetXSOutput`] is returned, otherwise a | |
tuple is returned where the first element is the sample tensor. | |
""" | |
# check channel order | |
if self.config.ctrl_conditioning_channel_order == "bgr": | |
controlnet_cond = torch.flip(controlnet_cond, dims=[1]) | |
# prepare attention_mask | |
if attention_mask is not None: | |
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 | |
attention_mask = attention_mask.unsqueeze(1) | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.base_time_proj(timesteps) | |
# timesteps does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
if self.config.ctrl_learn_time_embedding and apply_control: | |
ctrl_temb = self.ctrl_time_embedding(t_emb, timestep_cond) | |
base_temb = self.base_time_embedding(t_emb, timestep_cond) | |
interpolation_param = self.config.time_embedding_mix**0.3 | |
temb = ctrl_temb * interpolation_param + base_temb * (1 - interpolation_param) | |
else: | |
temb = self.base_time_embedding(t_emb) | |
# added time & text embeddings | |
aug_emb = None | |
if self.config.addition_embed_type is None: | |
pass | |
elif self.config.addition_embed_type == "text_time": | |
# SDXL - style | |
if "text_embeds" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" | |
) | |
text_embeds = added_cond_kwargs.get("text_embeds") | |
if "time_ids" not in added_cond_kwargs: | |
raise ValueError( | |
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" | |
) | |
time_ids = added_cond_kwargs.get("time_ids") | |
time_embeds = self.base_add_time_proj(time_ids.flatten()) | |
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) | |
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) | |
add_embeds = add_embeds.to(temb.dtype) | |
aug_emb = self.base_add_embedding(add_embeds) | |
else: | |
raise ValueError( | |
f"ControlNet-XS currently only supports StableDiffusion and StableDiffusion-XL, so addition_embed_type = {self.config.addition_embed_type} is currently not supported." | |
) | |
temb = temb + aug_emb if aug_emb is not None else temb | |
# text embeddings | |
cemb = encoder_hidden_states | |
# Preparation | |
h_ctrl = h_base = sample | |
hs_base, hs_ctrl = [], [] | |
# Cross Control | |
guided_hint = self.controlnet_cond_embedding(controlnet_cond) | |
# 1 - conv in & down | |
h_base = self.base_conv_in(h_base) | |
h_ctrl = self.ctrl_conv_in(h_ctrl) | |
if guided_hint is not None: | |
h_ctrl += guided_hint | |
if apply_control: | |
h_base = h_base + self.control_to_base_for_conv_in(h_ctrl) * conditioning_scale # add ctrl -> base | |
hs_base.append(h_base) | |
hs_ctrl.append(h_ctrl) | |
for down in self.down_blocks: | |
h_base, h_ctrl, residual_hb, residual_hc = down( | |
hidden_states_base=h_base, | |
hidden_states_ctrl=h_ctrl, | |
temb=temb, | |
encoder_hidden_states=cemb, | |
conditioning_scale=conditioning_scale, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
apply_control=apply_control, | |
) | |
hs_base.extend(residual_hb) | |
hs_ctrl.extend(residual_hc) | |
# 2 - mid | |
h_base, h_ctrl = self.mid_block( | |
hidden_states_base=h_base, | |
hidden_states_ctrl=h_ctrl, | |
temb=temb, | |
encoder_hidden_states=cemb, | |
conditioning_scale=conditioning_scale, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
apply_control=apply_control, | |
) | |
# 3 - up | |
for up in self.up_blocks: | |
n_resnets = len(up.resnets) | |
skips_hb = hs_base[-n_resnets:] | |
skips_hc = hs_ctrl[-n_resnets:] | |
hs_base = hs_base[:-n_resnets] | |
hs_ctrl = hs_ctrl[:-n_resnets] | |
h_base = up( | |
hidden_states=h_base, | |
res_hidden_states_tuple_base=skips_hb, | |
res_hidden_states_tuple_ctrl=skips_hc, | |
temb=temb, | |
encoder_hidden_states=cemb, | |
conditioning_scale=conditioning_scale, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
apply_control=apply_control, | |
) | |
# 4 - conv out | |
h_base = self.base_conv_norm_out(h_base) | |
h_base = self.base_conv_act(h_base) | |
h_base = self.base_conv_out(h_base) | |
if not return_dict: | |
return (h_base,) | |
return ControlNetXSOutput(sample=h_base) | |
class ControlNetXSCrossAttnDownBlock2D(nn.Module): | |
def __init__( | |
self, | |
base_in_channels: int, | |
base_out_channels: int, | |
ctrl_in_channels: int, | |
ctrl_out_channels: int, | |
temb_channels: int, | |
norm_num_groups: int = 32, | |
ctrl_max_norm_num_groups: int = 32, | |
has_crossattn=True, | |
transformer_layers_per_block: Optional[Union[int, Tuple[int]]] = 1, | |
base_num_attention_heads: Optional[int] = 1, | |
ctrl_num_attention_heads: Optional[int] = 1, | |
cross_attention_dim: Optional[int] = 1024, | |
add_downsample: bool = True, | |
upcast_attention: Optional[bool] = False, | |
use_linear_projection: Optional[bool] = True, | |
): | |
super().__init__() | |
base_resnets = [] | |
base_attentions = [] | |
ctrl_resnets = [] | |
ctrl_attentions = [] | |
ctrl_to_base = [] | |
base_to_ctrl = [] | |
num_layers = 2 # only support sd + sdxl | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
base_in_channels = base_in_channels if i == 0 else base_out_channels | |
ctrl_in_channels = ctrl_in_channels if i == 0 else ctrl_out_channels | |
# Before the resnet/attention application, information is concatted from base to control. | |
# Concat doesn't require change in number of channels | |
base_to_ctrl.append(make_zero_conv(base_in_channels, base_in_channels)) | |
base_resnets.append( | |
ResnetBlock2D( | |
in_channels=base_in_channels, | |
out_channels=base_out_channels, | |
temb_channels=temb_channels, | |
groups=norm_num_groups, | |
) | |
) | |
ctrl_resnets.append( | |
ResnetBlock2D( | |
in_channels=ctrl_in_channels + base_in_channels, # information from base is concatted to ctrl | |
out_channels=ctrl_out_channels, | |
temb_channels=temb_channels, | |
groups=find_largest_factor( | |
ctrl_in_channels + base_in_channels, max_factor=ctrl_max_norm_num_groups | |
), | |
groups_out=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), | |
eps=1e-5, | |
) | |
) | |
if has_crossattn: | |
base_attentions.append( | |
Transformer2DModel( | |
base_num_attention_heads, | |
base_out_channels // base_num_attention_heads, | |
in_channels=base_out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
norm_num_groups=norm_num_groups, | |
) | |
) | |
ctrl_attentions.append( | |
Transformer2DModel( | |
ctrl_num_attention_heads, | |
ctrl_out_channels // ctrl_num_attention_heads, | |
in_channels=ctrl_out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
norm_num_groups=find_largest_factor(ctrl_out_channels, max_factor=ctrl_max_norm_num_groups), | |
) | |
) | |
# After the resnet/attention application, information is added from control to base | |
# Addition requires change in number of channels | |
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) | |
if add_downsample: | |
# Before the downsampler application, information is concatted from base to control | |
# Concat doesn't require change in number of channels | |
base_to_ctrl.append(make_zero_conv(base_out_channels, base_out_channels)) | |
self.base_downsamplers = Downsample2D( | |
base_out_channels, use_conv=True, out_channels=base_out_channels, name="op" | |
) | |
self.ctrl_downsamplers = Downsample2D( | |
ctrl_out_channels + base_out_channels, use_conv=True, out_channels=ctrl_out_channels, name="op" | |
) | |
# After the downsampler application, information is added from control to base | |
# Addition requires change in number of channels | |
ctrl_to_base.append(make_zero_conv(ctrl_out_channels, base_out_channels)) | |
else: | |
self.base_downsamplers = None | |
self.ctrl_downsamplers = None | |
self.base_resnets = nn.ModuleList(base_resnets) | |
self.ctrl_resnets = nn.ModuleList(ctrl_resnets) | |
self.base_attentions = nn.ModuleList(base_attentions) if has_crossattn else [None] * num_layers | |
self.ctrl_attentions = nn.ModuleList(ctrl_attentions) if has_crossattn else [None] * num_layers | |
self.base_to_ctrl = nn.ModuleList(base_to_ctrl) | |
self.ctrl_to_base = nn.ModuleList(ctrl_to_base) | |
self.gradient_checkpointing = False | |
def from_modules(cls, base_downblock: CrossAttnDownBlock2D, ctrl_downblock: DownBlockControlNetXSAdapter): | |
# get params | |
def get_first_cross_attention(block): | |
return block.attentions[0].transformer_blocks[0].attn2 | |
base_in_channels = base_downblock.resnets[0].in_channels | |
base_out_channels = base_downblock.resnets[0].out_channels | |
ctrl_in_channels = ( | |
ctrl_downblock.resnets[0].in_channels - base_in_channels | |
) # base channels are concatted to ctrl channels in init | |
ctrl_out_channels = ctrl_downblock.resnets[0].out_channels | |
temb_channels = base_downblock.resnets[0].time_emb_proj.in_features | |
num_groups = base_downblock.resnets[0].norm1.num_groups | |
ctrl_num_groups = ctrl_downblock.resnets[0].norm1.num_groups | |
if hasattr(base_downblock, "attentions"): | |
has_crossattn = True | |
transformer_layers_per_block = len(base_downblock.attentions[0].transformer_blocks) | |
base_num_attention_heads = get_first_cross_attention(base_downblock).heads | |
ctrl_num_attention_heads = get_first_cross_attention(ctrl_downblock).heads | |
cross_attention_dim = get_first_cross_attention(base_downblock).cross_attention_dim | |
upcast_attention = get_first_cross_attention(base_downblock).upcast_attention | |
use_linear_projection = base_downblock.attentions[0].use_linear_projection | |
else: | |
has_crossattn = False | |
transformer_layers_per_block = None | |
base_num_attention_heads = None | |
ctrl_num_attention_heads = None | |
cross_attention_dim = None | |
upcast_attention = None | |
use_linear_projection = None | |
add_downsample = base_downblock.downsamplers is not None | |
# create model | |
model = cls( | |
base_in_channels=base_in_channels, | |
base_out_channels=base_out_channels, | |
ctrl_in_channels=ctrl_in_channels, | |
ctrl_out_channels=ctrl_out_channels, | |
temb_channels=temb_channels, | |
norm_num_groups=num_groups, | |
ctrl_max_norm_num_groups=ctrl_num_groups, | |
has_crossattn=has_crossattn, | |
transformer_layers_per_block=transformer_layers_per_block, | |
base_num_attention_heads=base_num_attention_heads, | |
ctrl_num_attention_heads=ctrl_num_attention_heads, | |
cross_attention_dim=cross_attention_dim, | |
add_downsample=add_downsample, | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
# # load weights | |
model.base_resnets.load_state_dict(base_downblock.resnets.state_dict()) | |
model.ctrl_resnets.load_state_dict(ctrl_downblock.resnets.state_dict()) | |
if has_crossattn: | |
model.base_attentions.load_state_dict(base_downblock.attentions.state_dict()) | |
model.ctrl_attentions.load_state_dict(ctrl_downblock.attentions.state_dict()) | |
if add_downsample: | |
model.base_downsamplers.load_state_dict(base_downblock.downsamplers[0].state_dict()) | |
model.ctrl_downsamplers.load_state_dict(ctrl_downblock.downsamplers.state_dict()) | |
model.base_to_ctrl.load_state_dict(ctrl_downblock.base_to_ctrl.state_dict()) | |
model.ctrl_to_base.load_state_dict(ctrl_downblock.ctrl_to_base.state_dict()) | |
return model | |
def freeze_base_params(self) -> None: | |
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine | |
tuning.""" | |
# Unfreeze everything | |
for param in self.parameters(): | |
param.requires_grad = True | |
# Freeze base part | |
base_parts = [self.base_resnets] | |
if isinstance(self.base_attentions, nn.ModuleList): # attentions can be a list of Nones | |
base_parts.append(self.base_attentions) | |
if self.base_downsamplers is not None: | |
base_parts.append(self.base_downsamplers) | |
for part in base_parts: | |
for param in part.parameters(): | |
param.requires_grad = False | |
def forward( | |
self, | |
hidden_states_base: Tensor, | |
temb: Tensor, | |
encoder_hidden_states: Optional[Tensor] = None, | |
hidden_states_ctrl: Optional[Tensor] = None, | |
conditioning_scale: Optional[float] = 1.0, | |
attention_mask: Optional[Tensor] = None, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
encoder_attention_mask: Optional[Tensor] = None, | |
apply_control: bool = True, | |
) -> Tuple[Tensor, Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
h_base = hidden_states_base | |
h_ctrl = hidden_states_ctrl | |
base_output_states = () | |
ctrl_output_states = () | |
base_blocks = list(zip(self.base_resnets, self.base_attentions)) | |
ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions)) | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip( | |
base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base | |
): | |
# concat base -> ctrl | |
if apply_control: | |
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) | |
# apply base subblock | |
if self.training and self.gradient_checkpointing: | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
h_base = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(b_res), | |
h_base, | |
temb, | |
**ckpt_kwargs, | |
) | |
else: | |
h_base = b_res(h_base, temb) | |
if b_attn is not None: | |
h_base = b_attn( | |
h_base, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
# apply ctrl subblock | |
if apply_control: | |
if self.training and self.gradient_checkpointing: | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
h_ctrl = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(c_res), | |
h_ctrl, | |
temb, | |
**ckpt_kwargs, | |
) | |
else: | |
h_ctrl = c_res(h_ctrl, temb) | |
if c_attn is not None: | |
h_ctrl = c_attn( | |
h_ctrl, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
# add ctrl -> base | |
if apply_control: | |
h_base = h_base + c2b(h_ctrl) * conditioning_scale | |
base_output_states = base_output_states + (h_base,) | |
ctrl_output_states = ctrl_output_states + (h_ctrl,) | |
if self.base_downsamplers is not None: # if we have a base_downsampler, then also a ctrl_downsampler | |
b2c = self.base_to_ctrl[-1] | |
c2b = self.ctrl_to_base[-1] | |
# concat base -> ctrl | |
if apply_control: | |
h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1) | |
# apply base subblock | |
h_base = self.base_downsamplers(h_base) | |
# apply ctrl subblock | |
if apply_control: | |
h_ctrl = self.ctrl_downsamplers(h_ctrl) | |
# add ctrl -> base | |
if apply_control: | |
h_base = h_base + c2b(h_ctrl) * conditioning_scale | |
base_output_states = base_output_states + (h_base,) | |
ctrl_output_states = ctrl_output_states + (h_ctrl,) | |
return h_base, h_ctrl, base_output_states, ctrl_output_states | |
class ControlNetXSCrossAttnMidBlock2D(nn.Module): | |
def __init__( | |
self, | |
base_channels: int, | |
ctrl_channels: int, | |
temb_channels: Optional[int] = None, | |
norm_num_groups: int = 32, | |
ctrl_max_norm_num_groups: int = 32, | |
transformer_layers_per_block: int = 1, | |
base_num_attention_heads: Optional[int] = 1, | |
ctrl_num_attention_heads: Optional[int] = 1, | |
cross_attention_dim: Optional[int] = 1024, | |
upcast_attention: bool = False, | |
use_linear_projection: Optional[bool] = True, | |
): | |
super().__init__() | |
# Before the midblock application, information is concatted from base to control. | |
# Concat doesn't require change in number of channels | |
self.base_to_ctrl = make_zero_conv(base_channels, base_channels) | |
self.base_midblock = UNetMidBlock2DCrossAttn( | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=base_channels, | |
temb_channels=temb_channels, | |
resnet_groups=norm_num_groups, | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=base_num_attention_heads, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
) | |
self.ctrl_midblock = UNetMidBlock2DCrossAttn( | |
transformer_layers_per_block=transformer_layers_per_block, | |
in_channels=ctrl_channels + base_channels, | |
out_channels=ctrl_channels, | |
temb_channels=temb_channels, | |
# number or norm groups must divide both in_channels and out_channels | |
resnet_groups=find_largest_factor( | |
gcd(ctrl_channels, ctrl_channels + base_channels), ctrl_max_norm_num_groups | |
), | |
cross_attention_dim=cross_attention_dim, | |
num_attention_heads=ctrl_num_attention_heads, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
) | |
# After the midblock application, information is added from control to base | |
# Addition requires change in number of channels | |
self.ctrl_to_base = make_zero_conv(ctrl_channels, base_channels) | |
self.gradient_checkpointing = False | |
def from_modules( | |
cls, | |
base_midblock: UNetMidBlock2DCrossAttn, | |
ctrl_midblock: MidBlockControlNetXSAdapter, | |
): | |
base_to_ctrl = ctrl_midblock.base_to_ctrl | |
ctrl_to_base = ctrl_midblock.ctrl_to_base | |
ctrl_midblock = ctrl_midblock.midblock | |
# get params | |
def get_first_cross_attention(midblock): | |
return midblock.attentions[0].transformer_blocks[0].attn2 | |
base_channels = ctrl_to_base.out_channels | |
ctrl_channels = ctrl_to_base.in_channels | |
transformer_layers_per_block = len(base_midblock.attentions[0].transformer_blocks) | |
temb_channels = base_midblock.resnets[0].time_emb_proj.in_features | |
num_groups = base_midblock.resnets[0].norm1.num_groups | |
ctrl_num_groups = ctrl_midblock.resnets[0].norm1.num_groups | |
base_num_attention_heads = get_first_cross_attention(base_midblock).heads | |
ctrl_num_attention_heads = get_first_cross_attention(ctrl_midblock).heads | |
cross_attention_dim = get_first_cross_attention(base_midblock).cross_attention_dim | |
upcast_attention = get_first_cross_attention(base_midblock).upcast_attention | |
use_linear_projection = base_midblock.attentions[0].use_linear_projection | |
# create model | |
model = cls( | |
base_channels=base_channels, | |
ctrl_channels=ctrl_channels, | |
temb_channels=temb_channels, | |
norm_num_groups=num_groups, | |
ctrl_max_norm_num_groups=ctrl_num_groups, | |
transformer_layers_per_block=transformer_layers_per_block, | |
base_num_attention_heads=base_num_attention_heads, | |
ctrl_num_attention_heads=ctrl_num_attention_heads, | |
cross_attention_dim=cross_attention_dim, | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
# load weights | |
model.base_to_ctrl.load_state_dict(base_to_ctrl.state_dict()) | |
model.base_midblock.load_state_dict(base_midblock.state_dict()) | |
model.ctrl_midblock.load_state_dict(ctrl_midblock.state_dict()) | |
model.ctrl_to_base.load_state_dict(ctrl_to_base.state_dict()) | |
return model | |
def freeze_base_params(self) -> None: | |
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine | |
tuning.""" | |
# Unfreeze everything | |
for param in self.parameters(): | |
param.requires_grad = True | |
# Freeze base part | |
for param in self.base_midblock.parameters(): | |
param.requires_grad = False | |
def forward( | |
self, | |
hidden_states_base: Tensor, | |
temb: Tensor, | |
encoder_hidden_states: Tensor, | |
hidden_states_ctrl: Optional[Tensor] = None, | |
conditioning_scale: Optional[float] = 1.0, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
attention_mask: Optional[Tensor] = None, | |
encoder_attention_mask: Optional[Tensor] = None, | |
apply_control: bool = True, | |
) -> Tuple[Tensor, Tensor]: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
h_base = hidden_states_base | |
h_ctrl = hidden_states_ctrl | |
joint_args = { | |
"temb": temb, | |
"encoder_hidden_states": encoder_hidden_states, | |
"attention_mask": attention_mask, | |
"cross_attention_kwargs": cross_attention_kwargs, | |
"encoder_attention_mask": encoder_attention_mask, | |
} | |
if apply_control: | |
h_ctrl = torch.cat([h_ctrl, self.base_to_ctrl(h_base)], dim=1) # concat base -> ctrl | |
h_base = self.base_midblock(h_base, **joint_args) # apply base mid block | |
if apply_control: | |
h_ctrl = self.ctrl_midblock(h_ctrl, **joint_args) # apply ctrl mid block | |
h_base = h_base + self.ctrl_to_base(h_ctrl) * conditioning_scale # add ctrl -> base | |
return h_base, h_ctrl | |
class ControlNetXSCrossAttnUpBlock2D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
prev_output_channel: int, | |
ctrl_skip_channels: List[int], | |
temb_channels: int, | |
norm_num_groups: int = 32, | |
resolution_idx: Optional[int] = None, | |
has_crossattn=True, | |
transformer_layers_per_block: int = 1, | |
num_attention_heads: int = 1, | |
cross_attention_dim: int = 1024, | |
add_upsample: bool = True, | |
upcast_attention: bool = False, | |
use_linear_projection: Optional[bool] = True, | |
): | |
super().__init__() | |
resnets = [] | |
attentions = [] | |
ctrl_to_base = [] | |
num_layers = 3 # only support sd + sdxl | |
self.has_cross_attention = has_crossattn | |
self.num_attention_heads = num_attention_heads | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * num_layers | |
for i in range(num_layers): | |
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
ctrl_to_base.append(make_zero_conv(ctrl_skip_channels[i], resnet_in_channels)) | |
resnets.append( | |
ResnetBlock2D( | |
in_channels=resnet_in_channels + res_skip_channels, | |
out_channels=out_channels, | |
temb_channels=temb_channels, | |
groups=norm_num_groups, | |
) | |
) | |
if has_crossattn: | |
attentions.append( | |
Transformer2DModel( | |
num_attention_heads, | |
out_channels // num_attention_heads, | |
in_channels=out_channels, | |
num_layers=transformer_layers_per_block[i], | |
cross_attention_dim=cross_attention_dim, | |
use_linear_projection=use_linear_projection, | |
upcast_attention=upcast_attention, | |
norm_num_groups=norm_num_groups, | |
) | |
) | |
self.resnets = nn.ModuleList(resnets) | |
self.attentions = nn.ModuleList(attentions) if has_crossattn else [None] * num_layers | |
self.ctrl_to_base = nn.ModuleList(ctrl_to_base) | |
if add_upsample: | |
self.upsamplers = Upsample2D(out_channels, use_conv=True, out_channels=out_channels) | |
else: | |
self.upsamplers = None | |
self.gradient_checkpointing = False | |
self.resolution_idx = resolution_idx | |
def from_modules(cls, base_upblock: CrossAttnUpBlock2D, ctrl_upblock: UpBlockControlNetXSAdapter): | |
ctrl_to_base_skip_connections = ctrl_upblock.ctrl_to_base | |
# get params | |
def get_first_cross_attention(block): | |
return block.attentions[0].transformer_blocks[0].attn2 | |
out_channels = base_upblock.resnets[0].out_channels | |
in_channels = base_upblock.resnets[-1].in_channels - out_channels | |
prev_output_channels = base_upblock.resnets[0].in_channels - out_channels | |
ctrl_skip_channelss = [c.in_channels for c in ctrl_to_base_skip_connections] | |
temb_channels = base_upblock.resnets[0].time_emb_proj.in_features | |
num_groups = base_upblock.resnets[0].norm1.num_groups | |
resolution_idx = base_upblock.resolution_idx | |
if hasattr(base_upblock, "attentions"): | |
has_crossattn = True | |
transformer_layers_per_block = len(base_upblock.attentions[0].transformer_blocks) | |
num_attention_heads = get_first_cross_attention(base_upblock).heads | |
cross_attention_dim = get_first_cross_attention(base_upblock).cross_attention_dim | |
upcast_attention = get_first_cross_attention(base_upblock).upcast_attention | |
use_linear_projection = base_upblock.attentions[0].use_linear_projection | |
else: | |
has_crossattn = False | |
transformer_layers_per_block = None | |
num_attention_heads = None | |
cross_attention_dim = None | |
upcast_attention = None | |
use_linear_projection = None | |
add_upsample = base_upblock.upsamplers is not None | |
# create model | |
model = cls( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
prev_output_channel=prev_output_channels, | |
ctrl_skip_channels=ctrl_skip_channelss, | |
temb_channels=temb_channels, | |
norm_num_groups=num_groups, | |
resolution_idx=resolution_idx, | |
has_crossattn=has_crossattn, | |
transformer_layers_per_block=transformer_layers_per_block, | |
num_attention_heads=num_attention_heads, | |
cross_attention_dim=cross_attention_dim, | |
add_upsample=add_upsample, | |
upcast_attention=upcast_attention, | |
use_linear_projection=use_linear_projection, | |
) | |
# load weights | |
model.resnets.load_state_dict(base_upblock.resnets.state_dict()) | |
if has_crossattn: | |
model.attentions.load_state_dict(base_upblock.attentions.state_dict()) | |
if add_upsample: | |
model.upsamplers.load_state_dict(base_upblock.upsamplers[0].state_dict()) | |
model.ctrl_to_base.load_state_dict(ctrl_to_base_skip_connections.state_dict()) | |
return model | |
def freeze_base_params(self) -> None: | |
"""Freeze the weights of the parts belonging to the base UNet2DConditionModel, and leave everything else unfrozen for fine | |
tuning.""" | |
# Unfreeze everything | |
for param in self.parameters(): | |
param.requires_grad = True | |
# Freeze base part | |
base_parts = [self.resnets] | |
if isinstance(self.attentions, nn.ModuleList): # attentions can be a list of Nones | |
base_parts.append(self.attentions) | |
if self.upsamplers is not None: | |
base_parts.append(self.upsamplers) | |
for part in base_parts: | |
for param in part.parameters(): | |
param.requires_grad = False | |
def forward( | |
self, | |
hidden_states: Tensor, | |
res_hidden_states_tuple_base: Tuple[Tensor, ...], | |
res_hidden_states_tuple_ctrl: Tuple[Tensor, ...], | |
temb: Tensor, | |
encoder_hidden_states: Optional[Tensor] = None, | |
conditioning_scale: Optional[float] = 1.0, | |
cross_attention_kwargs: Optional[Dict[str, Any]] = None, | |
attention_mask: Optional[Tensor] = None, | |
upsample_size: Optional[int] = None, | |
encoder_attention_mask: Optional[Tensor] = None, | |
apply_control: bool = True, | |
) -> Tensor: | |
if cross_attention_kwargs is not None: | |
if cross_attention_kwargs.get("scale", None) is not None: | |
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.") | |
is_freeu_enabled = ( | |
getattr(self, "s1", None) | |
and getattr(self, "s2", None) | |
and getattr(self, "b1", None) | |
and getattr(self, "b2", None) | |
) | |
def create_custom_forward(module, return_dict=None): | |
def custom_forward(*inputs): | |
if return_dict is not None: | |
return module(*inputs, return_dict=return_dict) | |
else: | |
return module(*inputs) | |
return custom_forward | |
def maybe_apply_freeu_to_subblock(hidden_states, res_h_base): | |
# FreeU: Only operate on the first two stages | |
if is_freeu_enabled: | |
return apply_freeu( | |
self.resolution_idx, | |
hidden_states, | |
res_h_base, | |
s1=self.s1, | |
s2=self.s2, | |
b1=self.b1, | |
b2=self.b2, | |
) | |
else: | |
return hidden_states, res_h_base | |
for resnet, attn, c2b, res_h_base, res_h_ctrl in zip( | |
self.resnets, | |
self.attentions, | |
self.ctrl_to_base, | |
reversed(res_hidden_states_tuple_base), | |
reversed(res_hidden_states_tuple_ctrl), | |
): | |
if apply_control: | |
hidden_states += c2b(res_h_ctrl) * conditioning_scale | |
hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base) | |
hidden_states = torch.cat([hidden_states, res_h_base], dim=1) | |
if self.training and self.gradient_checkpointing: | |
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} | |
hidden_states = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(resnet), | |
hidden_states, | |
temb, | |
**ckpt_kwargs, | |
) | |
else: | |
hidden_states = resnet(hidden_states, temb) | |
if attn is not None: | |
hidden_states = attn( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
attention_mask=attention_mask, | |
encoder_attention_mask=encoder_attention_mask, | |
return_dict=False, | |
)[0] | |
if self.upsamplers is not None: | |
hidden_states = self.upsamplers(hidden_states, upsample_size) | |
return hidden_states | |
def make_zero_conv(in_channels, out_channels=None): | |
return zero_module(nn.Conv2d(in_channels, out_channels, 1, padding=0)) | |
def zero_module(module): | |
for p in module.parameters(): | |
nn.init.zeros_(p) | |
return module | |
def find_largest_factor(number, max_factor): | |
factor = max_factor | |
if factor >= number: | |
return number | |
while factor != 0: | |
residual = number % factor | |
if residual == 0: | |
return factor | |
factor -= 1 | |