InvSR / src /diffusers /models /unets /unet_2d_blocks_flax.py
OAOA's picture
first commit
bfa59ab
raw
history blame
15.6 kB
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import flax.linen as nn
import jax.numpy as jnp
from ..attention_flax import FlaxTransformer2DModel
from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
class FlaxCrossAttnDownBlock2D(nn.Module):
r"""
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
num_attention_heads: int = 1
add_downsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxDownBlock2D(nn.Module):
r"""
Flax 2D downsizing block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
dropout: float = 0.0
num_layers: int = 1
add_downsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
in_channels = self.in_channels if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=in_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_downsample:
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, temb, deterministic=True):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
output_states += (hidden_states,)
if self.add_downsample:
hidden_states = self.downsamplers_0(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class FlaxCrossAttnUpBlock2D(nn.Module):
r"""
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add upsampling layer before each final output
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
num_attention_heads: int = 1
add_upsample: bool = True
use_linear_projection: bool = False
only_cross_attention: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
resnets = []
attentions = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
self.resnets = resnets
self.attentions = attentions
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, deterministic=True):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class FlaxUpBlock2D(nn.Module):
r"""
Flax 2D upsampling block
Parameters:
in_channels (:obj:`int`):
Input channels
out_channels (:obj:`int`):
Output channels
prev_output_channel (:obj:`int`):
Output channels from the previous block
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
Whether to add downsampling layer before each final output
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
out_channels: int
prev_output_channel: int
dropout: float = 0.0
num_layers: int = 1
add_upsample: bool = True
dtype: jnp.dtype = jnp.float32
def setup(self):
resnets = []
for i in range(self.num_layers):
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
res_block = FlaxResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=self.out_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
if self.add_upsample:
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
def __call__(self, hidden_states, res_hidden_states_tuple, temb, deterministic=True):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
if self.add_upsample:
hidden_states = self.upsamplers_0(hidden_states)
return hidden_states
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
r"""
Cross Attention 2D Mid-level block - original architecture from Unet transformers: https://arxiv.org/abs/2103.06104
Parameters:
in_channels (:obj:`int`):
Input channels
dropout (:obj:`float`, *optional*, defaults to 0.0):
Dropout rate
num_layers (:obj:`int`, *optional*, defaults to 1):
Number of attention blocks layers
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
Number of attention heads of each spatial transformer block
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
enable memory efficient attention https://arxiv.org/abs/2112.05682
split_head_dim (`bool`, *optional*, defaults to `False`):
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
Parameters `dtype`
"""
in_channels: int
dropout: float = 0.0
num_layers: int = 1
num_attention_heads: int = 1
use_linear_projection: bool = False
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self):
# there is always at least one resnet
resnets = [
FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
]
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels,
n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads,
depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention,
split_head_dim=self.split_head_dim,
dtype=self.dtype,
)
attentions.append(attn_block)
res_block = FlaxResnetBlock2D(
in_channels=self.in_channels,
out_channels=self.in_channels,
dropout_prob=self.dropout,
dtype=self.dtype,
)
resnets.append(res_block)
self.resnets = resnets
self.attentions = attentions
def __call__(self, hidden_states, temb, encoder_hidden_states, deterministic=True):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
return hidden_states