import os import torch import gradio as gr import spaces from PIL import Image from diffusers import DiffusionPipeline from huggingface_hub import snapshot_download from test_ccsr_tile import load_pipeline import argparse from accelerate import Accelerator # Global variables class ModelContainer: def __init__(self): self.pipeline = None self.generator = None self.accelerator = None self.is_initialized = False model_container = ModelContainer() class Args: def __init__(self, **kwargs): self.__dict__.update(kwargs) @spaces.GPU def initialize_models(): """Initialize models only if they haven't been initialized yet""" if model_container.is_initialized: return True try: # Download model repository (only once) model_path = snapshot_download( repo_id="NightRaven109/CCSRModels", token=os.environ['Read2'] ) # Set up default arguments args = Args( pretrained_model_path=os.path.join(model_path, "stable-diffusion-2-1-base"), controlnet_model_path=os.path.join(model_path, "Controlnet"), vae_model_path=os.path.join(model_path, "vae"), mixed_precision="fp16", tile_vae=False, sample_method="ddpm", vae_encoder_tile_size=1024, vae_decoder_tile_size=224 ) # Initialize accelerator model_container.accelerator = Accelerator( mixed_precision=args.mixed_precision, ) # Load pipeline model_container.pipeline = load_pipeline(args, model_container.accelerator, enable_xformers_memory_efficient_attention=False) # Set models to eval mode model_container.pipeline.unet.eval() model_container.pipeline.controlnet.eval() model_container.pipeline.vae.eval() model_container.pipeline.text_encoder.eval() # Move pipeline to CUDA and set to eval mode once model_container.pipeline = model_container.pipeline.to("cuda") # Initialize generator model_container.generator = torch.Generator("cuda") # Set initialization flag model_container.is_initialized = True return True except Exception as e: print(f"Error initializing models: {str(e)}") return False @torch.no_grad() # Add no_grad decorator for inference @spaces.GPU def process_image( input_image, prompt="clean, texture, high-resolution, 8k", negative_prompt="blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", guidance_scale=2.5, conditioning_scale=1.0, num_inference_steps=6, seed=None, upscale_factor=4, color_fix_method="adain" ): # Initialize models if not already done if not model_container.is_initialized: if not initialize_models(): return None try: # Create args object args = Args( added_prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, conditioning_scale=conditioning_scale, num_inference_steps=num_inference_steps, seed=seed, upscale=upscale_factor, process_size=512, align_method=color_fix_method, t_max=0.6666, t_min=0.0, tile_diffusion=False, tile_diffusion_size=None, tile_diffusion_stride=None, start_steps=999, start_point='lr', use_vae_encode_condition=True, sample_times=1 ) # Set seed if provided if seed is not None: model_container.generator.manual_seed(seed) # Process input image validation_image = Image.fromarray(input_image) ori_width, ori_height = validation_image.size # Resize logic resize_flag = False if ori_width < args.process_size//args.upscale or ori_height < args.process_size//args.upscale: scale = (args.process_size//args.upscale)/min(ori_width, ori_height) validation_image = validation_image.resize((round(scale*ori_width), round(scale*ori_height))) resize_flag = True validation_image = validation_image.resize((validation_image.size[0]*args.upscale, validation_image.size[1]*args.upscale)) validation_image = validation_image.resize((validation_image.size[0]//8*8, validation_image.size[1]//8*8)) width, height = validation_image.size # Generate image inference_time, output = model_container.pipeline( args.t_max, args.t_min, args.tile_diffusion, args.tile_diffusion_size, args.tile_diffusion_stride, args.added_prompt, validation_image, num_inference_steps=args.num_inference_steps, generator=model_container.generator, height=height, width=width, guidance_scale=args.guidance_scale, negative_prompt=args.negative_prompt, conditioning_scale=args.conditioning_scale, start_steps=args.start_steps, start_point=args.start_point, use_vae_encode_condition=True, ) image = output.images[0] # Apply color fixing if specified if args.align_method != "none": from myutils.wavelet_color_fix import wavelet_color_fix, adain_color_fix fix_func = wavelet_color_fix if args.align_method == "wavelet" else adain_color_fix image = fix_func(image, validation_image) if resize_flag: image = image.resize((ori_width*args.upscale, ori_height*args.upscale)) return image except Exception as e: print(f"Error processing image: {str(e)}") import traceback traceback.print_exc() return None # Define default values DEFAULT_VALUES = { "prompt": "clean, texture, high-resolution, 8k", "negative_prompt": "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", "guidance_scale": 3, "conditioning_scale": 1.0, "num_steps": 6, "seed": None, "upscale_factor": 4, "color_fix_method": "adain" } # Define example data EXAMPLES = [ [ "examples/1.png", # Input image path "clean, texture, high-resolution, 8k", # Prompt "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", # Negative prompt 3.0, # Guidance scale 1.0, # Conditioning scale 6, # Num steps 42, # Seed 4, # Upscale factor "wavelet" # Color fix method ], [ "examples/22.png", "clean, texture, high-resolution, 8k", "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", 3.0, 1.0, 6, 123, 4, "wavelet" ], [ "examples/4.png", "clean, texture, high-resolution, 8k", "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", 3.0, 1.0, 6, 123, 4, "wavelet" ], [ "examples/9D03D7F206775949.png", "clean, texture, high-resolution, 8k", "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", 3.0, 1.0, 6, 123, 4, "wavelet" ], [ "examples/3.jpeg", "clean, texture, high-resolution, 8k", "blurry, dotted, noise, raster lines, unclear, lowres, over-smoothed", 2.5, 1.0, 6, 456, 4, "wavelet" ] ] # Create interface components with gr.Blocks(title="Controllable Conditional Super-Resolution") as demo: gr.Markdown("## Controllable Conditional Super-Resolution") gr.Markdown("Upload an image to enhance its resolution using CCSR.") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image") with gr.Accordion("Advanced Options", open=False): prompt = gr.Textbox(label="Prompt", value=DEFAULT_VALUES["prompt"]) negative_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_VALUES["negative_prompt"]) guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, value=DEFAULT_VALUES["guidance_scale"], label="Guidance Scale") conditioning_scale = gr.Slider(minimum=0.1, maximum=2.0, value=DEFAULT_VALUES["conditioning_scale"], label="Conditioning Scale") num_steps = gr.Slider(minimum=1, maximum=50, value=DEFAULT_VALUES["num_steps"], step=1, label="Number of Steps") seed = gr.Number(label="Seed", value=DEFAULT_VALUES["seed"]) upscale_factor = gr.Slider(minimum=1, maximum=8, value=DEFAULT_VALUES["upscale_factor"], step=1, label="Upscale Factor") color_fix_method = gr.Dropdown( choices=["none", "wavelet", "adain"], label="Color Fix Method", value=DEFAULT_VALUES["color_fix_method"] ) with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit", variant="primary") with gr.Column(): output_image = gr.Image(label="Generated Image") # Add examples gr.Examples( examples=EXAMPLES, inputs=[ input_image, prompt, negative_prompt, guidance_scale, conditioning_scale, num_steps, seed, upscale_factor, color_fix_method ], outputs=output_image, fn=process_image, cache_examples=True # Cache the results for faster loading ) # Define submit action submit_btn.click( fn=process_image, inputs=[ input_image, prompt, negative_prompt, guidance_scale, conditioning_scale, num_steps, seed, upscale_factor, color_fix_method ], outputs=output_image ) # Define clear action that resets to default values def reset_to_defaults(): return [ None, # input_image DEFAULT_VALUES["prompt"], DEFAULT_VALUES["negative_prompt"], DEFAULT_VALUES["guidance_scale"], DEFAULT_VALUES["conditioning_scale"], DEFAULT_VALUES["num_steps"], DEFAULT_VALUES["seed"], DEFAULT_VALUES["upscale_factor"], DEFAULT_VALUES["color_fix_method"] ] clear_btn.click( fn=reset_to_defaults, inputs=None, outputs=[ input_image, prompt, negative_prompt, guidance_scale, conditioning_scale, num_steps, seed, upscale_factor, color_fix_method ] ) if __name__ == "__main__": demo.launch()