File size: 3,091 Bytes
5b344d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb7a796
9d157fe
5b344d3
 
eb7a796
 
9d157fe
5b344d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7057feb
5b344d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7057feb
5b344d3
 
 
 
 
eb7a796
5b344d3
 
 
 
 
 
 
 
 
 
 
eb7a796
5b344d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gc
import datetime
import os
import re
from typing import Literal

import streamlit as st
import torch
from diffusers import (
    StableDiffusionPipeline,
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    EulerDiscreteScheduler,
    DDIMScheduler,
)

PIPELINES = Literal["txt2img", "sketch2img"]

@st.cache_resource(max_entries=1)
def get_pipelines( name:PIPELINES, enable_cpu_offload = False, ) -> StableDiffusionPipeline:
    pipe = None

    if name == "txt2img":
        pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,cache_dir="D:/huggingface/CACHE/")
        pipe.unet.load_attn_procs("./")
        pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))
    elif name == "sketch2img":
        controlnet = ControlNetModel.from_pretrained("Abhi5ingh/model_dresscode", torch_dtype=torch.float16,cache_dir="D:/huggingface/CACHE/")
        pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet = controlnet, torch_dtype = torch.float16,cache_dir="D:/huggingface/CACHE/")
        pipe.unet.load_attn_procs("./")
        pipe.safety_checker = lambda images, **kwargs: (images, [False] * len(images))

    if pipe is None:
        raise Exception(f"Pipeline not Found {name}")

    if enable_cpu_offload:
        print("Enabling cpu offloading for the given pipeline")
        pipe.enable_model_cpu_offload()
    else:
        pipe = pipe.to("cuda")
    return pipe

def generate(
        prompt,
        pipeline_name: PIPELINES,
        image = None,
        num_inference_steps = 30,
        negative_prompt = None,
        width = 512,
        height = 512,
        guidance_scale = 7.5,
        controlnet_conditioning_scale = None,
        enable_cpu_offload= False):
    negative_prompt = negative_prompt if negative_prompt else None
    p = st.progress(0)
    callback = lambda step,*_: p.progress(step/num_inference_steps)
    pipe = get_pipelines(pipeline_name,enable_cpu_offload=enable_cpu_offload)
    torch.cuda.empty_cache()

    kwargs = dict(
        prompt = prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        callback=callback,
        guidance_scale=guidance_scale,
    )
    print("kwargs",kwargs)

    if pipeline_name =="sketch2img" and sketch_pil:
        kwargs.update(image=sketch_pil,controlnet_conditioning_scale=controlnet_conditioning_scale)
    elif pipeline_name == "txt2img":
        kwargs.update(width = width, height = height)
    else:
        raise Exception(
            f"Cannot generate image for pipeline {pipeline_name} and {prompt}")
    images = pipe(**kwargs).images
    image = images[0]

    os.makedirs("outputs", exist_ok=True)

    filename = (
        "outputs/"
        + re.sub(r"\s+", "_",prompt)[:30]
        + f"_{datetime.datetime.now().timestamp()}"
    )
    image.save(f"{filename}.png")
    with open(f"{filename}.txt", "w") as f:
        f.write(f"Prompt: {prompt}\n\nNegative Prompt:{negative_prompt}")
    return image