FitDiT / src /pose_guider.py
BoyuanJiang's picture
Add application file
a62eecb
raw
history blame
1.95 kB
from typing import Tuple
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module
class PoseGuider(ModelMixin):
def __init__(
self,
conditioning_embedding_channels: int,
conditioning_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
):
super().__init__()
self.conv_in = nn.Conv2d(
conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
)
self.blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(
nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)
)
self.blocks.append(
nn.Conv2d(
channel_in, channel_out, kernel_size=3, padding=1, stride=2
)
)
self.out = zero_module(
nn.Linear(
block_out_channels[-1]*4,
conditioning_embedding_channels,
)
)
def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
embedding = rearrange(embedding, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
embedding = self.out(embedding)
return embedding
if __name__ == "__main__":
import torch
model = PoseGuider(conditioning_embedding_channels=3072, block_out_channels = (16, 32, 96, 256))
inp = torch.randn((4, 3, 1024, 768))
out = model(inp)