import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as transforms import torchvision.utils as vutils from datasets import load_dataset, load_from_disk from torch.utils.data import DataLoader, TensorDataset from torch.utils.tensorboard import SummaryWriter from safetensors.torch import save_file, load_file import os, time from models import AsymmetricResidualUDiT, xATGLU from torch.cuda.amp import autocast from torch.optim.lr_scheduler import CosineAnnealingLR from torch.distributions import Normal from schedulefree import AdamWScheduleFree from distributed_shampoo import AdamGraftingConfig, DistributedShampoo # Changes # MAE replace MSE # Larger shampoo preconditioner step for stability # Larger shampoo preconditioner dim 1024 -> 2048 # Commented out norm. def preload_dataset(image_size=256, device="cuda", max_images=50000): """Preload and cache the entire dataset in GPU memory""" print("Loading and preprocessing dataset...") dataset = load_dataset("jiovine/pixel-art-nouns-2k", split="train") #dataset = load_dataset("reach-vb/pokemon-blip-captions", split="train") #dataset = load_from_disk("./new_dataset") transform = transforms.Compose([ transforms.ToTensor(), #transforms.Pad((35, 0), fill=0), # Add 35 pixels on each side horizontally (70 total to get from 186 to 256) transforms.Resize((256, 256), antialias=True), transforms.Lambda(lambda x: (x * 2) - 1) # Scale to [-1, 1] ]) all_images = [] for i, example in enumerate(dataset): if max_images and i >= max_images: break img_tensor = transform(example['image']) all_images.extend([ img_tensor, ]) # Stack entire dataset onto gpu images_tensor = torch.stack(all_images).to(device) print(f"Dataset loaded: {images_tensor.shape} ({images_tensor.element_size() * images_tensor.nelement() / 1024/1024:.2f} MB)") return TensorDataset(images_tensor) def count_parameters(model): total_params = sum(p.numel() for p in model.parameters()) print(f'Total parameters: {total_params:,} ({total_params/1e6:.2f}M)') def save_checkpoint(model, optimizer, filename="checkpoint.safetensors"): model_state = model.state_dict() save_file(model_state, filename) def load_checkpoint(model, optimizer, filename="checkpoint.safetensors"): model_state = load_file(filename) model.load_state_dict(model_state) # https://arxiv.org/abs/2210.02747 class OptimalTransportLinearFlowGenerator(): def __init__(self, sigma_min=0.001): self.sigma_min = sigma_min def loss(self, model, x1, device): batch_size = x1.shape[0] # Uniform Dist 0..1 -- t ~ U[0, 1] t = torch.rand(batch_size, 1, 1, 1, device=device) # Sample noise -- x0 ~ N[0, I] x0 = torch.randn_like(x1) # Compute OT conditional flow matching path interpolation # My understanding of this process -- We start at some random time t (Per sample) # We have a pure noise value at x0, which is a totally destroyed signal. # We have the actual image as x1 which is a perfect signal. # We are going to destroy an amount of the image equal to t% of the signal. So if t is 0.3 we're destroying about 30% of the signal(image) # The final x_t represents our combined noisy singal, you can imagine 30% random noise overlayed onto the normal image. # We calculate the shortest path between x0 and x1, a straight line segment (lets call it a displacement vector) in their respective space, conditioned on the timestep. # We then try to predict the displacement vector where we provide our partially noisy signal and our conditioning timestep # We check the prediction against the real displacement vector we calculated to see how good the prediction was. Then we back propogate, baby. sigma_t = 1 - (1 - self.sigma_min) * t # As t increases this value decreases. This is almost 1 - t mu_t = t * x1 # As t increases this increases. x_t = sigma_t * x0 + mu_t # This is essentially a mixture of noise and signal ((1-t) * x0) + ((t) * x1) # Compute target target = x1 - (1 - self.sigma_min) * x0 # This is the target displacement vector (direction and magnitude) that we need to travel from x0 to x1. v_t = model(x_t, t) # v_t is our displacement vector prediction # Magnitude-corrected MSE # The 69 factor helps with very small gradients, as this loss tends to be b/w [0..1], this rescales to something more like [0..69] # Other values like 420 might lead to numerical instability if the loss is too large. loss = F.mse_loss(v_t, target)*69 # Compare the displacement vector the network predicted to the actual displacement we calculated as mean absolute error. return loss def write_logs(writer, model, loss, batch_idx, epoch, epoch_time, batch_size, lr, log_gradients=True): """ TensorBoard logging Args: writer: torch.utils.tensorboard.SummaryWriter instance model: torch.nn.Module - the model being trained loss: float or torch.Tensor - the loss value to log batch_idx: int - current batch index epoch: int - current epoch epoch_time: float - time taken for epoch batch_size: int - current batch size lr: float - current learning rate samples: Optional[torch.Tensor] - generated samples to log (only passed every 50 epochs) log_gradients: bool - whether to log gradient norms """ total_steps = epoch * batch_idx writer.add_scalar('Loss/batch', loss, total_steps) writer.add_scalar('Time/epoch', epoch_time, epoch) writer.add_scalar('Training/batch_size', batch_size, epoch) writer.add_scalar('Training/learning_rate', lr, epoch) # Gradient logging if log_gradients: total_norm = 0.0 for p in model.parameters(): if p.grad is not None: param_norm = p.grad.detach().data.norm(2) total_norm += param_norm.item() ** 2 total_norm = total_norm ** 0.5 writer.add_scalar('Gradients/total_norm', total_norm, total_steps) def train_udit_flow(num_epochs=1000, initial_batch_sizes=[8, 16, 32, 64, 128], epoch_batch_drop_at=40, device="cuda", dtype=torch.float32): dataset = preload_dataset(device=device) temp_loader = DataLoader(dataset, batch_size=initial_batch_sizes[0], shuffle=True) first_batch = next(iter(temp_loader)) image_shape = first_batch[0].shape[1:] writer = SummaryWriter('logs/current_run') model = AsymmetricResidualUDiT( in_channels=3, base_channels=128, num_levels=3, patch_size=4, encoder_blocks=3, decoder_blocks=7, encoder_transformer_thresh=2, decoder_transformer_thresh=4, mid_blocks=16 ).to(device).to(torch.float32) model.train() count_parameters(model) # optimizer = AdamWScheduleFree( # model.parameters(), # lr=4e-5, # warmup_steps=100 # ) # optimizer.train() optimizer = DistributedShampoo( model.parameters(), lr=0.001, betas=(0.9, 0.999), epsilon=1e-10, weight_decay=1e-05, max_preconditioner_dim=2048, precondition_frequency=100, start_preconditioning_step=250, use_decoupled_weight_decay=False, grafting_config=AdamGraftingConfig( beta2=0.999, epsilon=1e-10, ), ) scaler = torch.amp.GradScaler("cuda") scheduler = CosineAnnealingLR( optimizer, T_max=num_epochs, eta_min=1e-5 ) current_batch_sizes = initial_batch_sizes.copy() next_drop_epoch = epoch_batch_drop_at interval_multiplier = 2 torch.set_float32_matmul_precision('high') # torch.backends.cudnn.benchmark = True # torch.backends.cuda.matmul.allow_fp16_accumulation = True model = torch.compile( model, backend='inductor', dynamic=False, fullgraph=True, options={ "epilogue_fusion": True, "max_autotune": True, "cuda.use_fast_math": True, } ) flow_transport = OptimalTransportLinearFlowGenerator(sigma_min=0.001) current_batch_size = current_batch_sizes[-1] dataloader = DataLoader(dataset, batch_size=current_batch_size, shuffle=True) for epoch in range(num_epochs): epoch_start_time = time.time() total_loss = 0 # Batch size decay logic # Geomtric growth, every X*N+(X-1*N+...) use the number batch size in the list. if False: if epoch > 0 and epoch == next_drop_epoch and len(current_batch_sizes) > 1: current_batch_sizes.pop() next_interval = epoch_batch_drop_at * interval_multiplier next_drop_epoch += next_interval interval_multiplier += 1 print(f"\nEpoch {epoch}: Reducing batch size to {current_batch_sizes[-1]}") print(f"Next drop will occur at epoch {next_drop_epoch} (interval: {next_interval})") curr_lr = optimizer.param_groups[0]['lr'] for batch_idx, batch in enumerate(dataloader): optimizer.zero_grad() with torch.autocast(device_type='cuda', dtype=dtype): x1 = batch[0] batch_size = x1.shape[0] # x1 shape: B, C, H, W loss = flow_transport.loss(model, x1, device) scaler.scale(loss).backward() scaler.unscale_(optimizer) #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() total_loss += loss.item() avg_loss = total_loss / len(dataloader) epoch_time = time.time() - epoch_start_time print(f"Epoch {epoch}, Took: {epoch_time:.2f}s, Batch Size: {current_batch_size}, " f"Average Loss: {avg_loss:.4f}, Learning Rate: {curr_lr:.2e}") write_logs(writer, model, avg_loss, batch_idx, epoch, epoch_time, current_batch_size, curr_lr) if (epoch + 1) % 10 == 0: with torch.amp.autocast('cuda', dtype=dtype): sampling_start_time = time.time() samples = sample(model, device=device, dtype=dtype) os.makedirs("samples", exist_ok=True) vutils.save_image(samples, f"samples/epoch_{epoch}.png", nrow=4, padding=2) sample_time = time.time() - sampling_start_time print(f"Sampling took: {sample_time:.2f}s") if (epoch + 1) % 50 == 0: save_checkpoint(model, optimizer, f"step_{epoch}.safetensors") scheduler.step() return model def sample(model, n_samples=16, n_steps=50, image_size=256, device="cuda", sigma_min=0.001, dtype=torch.float32): with torch.amp.autocast('cuda', dtype=dtype): x = torch.randn(n_samples, 3, image_size, image_size, device=device) ts = torch.linspace(0, 1, n_steps, device=device) dt = 1/n_steps # Forward Euler Integration step 0..1 with torch.no_grad(): for i in range(len(ts)): t = ts[i] t_input = t.repeat(n_samples, 1, 1, 1) v_t = model(x, t_input) x = x + v_t * dt return x.float() if __name__ == "__main__": device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model = train_udit_flow( device=device, initial_batch_sizes=[16,32,64], epoch_batch_drop_at=100, dtype=torch.bfloat16 ) print("Training complete! Samples saved in 'samples' directory")