Spaces:
Runtime error
Runtime error
import torch | |
from torchvision.utils import make_grid | |
import math | |
from PIL import Image | |
from diffusion import create_diffusion | |
from diffusers.models import AutoencoderKL | |
import gradio as gr | |
from imagenet_class_data import IMAGENET_1K_CLASSES | |
from download import find_model | |
from models import DiT_XL_2 | |
def load_model(image_size=256): | |
assert image_size in [256, 512] | |
latent_size = image_size // 8 | |
model = DiT_XL_2(input_size=latent_size).to(device) | |
state_dict = find_model(f"DiT-XL-2-{image_size}x{image_size}.pt") | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
torch.set_grad_enabled(False) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
find_model(f"DiT-XL-2-512x512.pt") | |
model = load_model(image_size=256) | |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) | |
current_image_size = 256 | |
current_vae_model = "stabilityai/sd-vae-ft-mse" | |
def generate(image_size, vae_model, class_label, cfg_scale, num_sampling_steps, seed): | |
n = 1 | |
image_size = int(image_size.split("x")[0]) | |
global current_image_size | |
if image_size != current_image_size: | |
global model | |
model = model.to("cpu") | |
del model | |
if device == "cuda": | |
torch.cuda.empty_cache() | |
model = load_model(image_size=image_size) | |
current_image_size = image_size | |
global current_vae_model | |
if vae_model != current_vae_model: | |
global vae | |
if device == "cuda": | |
vae.to("cpu") | |
del vae | |
vae = AutoencoderKL.from_pretrained(vae_model).to(device) | |
# Seed PyTorch: | |
torch.manual_seed(seed) | |
# Setup diffusion | |
diffusion = create_diffusion(str(num_sampling_steps)) | |
# Create sampling noise: | |
latent_size = image_size // 8 | |
z = torch.randn(n, 4, latent_size, latent_size, device=device) | |
y = torch.tensor([class_label] * n, device=device) | |
# Setup classifier-free guidance: | |
z = torch.cat([z, z], 0) | |
y_null = torch.tensor([1000] * n, device=device) | |
y = torch.cat([y, y_null], 0) | |
model_kwargs = dict(y=y, cfg_scale=cfg_scale) | |
# Sample images: | |
samples = diffusion.p_sample_loop( | |
model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device | |
) | |
samples, _ = samples.chunk(2, dim=0) # Remove null class samples | |
samples = vae.decode(samples / 0.18215).sample | |
# Convert to PIL.Image format: | |
samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() | |
samples = [Image.fromarray(sample) for sample in samples] | |
return samples | |
description = '''This is a demo of our DiT image generation models. DiTs are a new class of diffusion models with | |
transformer backbones. They are class-conditional models trained on ImageNet-1K, and they outperform prior DDPMs.''' | |
duplicate = '''Skip the queue by duplicating this space and upgrading to GPU in settings | |
<a href="https://huggingface.co/spaces/wpeebles/DiT?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>''' | |
project_links = ''' | |
<p style="text-align: center"> | |
<a href="https://www.wpeebles.com/DiT.html">Project Page</a> · | |
<a href="http://colab.research.google.com/github/facebookresearch/DiT/blob/main/run_DiT.ipynb">Colab</a> · | |
<a href="http://arxiv.org/abs/2212.09748">Paper</a> · | |
<a href="https://github.com/facebookresearch/DiT">GitHub</a></p>''' | |
examples = [ | |
["512x512", "stabilityai/sd-vae-ft-mse", "golden retriever", 4.0, 200, 4, 1000], | |
["512x512", "stabilityai/sd-vae-ft-mse", "macaw", 4.0, 200, 4, 1], | |
["512x512", "stabilityai/sd-vae-ft-mse", "balloon", 4.0, 200, 4, 1], | |
["512x512", "stabilityai/sd-vae-ft-mse", "cliff, drop, drop-off", 4.0, 200, 4, 7], | |
["512x512", "stabilityai/sd-vae-ft-mse", "Pembroke, Pembroke Welsh corgi", 4.0, 200, 4, 0], | |
["256x256", "stabilityai/sd-vae-ft-mse", "sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita", 4.0, 200, | |
4, 1], | |
["256x256", "stabilityai/sd-vae-ft-mse", "teddy, teddy bear", 4.0, 200, 4, 3], | |
["256x256", "stabilityai/sd-vae-ft-mse", "cheeseburger", 4.0, 200, 4, 2], | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown("<h1 style='text-align: center'>Scalable Diffusion Models with Transformers (DiT)</h1>") | |
gr.Markdown(project_links) | |
gr.Markdown(description) | |
gr.Markdown(duplicate) | |
with gr.Tabs(): | |
with gr.TabItem('Generate'): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
image_size = gr.inputs.Radio(choices=["256x256", "512x512"], default="256x256", label='DiT Model Resolution') | |
vae_model = gr.inputs.Radio(choices=["stabilityai/sd-vae-ft-mse", "stabilityai/sd-vae-ft-ema"], | |
default="stabilityai/sd-vae-ft-mse", label='VAE Decoder') | |
with gr.Row(): | |
i1k_class = gr.inputs.Dropdown( | |
list(IMAGENET_1K_CLASSES.values()), | |
default='golden retriever', | |
type="index", label='ImageNet-1K Class' | |
) | |
cfg_scale = gr.inputs.Slider(minimum=1, maximum=25, step=0.1, default=4.0, label='Classifier-free Guidance Scale') | |
steps = gr.inputs.Slider(minimum=4, maximum=1000, step=1, default=75, label='Sampling Steps') | |
# n = gr.inputs.Slider(minimum=1, maximum=16, step=1, default=1, label='Number of Samples') | |
seed = gr.inputs.Number(default=0, label='Seed') | |
button = gr.Button("Generate", variant="primary") | |
with gr.Column(): | |
output = gr.Gallery(label='Generated Images').style(grid=[2], height="auto") | |
button.click(generate, inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, seed], outputs=[output]) | |
with gr.Row(): | |
ex = gr.Examples(examples=examples, fn=generate, | |
inputs=[image_size, vae_model, i1k_class, cfg_scale, steps, seed], | |
outputs=[output], | |
cache_examples=True) | |
demo.queue() | |
demo.launch() | |