Spaces:
Running
on
Zero
Running
on
Zero
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, Optional, Union | |
import torch | |
from torch import nn | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...utils import logging | |
from ...utils.torch_utils import maybe_allow_in_graph | |
from ..attention import FeedForward | |
from ..attention_processor import Attention, AttentionProcessor, FusedHunyuanAttnProcessor2_0, HunyuanAttnProcessor2_0 | |
from ..embeddings import ( | |
HunyuanCombinedTimestepTextSizeStyleEmbedding, | |
PatchEmbed, | |
PixArtAlphaTextProjection, | |
) | |
from ..modeling_outputs import Transformer2DModelOutput | |
from ..modeling_utils import ModelMixin | |
from ..normalization import AdaLayerNormContinuous, FP32LayerNorm | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class AdaLayerNormShift(nn.Module): | |
r""" | |
Norm layer modified to incorporate timestep embeddings. | |
Parameters: | |
embedding_dim (`int`): The size of each embedding vector. | |
num_embeddings (`int`): The size of the embeddings dictionary. | |
""" | |
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6): | |
super().__init__() | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, embedding_dim) | |
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps) | |
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: | |
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype)) | |
x = self.norm(x) + shift.unsqueeze(dim=1) | |
return x | |
class HunyuanDiTBlock(nn.Module): | |
r""" | |
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and | |
QKNorm | |
Parameters: | |
dim (`int`): | |
The number of channels in the input and output. | |
num_attention_heads (`int`): | |
The number of headsto use for multi-head attention. | |
cross_attention_dim (`int`,*optional*): | |
The size of the encoder_hidden_states vector for cross attention. | |
dropout(`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
activation_fn (`str`,*optional*, defaults to `"geglu"`): | |
Activation function to be used in feed-forward. . | |
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): | |
Whether to use learnable elementwise affine parameters for normalization. | |
norm_eps (`float`, *optional*, defaults to 1e-6): | |
A small constant added to the denominator in normalization layers to prevent division by zero. | |
final_dropout (`bool` *optional*, defaults to False): | |
Whether to apply a final dropout after the last feed-forward layer. | |
ff_inner_dim (`int`, *optional*): | |
The size of the hidden layer in the feed-forward block. Defaults to `None`. | |
ff_bias (`bool`, *optional*, defaults to `True`): | |
Whether to use bias in the feed-forward block. | |
skip (`bool`, *optional*, defaults to `False`): | |
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks. | |
qk_norm (`bool`, *optional*, defaults to `True`): | |
Whether to use normalization in QK calculation. Defaults to `True`. | |
""" | |
def __init__( | |
self, | |
dim: int, | |
num_attention_heads: int, | |
cross_attention_dim: int = 1024, | |
dropout=0.0, | |
activation_fn: str = "geglu", | |
norm_elementwise_affine: bool = True, | |
norm_eps: float = 1e-6, | |
final_dropout: bool = False, | |
ff_inner_dim: Optional[int] = None, | |
ff_bias: bool = True, | |
skip: bool = False, | |
qk_norm: bool = True, | |
): | |
super().__init__() | |
# Define 3 blocks. Each block has its own normalization layer. | |
# NOTE: when new version comes, check norm2 and norm 3 | |
# 1. Self-Attn | |
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | |
self.attn1 = Attention( | |
query_dim=dim, | |
cross_attention_dim=None, | |
dim_head=dim // num_attention_heads, | |
heads=num_attention_heads, | |
qk_norm="layer_norm" if qk_norm else None, | |
eps=1e-6, | |
bias=True, | |
processor=HunyuanAttnProcessor2_0(), | |
) | |
# 2. Cross-Attn | |
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) | |
self.attn2 = Attention( | |
query_dim=dim, | |
cross_attention_dim=cross_attention_dim, | |
dim_head=dim // num_attention_heads, | |
heads=num_attention_heads, | |
qk_norm="layer_norm" if qk_norm else None, | |
eps=1e-6, | |
bias=True, | |
processor=HunyuanAttnProcessor2_0(), | |
) | |
# 3. Feed-forward | |
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine) | |
self.ff = FeedForward( | |
dim, | |
dropout=dropout, ### 0.0 | |
activation_fn=activation_fn, ### approx GeLU | |
final_dropout=final_dropout, ### 0.0 | |
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio) | |
bias=ff_bias, | |
) | |
# 4. Skip Connection | |
if skip: | |
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True) | |
self.skip_linear = nn.Linear(2 * dim, dim) | |
else: | |
self.skip_linear = None | |
# let chunk size default to None | |
self._chunk_size = None | |
self._chunk_dim = 0 | |
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward | |
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): | |
# Sets chunk feed-forward | |
self._chunk_size = chunk_size | |
self._chunk_dim = dim | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: Optional[torch.Tensor] = None, | |
temb: Optional[torch.Tensor] = None, | |
image_rotary_emb=None, | |
skip=None, | |
) -> torch.Tensor: | |
# Notice that normalization is always applied before the real computation in the following blocks. | |
# 0. Long Skip Connection | |
if self.skip_linear is not None: | |
cat = torch.cat([hidden_states, skip], dim=-1) | |
cat = self.skip_norm(cat) | |
hidden_states = self.skip_linear(cat) | |
# 1. Self-Attention | |
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct | |
attn_output = self.attn1( | |
norm_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
hidden_states = hidden_states + attn_output | |
# 2. Cross-Attention | |
hidden_states = hidden_states + self.attn2( | |
self.norm2(hidden_states), | |
encoder_hidden_states=encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) | |
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict | |
mlp_inputs = self.norm3(hidden_states) | |
hidden_states = hidden_states + self.ff(mlp_inputs) | |
return hidden_states | |
class HunyuanDiT2DModel(ModelMixin, ConfigMixin): | |
""" | |
HunYuanDiT: Diffusion model with a Transformer backbone. | |
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. | |
Parameters: | |
num_attention_heads (`int`, *optional*, defaults to 16): | |
The number of heads to use for multi-head attention. | |
attention_head_dim (`int`, *optional*, defaults to 88): | |
The number of channels in each head. | |
in_channels (`int`, *optional*): | |
The number of channels in the input and output (specify if the input is **continuous**). | |
patch_size (`int`, *optional*): | |
The size of the patch to use for the input. | |
activation_fn (`str`, *optional*, defaults to `"geglu"`): | |
Activation function to use in feed-forward. | |
sample_size (`int`, *optional*): | |
The width of the latent images. This is fixed during training since it is used to learn a number of | |
position embeddings. | |
dropout (`float`, *optional*, defaults to 0.0): | |
The dropout probability to use. | |
cross_attention_dim (`int`, *optional*): | |
The number of dimension in the clip text embedding. | |
hidden_size (`int`, *optional*): | |
The size of hidden layer in the conditioning embedding layers. | |
num_layers (`int`, *optional*, defaults to 1): | |
The number of layers of Transformer blocks to use. | |
mlp_ratio (`float`, *optional*, defaults to 4.0): | |
The ratio of the hidden layer size to the input size. | |
learn_sigma (`bool`, *optional*, defaults to `True`): | |
Whether to predict variance. | |
cross_attention_dim_t5 (`int`, *optional*): | |
The number dimensions in t5 text embedding. | |
pooled_projection_dim (`int`, *optional*): | |
The size of the pooled projection. | |
text_len (`int`, *optional*): | |
The length of the clip text embedding. | |
text_len_t5 (`int`, *optional*): | |
The length of the T5 text embedding. | |
use_style_cond_and_image_meta_size (`bool`, *optional*): | |
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2 | |
""" | |
def __init__( | |
self, | |
num_attention_heads: int = 16, | |
attention_head_dim: int = 88, | |
in_channels: Optional[int] = None, | |
patch_size: Optional[int] = None, | |
activation_fn: str = "gelu-approximate", | |
sample_size=32, | |
hidden_size=1152, | |
num_layers: int = 28, | |
mlp_ratio: float = 4.0, | |
learn_sigma: bool = True, | |
cross_attention_dim: int = 1024, | |
norm_type: str = "layer_norm", | |
cross_attention_dim_t5: int = 2048, | |
pooled_projection_dim: int = 1024, | |
text_len: int = 77, | |
text_len_t5: int = 256, | |
use_style_cond_and_image_meta_size: bool = True, | |
): | |
super().__init__() | |
self.out_channels = in_channels * 2 if learn_sigma else in_channels | |
self.num_heads = num_attention_heads | |
self.inner_dim = num_attention_heads * attention_head_dim | |
self.text_embedder = PixArtAlphaTextProjection( | |
in_features=cross_attention_dim_t5, | |
hidden_size=cross_attention_dim_t5 * 4, | |
out_features=cross_attention_dim, | |
act_fn="silu_fp32", | |
) | |
self.text_embedding_padding = nn.Parameter( | |
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32) | |
) | |
self.pos_embed = PatchEmbed( | |
height=sample_size, | |
width=sample_size, | |
in_channels=in_channels, | |
embed_dim=hidden_size, | |
patch_size=patch_size, | |
pos_embed_type=None, | |
) | |
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding( | |
hidden_size, | |
pooled_projection_dim=pooled_projection_dim, | |
seq_len=text_len_t5, | |
cross_attention_dim=cross_attention_dim_t5, | |
use_style_cond_and_image_meta_size=use_style_cond_and_image_meta_size, | |
) | |
# HunyuanDiT Blocks | |
self.blocks = nn.ModuleList( | |
[ | |
HunyuanDiTBlock( | |
dim=self.inner_dim, | |
num_attention_heads=self.config.num_attention_heads, | |
activation_fn=activation_fn, | |
ff_inner_dim=int(self.inner_dim * mlp_ratio), | |
cross_attention_dim=cross_attention_dim, | |
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details. | |
skip=layer > num_layers // 2, | |
) | |
for layer in range(num_layers) | |
] | |
) | |
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedHunyuanAttnProcessor2_0 | |
def fuse_qkv_projections(self): | |
""" | |
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | |
are fused. For cross-attention modules, key and value projection matrices are fused. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
self.original_attn_processors = None | |
for _, attn_processor in self.attn_processors.items(): | |
if "Added" in str(attn_processor.__class__.__name__): | |
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | |
self.original_attn_processors = self.attn_processors | |
for module in self.modules(): | |
if isinstance(module, Attention): | |
module.fuse_projections(fuse=True) | |
self.set_attn_processor(FusedHunyuanAttnProcessor2_0()) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | |
def unfuse_qkv_projections(self): | |
"""Disables the fused QKV projection if enabled. | |
<Tip warning={true}> | |
This API is 🧪 experimental. | |
</Tip> | |
""" | |
if self.original_attn_processors is not None: | |
self.set_attn_processor(self.original_attn_processors) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor() | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
self.set_attn_processor(HunyuanAttnProcessor2_0()) | |
def forward( | |
self, | |
hidden_states, | |
timestep, | |
encoder_hidden_states=None, | |
text_embedding_mask=None, | |
encoder_hidden_states_t5=None, | |
text_embedding_mask_t5=None, | |
image_meta_size=None, | |
style=None, | |
image_rotary_emb=None, | |
controlnet_block_samples=None, | |
return_dict=True, | |
): | |
""" | |
The [`HunyuanDiT2DModel`] forward method. | |
Args: | |
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`): | |
The input tensor. | |
timestep ( `torch.LongTensor`, *optional*): | |
Used to indicate denoising step. | |
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
Conditional embeddings for cross attention layer. This is the output of `BertModel`. | |
text_embedding_mask: torch.Tensor | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output | |
of `BertModel`. | |
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*): | |
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder. | |
text_embedding_mask_t5: torch.Tensor | |
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output | |
of T5 Text Encoder. | |
image_meta_size (torch.Tensor): | |
Conditional embedding indicate the image sizes | |
style: torch.Tensor: | |
Conditional embedding indicate the style | |
image_rotary_emb (`torch.Tensor`): | |
The image rotary embeddings to apply on query and key tensors during attention calculation. | |
return_dict: bool | |
Whether to return a dictionary. | |
""" | |
height, width = hidden_states.shape[-2:] | |
hidden_states = self.pos_embed(hidden_states) | |
temb = self.time_extra_emb( | |
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype | |
) # [B, D] | |
# text projection | |
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape | |
encoder_hidden_states_t5 = self.text_embedder( | |
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1]) | |
) | |
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1) | |
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) | |
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1) | |
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool() | |
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding) | |
skips = [] | |
for layer, block in enumerate(self.blocks): | |
if layer > self.config.num_layers // 2: | |
if controlnet_block_samples is not None: | |
skip = skips.pop() + controlnet_block_samples.pop() | |
else: | |
skip = skips.pop() | |
hidden_states = block( | |
hidden_states, | |
temb=temb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
skip=skip, | |
) # (N, L, D) | |
else: | |
hidden_states = block( | |
hidden_states, | |
temb=temb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_rotary_emb=image_rotary_emb, | |
) # (N, L, D) | |
if layer < (self.config.num_layers // 2 - 1): | |
skips.append(hidden_states) | |
if controlnet_block_samples is not None and len(controlnet_block_samples) != 0: | |
raise ValueError("The number of controls is not equal to the number of skip connections.") | |
# final layer | |
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32)) | |
hidden_states = self.proj_out(hidden_states) | |
# (N, L, patch_size ** 2 * out_channels) | |
# unpatchify: (N, out_channels, H, W) | |
patch_size = self.pos_embed.patch_size | |
height = height // patch_size | |
width = width // patch_size | |
hidden_states = hidden_states.reshape( | |
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) | |
) | |
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) | |
output = hidden_states.reshape( | |
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) | |
) | |
if not return_dict: | |
return (output,) | |
return Transformer2DModelOutput(sample=output) | |
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | |
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: | |
""" | |
Sets the attention processor to use [feed forward | |
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). | |
Parameters: | |
chunk_size (`int`, *optional*): | |
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually | |
over each tensor of dim=`dim`. | |
dim (`int`, *optional*, defaults to `0`): | |
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) | |
or dim=1 (sequence length). | |
""" | |
if dim not in [0, 1]: | |
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") | |
# By default chunk size is 1 | |
chunk_size = chunk_size or 1 | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, chunk_size, dim) | |
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking | |
def disable_forward_chunking(self): | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, None, 0) | |