Blackroot commited on
Commit
e9959b7
·
verified ·
1 Parent(s): 3d16315

Upload 4 files

Browse files
Files changed (4) hide show
  1. models/__init__.py +3 -0
  2. models/uvit.py +368 -0
  3. step_799.safetensors +3 -0
  4. train.py +307 -0
models/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .uvit import AsymmetricResidualUDiT, xATGLU
2
+
3
+ __all__ = ['AsymmetricResidualUDiT', xATGLU]
models/uvit.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Changelog since original version:
6
+ # xATGLU instead of top linear in transformer block
7
+ # Added a learned residual scale to all blocks and all residuals. This allowed bfloat16 training to stabilize, prior it was just exploding.
8
+
9
+ # This architecture was my attempt at the following Simple Diffusion paper with some modifications:
10
+ # https://arxiv.org/pdf/2410.19324v1
11
+
12
+ # Very similar to GeGLU or SwiGLU, there's a learned gate FN, uses arctan as the activation fn.
13
+ class xATGLU(nn.Module):
14
+ def __init__(self, input_dim, output_dim, bias=True):
15
+ super().__init__()
16
+ # GATE path | VALUE path
17
+ self.proj = nn.Linear(input_dim, output_dim * 2, bias=bias)
18
+ nn.init.kaiming_normal_(self.proj.weight, nonlinearity='linear')
19
+
20
+ self.alpha = nn.Parameter(torch.zeros(1))
21
+ self.half_pi = torch.pi / 2
22
+ self.inv_pi = 1 / torch.pi
23
+
24
+ def forward(self, x):
25
+ projected = self.proj(x)
26
+ gate_path, value_path = projected.chunk(2, dim=-1)
27
+
28
+ # Apply arctan gating with expanded range via learned alpha -- https://arxiv.org/pdf/2405.20768
29
+ gate = (torch.arctan(gate_path) + self.half_pi) * self.inv_pi
30
+ expanded_gate = gate * (1 + 2 * self.alpha) - self.alpha
31
+
32
+ return expanded_gate * value_path # g(x) × y
33
+
34
+ # Tensor product attention, modified. Original code from:
35
+ # https://github.com/tensorgi/T6/blob/main/model/T6_ropek.py
36
+ # https://arxiv.org/pdf/2501.06425
37
+
38
+ class CPLinear(nn.Module):
39
+ def __init__(self, in_features, n_head, head_dim, rank: int = 1, q_rank: int = 12):
40
+ super(CPLinear, self).__init__()
41
+ self.in_features = in_features
42
+ self.n_head = n_head
43
+ self.head_dim = head_dim
44
+ self.rank = rank
45
+ self.q_rank = q_rank
46
+
47
+ self.W_A_q = nn.Linear(in_features, n_head * q_rank, bias=False)
48
+ self.W_A_k = nn.Linear(in_features, n_head * rank, bias=False)
49
+ self.W_A_v = nn.Linear(in_features, n_head * rank, bias=False)
50
+
51
+ nn.init.xavier_normal_(self.W_A_q.weight)
52
+ nn.init.xavier_normal_(self.W_A_k.weight)
53
+ nn.init.xavier_normal_(self.W_A_v.weight)
54
+
55
+ self.W_B_q = nn.Linear(in_features, q_rank * head_dim, bias=False)
56
+ self.W_B_k = nn.Linear(in_features, rank * head_dim, bias=False)
57
+ self.W_B_v = nn.Linear(in_features, rank * head_dim, bias=False)
58
+
59
+ nn.init.xavier_normal_(self.W_B_q.weight)
60
+ nn.init.xavier_normal_(self.W_B_k.weight)
61
+ nn.init.xavier_normal_(self.W_B_v.weight)
62
+
63
+ def forward(self, x):
64
+ batch_size, seq_len, _ = x.size()
65
+
66
+ # A clarification on the naming, it's somewhat standard to call the two low rank matrices A and B, so I've followed that.
67
+
68
+ # Compute intermediate variables A for Q, K, and V
69
+ A_q = self.W_A_q(x).view(batch_size, seq_len, self.n_head, self.q_rank)
70
+ A_k = self.W_A_k(x).view(batch_size, seq_len, self.n_head, self.rank)
71
+ A_v = self.W_A_v(x).view(batch_size, seq_len, self.n_head, self.rank)
72
+
73
+ # Compute intermediate variables B for Q, K, and V
74
+ B_q = self.W_B_q(x).view(batch_size, seq_len, self.q_rank, self.head_dim)
75
+ B_k = self.W_B_k(x).view(batch_size, seq_len, self.rank, self.head_dim)
76
+ B_v = self.W_B_v(x).view(batch_size, seq_len, self.rank, self.head_dim)
77
+
78
+ # Reshape A_q, A_k, A_v
79
+ A_q = A_q.view(batch_size * seq_len, self.n_head, self.q_rank)
80
+ A_k = A_k.view(batch_size * seq_len, self.n_head, self.rank)
81
+ A_v = A_v.view(batch_size * seq_len, self.n_head, self.rank)
82
+
83
+ # Reshape B_k, B_v
84
+ B_q = B_q.view(batch_size * seq_len, self.q_rank, self.head_dim)
85
+ B_k = B_k.view(batch_size * seq_len, self.rank, self.head_dim)
86
+ B_v = B_v.view(batch_size * seq_len, self.rank, self.head_dim)
87
+
88
+ q = torch.bmm(A_q, B_q).div_(self.q_rank).view(batch_size, seq_len, self.n_head, self.head_dim)
89
+ k = torch.bmm(A_k, B_k).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim)
90
+ v = torch.bmm(A_v, B_v).div_(self.rank).view(batch_size, seq_len, self.n_head, self.head_dim)
91
+
92
+ return q, k, v
93
+
94
+ # Very possible this is not a good method for positional encoding in DiT, in fact it may be actively harmful. It does help in small datasets though.
95
+ # No positional embedding should be a serious consideration for high compute resources/large data scenarios.
96
+ class Rotary(torch.nn.Module):
97
+ def __init__(self, dim, base=10000):
98
+ super().__init__()
99
+ self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
100
+ self.seq_len_cached = None
101
+ self.cos_cached = None
102
+ self.sin_cached = None
103
+
104
+ def forward(self, x):
105
+ seq_len = x.shape[1]
106
+ if seq_len != self.seq_len_cached:
107
+ self.seq_len_cached = seq_len
108
+ t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
109
+ freqs = torch.outer(t, self.inv_freq).to(x.device)
110
+ self.cos_cached = freqs.cos().bfloat16()
111
+ self.sin_cached = freqs.sin().bfloat16()
112
+ return self.cos_cached[None, :, None, :], self.sin_cached[None, :, None, :]
113
+
114
+ def apply_rotary_emb(x, cos, sin):
115
+ assert x.ndim == 4 # multihead attention
116
+ d = x.shape[3] // 2
117
+ x1 = x[..., :d]
118
+ x2 = x[..., d:]
119
+ y1 = x1 * cos + x2 * sin
120
+ y2 = x1 * (-sin) + x2 * cos
121
+ return torch.cat([y1, y2], 3).type_as(x)
122
+
123
+ class TensorProductAttentionWithRope(nn.Module):
124
+ def __init__(self, n_head, head_dim, n_embd, kv_rank=2, q_rank=6):
125
+ super().__init__()
126
+ self.n_head = n_head
127
+ self.head_dim = head_dim
128
+ self.n_embd = n_embd
129
+ self.kv_rank = kv_rank
130
+ self.q_rank = q_rank
131
+
132
+ self.c_qkv = CPLinear(self.n_embd, self.n_head, self.head_dim, self.kv_rank, self.q_rank)
133
+
134
+ # Output projection. Bias seems sensible here, each head can learn a shift.
135
+ self.o_proj = xATGLU(self.n_head * self.head_dim, self.n_embd, bias=True)
136
+
137
+ # Not a layer, just a helper
138
+ self.rotary = Rotary(self.head_dim)
139
+
140
+ def forward(self, x):
141
+ B, T, C = x.size() # batch_size, seq_length (T), embedding_dim
142
+
143
+ # Get Q, K, V through CPLinear factorization
144
+ q, k, v = self.c_qkv(x) # Each shape: (B, T, n_head, head_dim)
145
+
146
+ cos, sin = self.rotary(q)
147
+ q = apply_rotary_emb(q, cos, sin)
148
+ k = apply_rotary_emb(k, cos, sin)
149
+
150
+ # SDPA expects (B, n_head, T, head_dim)
151
+ q = q.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
152
+ k = k.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
153
+ v = v.permute(0, 2, 1, 3) # batch seq heads dim -> batch heads seq dim
154
+
155
+ # Compute attention using scaled_dot_product_attention
156
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
157
+
158
+ # Back to B T C
159
+ y = y.transpose(1, 2).flatten(2)
160
+ y = self.o_proj(y)
161
+
162
+ return y
163
+
164
+ class ResBlock(nn.Module):
165
+ def __init__(self, channels):
166
+ super().__init__()
167
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
168
+ self.norm1 = nn.GroupNorm(32, channels)
169
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
170
+ self.norm2 = nn.GroupNorm(32, channels)
171
+
172
+ self.learned_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
173
+
174
+ def forward(self, x):
175
+ h = self.conv1(F.silu(self.norm1(x)))
176
+ h = self.conv2(F.silu(self.norm2(h)))
177
+ return x + h * self.learned_residual_scale
178
+
179
+ class TransformerBlock(nn.Module):
180
+ def __init__(self, channels, num_heads=8):
181
+ super().__init__()
182
+ self.norm1 = nn.LayerNorm(channels)
183
+ self.norm2 = nn.LayerNorm(channels)
184
+
185
+ # Params recommended by TPA paper, seem to work fine.
186
+ self.attn = TensorProductAttentionWithRope(
187
+ n_head=num_heads,
188
+ head_dim=channels // num_heads,
189
+ n_embd=channels,
190
+ kv_rank=2,
191
+ q_rank=6
192
+ )
193
+
194
+ self.mlp = nn.Sequential(
195
+ xATGLU(channels, 2 * channels, bias=False),
196
+ nn.Linear(2 * channels, channels, bias=False) # Candidate for a bias
197
+ )
198
+
199
+ self.learned_residual_scale_attn = nn.Parameter(torch.ones(1) * 0.1)
200
+ self.learned_residual_scale_mlp = nn.Parameter(torch.ones(1) * 0.1)
201
+
202
+ def forward(self, x):
203
+ # Input shape B C H W
204
+ b, c, h, w = x.shape
205
+
206
+ x = x.reshape(b, h * w, c) # [B, H*W, C]
207
+
208
+ # Pre-norm architecture, this was really helpful for network stability when using bf16
209
+ identity = x
210
+ x = self.norm1(x)
211
+ h_attn = self.attn(x)
212
+ #h_attn, _ = self.attn(x, x, x)
213
+ x = identity + h_attn * self.learned_residual_scale_attn
214
+
215
+ identity = x
216
+ x = self.norm2(x)
217
+ h_mlp = self.mlp(x)
218
+ x = identity + h_mlp * self.learned_residual_scale_mlp
219
+
220
+ # Reshape back to B C H W
221
+ x = x.permute(1, 2, 0).reshape(b, c, h, w)
222
+ return x
223
+
224
+ class LevelBlock(nn.Module):
225
+ def __init__(self, channels, num_blocks, block_type='res'):
226
+ super().__init__()
227
+ self.blocks = nn.ModuleList()
228
+ for _ in range(num_blocks):
229
+ if block_type == 'transformer':
230
+ self.blocks.append(TransformerBlock(channels))
231
+ else:
232
+ self.blocks.append(ResBlock(channels))
233
+
234
+ def forward(self, x):
235
+ for block in self.blocks:
236
+ x = block(x)
237
+ return x
238
+
239
+ class AsymmetricResidualUDiT(nn.Module):
240
+ def __init__(self,
241
+ in_channels=3, # Input color channels
242
+ base_channels=128, # Initial feature size, dramatically increases parameter size of network.
243
+ patch_size=2, # Smaller patches dramatically increases flops and compute expenses. Recommend >=4 unless you have real compute.
244
+ num_levels=3, # Feature downsample, essentially the unet depth -- so we down/upsample three times. Dramatically increases parameters as you increase.
245
+ encoder_blocks=3, # Can be different number of blocks VS decoder_blocks
246
+ decoder_blocks=7, # Can be different number of blocks VS encoder_blocks
247
+ encoder_transformer_thresh=2, #When to start using transformer blocks instead of res blocks in the encoder. (>=)
248
+ decoder_transformer_thresh=4, #When to stop using transformer blocks instead of res blocks in the decoder. (<=)
249
+ mid_blocks=16, # Number of middle transformer blocks. Relatively cheap as this is at the bottom of the unet feature bottleneck.
250
+ ):
251
+ super().__init__()
252
+ self.learned_middle_residual_scale = nn.Parameter(torch.ones(1) * 0.1)
253
+ # Initial projection from image space
254
+ self.patch_embed = nn.Conv2d(in_channels, base_channels,
255
+ kernel_size=patch_size, stride=patch_size)
256
+
257
+ self.encoders = nn.ModuleList()
258
+ curr_channels = base_channels
259
+
260
+ for level in range(num_levels):
261
+ use_transformer = level >= encoder_transformer_thresh # Use transformers for latter levels
262
+
263
+ # Encoder blocks -- N = encoder_blocks
264
+ self.encoders.append(
265
+ LevelBlock(curr_channels, encoder_blocks, use_transformer)
266
+ )
267
+
268
+ # Each successive decoder halves the size of the feature space for each step, except for the last level.
269
+ if level < num_levels - 1:
270
+ self.encoders.append(
271
+ nn.Conv2d(curr_channels, curr_channels * 2, 1)
272
+ )
273
+ curr_channels *= 2
274
+
275
+ # Middle transformer blocks -- N = mid_blocks
276
+ self.middle = nn.ModuleList([
277
+ TransformerBlock(curr_channels) for _ in range(mid_blocks)
278
+ ])
279
+
280
+ # Create decoder levels
281
+ self.decoders = nn.ModuleList()
282
+
283
+ for level in range(num_levels):
284
+ use_transformer = level <= decoder_transformer_thresh # Use transformers for early levels (inverse of encoder)
285
+
286
+ # Decoder blocks -- N = decoder_blocks
287
+ self.decoders.append(
288
+ LevelBlock(curr_channels, decoder_blocks, use_transformer)
289
+ )
290
+
291
+ # Each successive decoder halves the size of the feature space for each step, except for the last level.
292
+ if level < num_levels - 1:
293
+ self.decoders.append(
294
+ nn.Conv2d(curr_channels, curr_channels // 2, 1)
295
+ )
296
+ curr_channels //= 2
297
+
298
+ # Final projection back to image space
299
+ self.final_proj = nn.ConvTranspose2d(base_channels, in_channels,
300
+ kernel_size=patch_size, stride=patch_size)
301
+
302
+ def downsample(self, x):
303
+ return F.avg_pool2d(x, kernel_size=2)
304
+
305
+ def upsample(self, x):
306
+ return F.interpolate(x, scale_factor=2, mode='nearest')
307
+
308
+ def forward(self, x, t=None):
309
+ # x shape B C H W
310
+ # This patchifies our input, for example given an input shape like:
311
+ # From 2, 3, 256, 256
312
+ x = self.patch_embed(x)
313
+ # Our shape is now more channels and with smaller W and H
314
+ # To 2, 128, 64, 64
315
+
316
+
317
+ # *Per resolution e.g. per num_level resolution block more or less
318
+ # f(x) = fu( U(fm(D(h)) - D(h)) + h ) where h = fd(x)
319
+ #
320
+ # Where
321
+ # 1. h = fd(x) : Encoder path processes input
322
+ # 2. D(h) : Downsample the encoded features
323
+ # 3. fm(D(h)) : Middle transformer blocks process downsampled features
324
+ # 4. fm(D(h))-D(h): Subtract original downsampled features (residual connection)
325
+ # 5. U(...) : Upsample the processed features
326
+ # 6. ... + h : Add back original encoder features (skip connection)
327
+ # 7. fu(...) : Decoder path processes the combined features
328
+
329
+ residuals = []
330
+ curr_res = x
331
+
332
+ # Encoder path (computing h = fd(x))
333
+ h = x
334
+ for i, blocks in enumerate(self.encoders):
335
+ if isinstance(blocks, LevelBlock):
336
+ h = blocks(h)
337
+ else:
338
+ # Save residual before downsampling
339
+ residuals.append(curr_res)
340
+ # Downsample and update current residual
341
+ h = self.downsample(blocks(h))
342
+ curr_res = h
343
+
344
+ # Middle blocks (fm)
345
+ x = h
346
+ for block in self.middle:
347
+ x = block(x)
348
+
349
+ # Subtract the residual at this level (D(h))
350
+ x = x - curr_res * self.learned_middle_residual_scale
351
+
352
+ # Decoder path (fu)
353
+ for i, blocks in enumerate(self.decoders):
354
+ if isinstance(blocks, LevelBlock):
355
+ x = blocks(x)
356
+ else:
357
+ # Channel reduction
358
+ x = blocks(x)
359
+ # Upsample
360
+ x = self.upsample(x)
361
+ # Add residual from encoder at this level, LIFO, last residual added is the first we want, since it's this u-shape.
362
+ curr_res = residuals.pop()
363
+ x = x + curr_res * self.learned_middle_residual_scale
364
+
365
+ # Final projection
366
+ x = self.final_proj(x)
367
+
368
+ return x
step_799.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29410f1aac9ed73a51a1b225f6d3c5cbe5560fa5a6521c8f464030b1a2de6157
3
+ size 407377304
train.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ import torchvision.transforms as transforms
6
+ import torchvision.utils as vutils
7
+ from datasets import load_dataset, load_from_disk
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ from safetensors.torch import save_file, load_file
11
+ import os, time
12
+ from models import AsymmetricResidualUDiT, xATGLU
13
+ from torch.cuda.amp import autocast
14
+
15
+ from torch.optim.lr_scheduler import CosineAnnealingLR
16
+ from torch.distributions import Normal
17
+ from schedulefree import AdamWScheduleFree
18
+ from distributed_shampoo import AdamGraftingConfig, DistributedShampoo
19
+
20
+ # Changes
21
+ # MAE replace MSE
22
+ # Larger shampoo preconditioner step for stability
23
+ # Larger shampoo preconditioner dim 1024 -> 2048
24
+ # Commented out norm.
25
+
26
+ def preload_dataset(image_size=256, device="cuda", max_images=50000):
27
+ """Preload and cache the entire dataset in GPU memory"""
28
+ print("Loading and preprocessing dataset...")
29
+ dataset = load_dataset("jiovine/pixel-art-nouns-2k", split="train")
30
+ #dataset = load_dataset("reach-vb/pokemon-blip-captions", split="train")
31
+ #dataset = load_from_disk("./new_dataset")
32
+
33
+ transform = transforms.Compose([
34
+ transforms.ToTensor(),
35
+ #transforms.Pad((35, 0), fill=0), # Add 35 pixels on each side horizontally (70 total to get from 186 to 256)
36
+ transforms.Resize((256, 256), antialias=True),
37
+ transforms.Lambda(lambda x: (x * 2) - 1) # Scale to [-1, 1]
38
+ ])
39
+
40
+ all_images = []
41
+
42
+ for i, example in enumerate(dataset):
43
+ if max_images and i >= max_images:
44
+ break
45
+
46
+ img_tensor = transform(example['image'])
47
+
48
+ all_images.extend([
49
+ img_tensor,
50
+ ])
51
+
52
+ # Stack entire dataset onto gpu
53
+ images_tensor = torch.stack(all_images).to(device)
54
+ print(f"Dataset loaded: {images_tensor.shape} ({images_tensor.element_size() * images_tensor.nelement() / 1024/1024:.2f} MB)")
55
+
56
+ return TensorDataset(images_tensor)
57
+
58
+ def count_parameters(model):
59
+ total_params = sum(p.numel() for p in model.parameters())
60
+ print(f'Total parameters: {total_params:,} ({total_params/1e6:.2f}M)')
61
+
62
+ def save_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
63
+ model_state = model.state_dict()
64
+ save_file(model_state, filename)
65
+
66
+ def load_checkpoint(model, optimizer, filename="checkpoint.safetensors"):
67
+ model_state = load_file(filename)
68
+ model.load_state_dict(model_state)
69
+
70
+ # https://arxiv.org/abs/2210.02747
71
+ class OptimalTransportLinearFlowGenerator():
72
+ def __init__(self, sigma_min=0.001):
73
+ self.sigma_min = sigma_min
74
+
75
+ def loss(self, model, x1, device):
76
+ batch_size = x1.shape[0]
77
+ # Uniform Dist 0..1 -- t ~ U[0, 1]
78
+ t = torch.rand(batch_size, 1, 1, 1, device=device)
79
+
80
+ # Sample noise -- x0 ~ N[0, I]
81
+ x0 = torch.randn_like(x1)
82
+
83
+ # Compute OT conditional flow matching path interpolation
84
+
85
+ # My understanding of this process -- We start at some random time t (Per sample)
86
+ # We have a pure noise value at x0, which is a totally destroyed signal.
87
+ # We have the actual image as x1 which is a perfect signal.
88
+ # We are going to destroy an amount of the image equal to t% of the signal. So if t is 0.3 we're destroying about 30% of the signal(image)
89
+ # The final x_t represents our combined noisy singal, you can imagine 30% random noise overlayed onto the normal image.
90
+ # We calculate the shortest path between x0 and x1, a straight line segment (lets call it a displacement vector) in their respective space, conditioned on the timestep.
91
+ # We then try to predict the displacement vector where we provide our partially noisy signal and our conditioning timestep
92
+ # We check the prediction against the real displacement vector we calculated to see how good the prediction was. Then we back propogate, baby.
93
+
94
+ sigma_t = 1 - (1 - self.sigma_min) * t # As t increases this value decreases. This is almost 1 - t
95
+ mu_t = t * x1 # As t increases this increases.
96
+ x_t = sigma_t * x0 + mu_t # This is essentially a mixture of noise and signal ((1-t) * x0) + ((t) * x1)
97
+
98
+ # Compute target
99
+ target = x1 - (1 - self.sigma_min) * x0 # This is the target displacement vector (direction and magnitude) that we need to travel from x0 to x1.
100
+ v_t = model(x_t, t) # v_t is our displacement vector prediction
101
+
102
+ # Magnitude-corrected MSE
103
+ # The 69 factor helps with very small gradients, as this loss tends to be b/w [0..1], this rescales to something more like [0..69]
104
+ # Other values like 420 might lead to numerical instability if the loss is too large.
105
+ loss = F.mse_loss(v_t, target)*69 # Compare the displacement vector the network predicted to the actual displacement we calculated as mean absolute error.
106
+
107
+ return loss
108
+
109
+ def write_logs(writer, model, loss, batch_idx, epoch, epoch_time, batch_size, lr, log_gradients=True):
110
+ """
111
+ TensorBoard logging
112
+
113
+ Args:
114
+ writer: torch.utils.tensorboard.SummaryWriter instance
115
+ model: torch.nn.Module - the model being trained
116
+ loss: float or torch.Tensor - the loss value to log
117
+ batch_idx: int - current batch index
118
+ epoch: int - current epoch
119
+ epoch_time: float - time taken for epoch
120
+ batch_size: int - current batch size
121
+ lr: float - current learning rate
122
+ samples: Optional[torch.Tensor] - generated samples to log (only passed every 50 epochs)
123
+ log_gradients: bool - whether to log gradient norms
124
+ """
125
+ total_steps = epoch * batch_idx
126
+
127
+ writer.add_scalar('Loss/batch', loss, total_steps)
128
+ writer.add_scalar('Time/epoch', epoch_time, epoch)
129
+ writer.add_scalar('Training/batch_size', batch_size, epoch)
130
+ writer.add_scalar('Training/learning_rate', lr, epoch)
131
+
132
+ # Gradient logging
133
+ if log_gradients:
134
+ total_norm = 0.0
135
+ for p in model.parameters():
136
+ if p.grad is not None:
137
+ param_norm = p.grad.detach().data.norm(2)
138
+ total_norm += param_norm.item() ** 2
139
+ total_norm = total_norm ** 0.5
140
+ writer.add_scalar('Gradients/total_norm', total_norm, total_steps)
141
+
142
+ def train_udit_flow(num_epochs=1000, initial_batch_sizes=[8, 16, 32, 64, 128], epoch_batch_drop_at=40, device="cuda", dtype=torch.float32):
143
+ dataset = preload_dataset(device=device)
144
+ temp_loader = DataLoader(dataset, batch_size=initial_batch_sizes[0], shuffle=True)
145
+ first_batch = next(iter(temp_loader))
146
+ image_shape = first_batch[0].shape[1:]
147
+
148
+ writer = SummaryWriter('logs/current_run')
149
+
150
+ model = AsymmetricResidualUDiT(
151
+ in_channels=3,
152
+ base_channels=128,
153
+ num_levels=3,
154
+ patch_size=4,
155
+ encoder_blocks=3,
156
+ decoder_blocks=7,
157
+ encoder_transformer_thresh=2,
158
+ decoder_transformer_thresh=4,
159
+ mid_blocks=16
160
+ ).to(device).to(torch.float32)
161
+ model.train()
162
+ count_parameters(model)
163
+
164
+ # optimizer = AdamWScheduleFree(
165
+ # model.parameters(),
166
+ # lr=4e-5,
167
+ # warmup_steps=100
168
+ # )
169
+ # optimizer.train()
170
+
171
+ optimizer = DistributedShampoo(
172
+ model.parameters(),
173
+ lr=0.001,
174
+ betas=(0.9, 0.999),
175
+ epsilon=1e-10,
176
+ weight_decay=1e-05,
177
+ max_preconditioner_dim=2048,
178
+ precondition_frequency=100,
179
+ start_preconditioning_step=250,
180
+ use_decoupled_weight_decay=False,
181
+ grafting_config=AdamGraftingConfig(
182
+ beta2=0.999,
183
+ epsilon=1e-10,
184
+ ),
185
+ )
186
+
187
+ scaler = torch.amp.GradScaler("cuda")
188
+
189
+ scheduler = CosineAnnealingLR(
190
+ optimizer,
191
+ T_max=num_epochs,
192
+ eta_min=1e-5
193
+ )
194
+
195
+ current_batch_sizes = initial_batch_sizes.copy()
196
+ next_drop_epoch = epoch_batch_drop_at
197
+ interval_multiplier = 2
198
+
199
+ torch.set_float32_matmul_precision('high')
200
+ # torch.backends.cudnn.benchmark = True
201
+ # torch.backends.cuda.matmul.allow_fp16_accumulation = True
202
+
203
+ model = torch.compile(
204
+ model,
205
+ backend='inductor',
206
+ dynamic=False,
207
+ fullgraph=True,
208
+ options={
209
+ "epilogue_fusion": True,
210
+ "max_autotune": True,
211
+ "cuda.use_fast_math": True,
212
+ }
213
+ )
214
+
215
+ flow_transport = OptimalTransportLinearFlowGenerator(sigma_min=0.001)
216
+
217
+ current_batch_size = current_batch_sizes[-1]
218
+ dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True)
219
+
220
+ for epoch in range(num_epochs):
221
+ epoch_start_time = time.time()
222
+ total_loss = 0
223
+
224
+ # Batch size decay logic
225
+ # Geomtric growth, every X*N+(X-1*N+...) use the number batch size in the list.
226
+ if False:
227
+ if epoch > 0 and epoch == next_drop_epoch and len(current_batch_sizes) > 1:
228
+ current_batch_sizes.pop()
229
+ next_interval = epoch_batch_drop_at * interval_multiplier
230
+ next_drop_epoch += next_interval
231
+ interval_multiplier += 1
232
+ print(f"\nEpoch {epoch}: Reducing batch size to {current_batch_sizes[-1]}")
233
+ print(f"Next drop will occur at epoch {next_drop_epoch} (interval: {next_interval})")
234
+
235
+ curr_lr = optimizer.param_groups[0]['lr']
236
+
237
+ for batch_idx, batch in enumerate(dataloader):
238
+ optimizer.zero_grad()
239
+ with torch.autocast(device_type='cuda', dtype=dtype):
240
+ x1 = batch[0]
241
+ batch_size = x1.shape[0]
242
+
243
+ # x1 shape: B, C, H, W
244
+ loss = flow_transport.loss(model, x1, device)
245
+
246
+ scaler.scale(loss).backward()
247
+ scaler.unscale_(optimizer)
248
+ #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
249
+ scaler.step(optimizer)
250
+ scaler.update()
251
+ total_loss += loss.item()
252
+
253
+ avg_loss = total_loss / len(dataloader)
254
+
255
+ epoch_time = time.time() - epoch_start_time
256
+ print(f"Epoch {epoch}, Took: {epoch_time:.2f}s, Batch Size: {current_batch_size}, "
257
+ f"Average Loss: {avg_loss:.4f}, Learning Rate: {curr_lr:.2e}")
258
+
259
+ write_logs(writer, model, avg_loss, batch_idx, epoch, epoch_time, current_batch_size, curr_lr)
260
+ if (epoch + 1) % 10 == 0:
261
+ with torch.amp.autocast('cuda', dtype=dtype):
262
+ sampling_start_time = time.time()
263
+ samples = sample(model, device=device, dtype=dtype)
264
+ os.makedirs("samples", exist_ok=True)
265
+ vutils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=4, padding=2)
266
+
267
+ sample_time = time.time() - sampling_start_time
268
+ print(f"Sampling took: {sample_time:.2f}s")
269
+
270
+ if (epoch + 1) % 50 == 0:
271
+ save_checkpoint(model, optimizer, f"step_{epoch}.safetensors")
272
+
273
+ scheduler.step()
274
+
275
+ return model
276
+
277
+ def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32):
278
+ with torch.amp.autocast('cuda', dtype=dtype):
279
+
280
+ x = torch.randn(n_samples, 3, image_size, image_size, device=device)
281
+ ts = torch.linspace(0, 1, n_steps, device=device)
282
+ dt = 1/n_steps
283
+
284
+ # Forward Euler Integration step 0..1
285
+ with torch.no_grad():
286
+ for i in range(len(ts)):
287
+ t = ts[i]
288
+ t_input = t.repeat(n_samples, 1, 1, 1)
289
+
290
+ v_t = model(x, t_input)
291
+
292
+ x = x + v_t * dt
293
+
294
+ return x.float()
295
+
296
+ if __name__ == "__main__":
297
+ device = "cuda" if torch.cuda.is_available() else "cpu"
298
+ print(f"Using device: {device}")
299
+
300
+ model = train_udit_flow(
301
+ device=device,
302
+ initial_batch_sizes=[16,32,64],
303
+ epoch_batch_drop_at=100,
304
+ dtype=torch.bfloat16
305
+ )
306
+
307
+ print("Training complete! Samples saved in 'samples' directory")