|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torchvision.transforms as transforms |
|
import torchvision.utils as vutils |
|
from datasets import load_dataset, load_from_disk |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from torch.utils.tensorboard import SummaryWriter |
|
from safetensors.torch import save_file, load_file |
|
import os, time |
|
from models import AsymmetricResidualUDiT, xATGLU |
|
from torch.cuda.amp import autocast |
|
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
from torch.distributions import Normal |
|
from schedulefree import AdamWScheduleFree |
|
from distributed_shampoo import AdamGraftingConfig, DistributedShampoo |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preload_dataset(image_size=256, device="cuda", max_images=50000): |
|
"""Preload and cache the entire dataset in GPU memory""" |
|
print("Loading and preprocessing dataset...") |
|
dataset = load_dataset("jiovine/pixel-art-nouns-2k", split="train") |
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
|
|
transforms.Resize((256, 256), antialias=True), |
|
transforms.Lambda(lambda x: (x * 2) - 1) |
|
]) |
|
|
|
all_images = [] |
|
|
|
for i, example in enumerate(dataset): |
|
if max_images and i >= max_images: |
|
break |
|
|
|
img_tensor = transform(example['image']) |
|
|
|
all_images.extend([ |
|
img_tensor, |
|
]) |
|
|
|
|
|
images_tensor = torch.stack(all_images).to(device) |
|
print(f"Dataset loaded: {images_tensor.shape} ({images_tensor.element_size() * images_tensor.nelement() / 1024/1024:.2f} MB)") |
|
|
|
return TensorDataset(images_tensor) |
|
|
|
def count_parameters(model): |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f'Total parameters: {total_params:,} ({total_params/1e6:.2f}M)') |
|
|
|
def save_checkpoint(model, optimizer, filename="checkpoint.safetensors"): |
|
model_state = model.state_dict() |
|
save_file(model_state, filename) |
|
|
|
def load_checkpoint(model, optimizer, filename="checkpoint.safetensors"): |
|
model_state = load_file(filename) |
|
model.load_state_dict(model_state) |
|
|
|
|
|
class OptimalTransportLinearFlowGenerator(): |
|
def __init__(self, sigma_min=0.001): |
|
self.sigma_min = sigma_min |
|
|
|
def loss(self, model, x1, device): |
|
batch_size = x1.shape[0] |
|
|
|
t = torch.rand(batch_size, 1, 1, 1, device=device) |
|
|
|
|
|
x0 = torch.randn_like(x1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sigma_t = 1 - (1 - self.sigma_min) * t |
|
mu_t = t * x1 |
|
x_t = sigma_t * x0 + mu_t |
|
|
|
|
|
target = x1 - (1 - self.sigma_min) * x0 |
|
v_t = model(x_t, t) |
|
|
|
|
|
|
|
|
|
loss = F.mse_loss(v_t, target)*69 |
|
|
|
return loss |
|
|
|
def write_logs(writer, model, loss, batch_idx, epoch, epoch_time, batch_size, lr, log_gradients=True): |
|
""" |
|
TensorBoard logging |
|
|
|
Args: |
|
writer: torch.utils.tensorboard.SummaryWriter instance |
|
model: torch.nn.Module - the model being trained |
|
loss: float or torch.Tensor - the loss value to log |
|
batch_idx: int - current batch index |
|
epoch: int - current epoch |
|
epoch_time: float - time taken for epoch |
|
batch_size: int - current batch size |
|
lr: float - current learning rate |
|
samples: Optional[torch.Tensor] - generated samples to log (only passed every 50 epochs) |
|
log_gradients: bool - whether to log gradient norms |
|
""" |
|
total_steps = epoch * batch_idx |
|
|
|
writer.add_scalar('Loss/batch', loss, total_steps) |
|
writer.add_scalar('Time/epoch', epoch_time, epoch) |
|
writer.add_scalar('Training/batch_size', batch_size, epoch) |
|
writer.add_scalar('Training/learning_rate', lr, epoch) |
|
|
|
|
|
if log_gradients: |
|
total_norm = 0.0 |
|
for p in model.parameters(): |
|
if p.grad is not None: |
|
param_norm = p.grad.detach().data.norm(2) |
|
total_norm += param_norm.item() ** 2 |
|
total_norm = total_norm ** 0.5 |
|
writer.add_scalar('Gradients/total_norm', total_norm, total_steps) |
|
|
|
def train_udit_flow(num_epochs=1000, initial_batch_sizes=[8, 16, 32, 64, 128], epoch_batch_drop_at=40, device="cuda", dtype=torch.float32): |
|
dataset = preload_dataset(device=device) |
|
temp_loader = DataLoader(dataset, batch_size=initial_batch_sizes[0], shuffle=True) |
|
first_batch = next(iter(temp_loader)) |
|
image_shape = first_batch[0].shape[1:] |
|
|
|
writer = SummaryWriter('logs/current_run') |
|
|
|
model = AsymmetricResidualUDiT( |
|
in_channels=3, |
|
base_channels=128, |
|
num_levels=3, |
|
patch_size=4, |
|
encoder_blocks=3, |
|
decoder_blocks=7, |
|
encoder_transformer_thresh=2, |
|
decoder_transformer_thresh=4, |
|
mid_blocks=16 |
|
).to(device).to(torch.float32) |
|
model.train() |
|
count_parameters(model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = DistributedShampoo( |
|
model.parameters(), |
|
lr=0.001, |
|
betas=(0.9, 0.999), |
|
epsilon=1e-10, |
|
weight_decay=1e-05, |
|
max_preconditioner_dim=2048, |
|
precondition_frequency=100, |
|
start_preconditioning_step=250, |
|
use_decoupled_weight_decay=False, |
|
grafting_config=AdamGraftingConfig( |
|
beta2=0.999, |
|
epsilon=1e-10, |
|
), |
|
) |
|
|
|
scaler = torch.amp.GradScaler("cuda") |
|
|
|
scheduler = CosineAnnealingLR( |
|
optimizer, |
|
T_max=num_epochs, |
|
eta_min=1e-5 |
|
) |
|
|
|
current_batch_sizes = initial_batch_sizes.copy() |
|
next_drop_epoch = epoch_batch_drop_at |
|
interval_multiplier = 2 |
|
|
|
torch.set_float32_matmul_precision('high') |
|
|
|
|
|
|
|
model = torch.compile( |
|
model, |
|
backend='inductor', |
|
dynamic=False, |
|
fullgraph=True, |
|
options={ |
|
"epilogue_fusion": True, |
|
"max_autotune": True, |
|
"cuda.use_fast_math": True, |
|
} |
|
) |
|
|
|
flow_transport = OptimalTransportLinearFlowGenerator(sigma_min=0.001) |
|
|
|
current_batch_size = current_batch_sizes[-1] |
|
dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True) |
|
|
|
for epoch in range(num_epochs): |
|
epoch_start_time = time.time() |
|
total_loss = 0 |
|
|
|
|
|
|
|
if False: |
|
if epoch > 0 and epoch == next_drop_epoch and len(current_batch_sizes) > 1: |
|
current_batch_sizes.pop() |
|
next_interval = epoch_batch_drop_at * interval_multiplier |
|
next_drop_epoch += next_interval |
|
interval_multiplier += 1 |
|
print(f"\nEpoch {epoch}: Reducing batch size to {current_batch_sizes[-1]}") |
|
print(f"Next drop will occur at epoch {next_drop_epoch} (interval: {next_interval})") |
|
|
|
curr_lr = optimizer.param_groups[0]['lr'] |
|
|
|
for batch_idx, batch in enumerate(dataloader): |
|
optimizer.zero_grad() |
|
with torch.autocast(device_type='cuda', dtype=dtype): |
|
x1 = batch[0] |
|
batch_size = x1.shape[0] |
|
|
|
|
|
loss = flow_transport.loss(model, x1, device) |
|
|
|
scaler.scale(loss).backward() |
|
scaler.unscale_(optimizer) |
|
|
|
scaler.step(optimizer) |
|
scaler.update() |
|
total_loss += loss.item() |
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
|
epoch_time = time.time() - epoch_start_time |
|
print(f"Epoch {epoch}, Took: {epoch_time:.2f}s, Batch Size: {current_batch_size}, " |
|
f"Average Loss: {avg_loss:.4f}, Learning Rate: {curr_lr:.2e}") |
|
|
|
write_logs(writer, model, avg_loss, batch_idx, epoch, epoch_time, current_batch_size, curr_lr) |
|
if (epoch + 1) % 10 == 0: |
|
with torch.amp.autocast('cuda', dtype=dtype): |
|
sampling_start_time = time.time() |
|
samples = sample(model, device=device, dtype=dtype) |
|
os.makedirs("samples", exist_ok=True) |
|
vutils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=4, padding=2) |
|
|
|
sample_time = time.time() - sampling_start_time |
|
print(f"Sampling took: {sample_time:.2f}s") |
|
|
|
if (epoch + 1) % 50 == 0: |
|
save_checkpoint(model, optimizer, f"step_{epoch}.safetensors") |
|
|
|
scheduler.step() |
|
|
|
return model |
|
|
|
def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32): |
|
with torch.amp.autocast('cuda', dtype=dtype): |
|
|
|
x = torch.randn(n_samples, 3, image_size, image_size, device=device) |
|
ts = torch.linspace(0, 1, n_steps, device=device) |
|
dt = 1/n_steps |
|
|
|
|
|
with torch.no_grad(): |
|
for i in range(len(ts)): |
|
t = ts[i] |
|
t_input = t.repeat(n_samples, 1, 1, 1) |
|
|
|
v_t = model(x, t_input) |
|
|
|
x = x + v_t * dt |
|
|
|
return x.float() |
|
|
|
if __name__ == "__main__": |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
model = train_udit_flow( |
|
device=device, |
|
initial_batch_sizes=[16,32,64], |
|
epoch_batch_drop_at=100, |
|
dtype=torch.bfloat16 |
|
) |
|
|
|
print("Training complete! Samples saved in 'samples' directory") |