Spaces:
Running
on
L4
Running
on
L4
Use simple adapter names
Browse files- app_base.py +6 -6
- model.py +30 -29
app_base.py
CHANGED
@@ -69,35 +69,35 @@ def create_demo(model: Model) -> gr.Blocks:
|
|
69 |
[
|
70 |
"assets/org_canny.jpg",
|
71 |
"Mystical fairy in real, magic, 4k picture, high quality",
|
72 |
-
"
|
73 |
0,
|
74 |
True,
|
75 |
],
|
76 |
[
|
77 |
"assets/org_sketch.png",
|
78 |
"a robot, mount fuji in the background, 4k photo, highly detailed",
|
79 |
-
"
|
80 |
0,
|
81 |
True,
|
82 |
],
|
83 |
[
|
84 |
"assets/org_lin.jpg",
|
85 |
"Ice dragon roar, 4k photo",
|
86 |
-
"
|
87 |
0,
|
88 |
True,
|
89 |
],
|
90 |
[
|
91 |
"assets/org_mid.jpg",
|
92 |
"A photo of a room, 4k photo, highly detailed",
|
93 |
-
"
|
94 |
0,
|
95 |
True,
|
96 |
],
|
97 |
[
|
98 |
"assets/org_zoe.jpg",
|
99 |
"A photo of a orchid, 4k photo, highly detailed",
|
100 |
-
"
|
101 |
0,
|
102 |
True,
|
103 |
],
|
@@ -109,7 +109,7 @@ def create_demo(model: Model) -> gr.Blocks:
|
|
109 |
with gr.Group():
|
110 |
image = gr.Image(label="Input image", type="pil", height=600)
|
111 |
prompt = gr.Textbox(label="Prompt")
|
112 |
-
adapter_name = gr.Dropdown(label="Adapter", choices=ADAPTER_NAMES, value=ADAPTER_NAMES[0])
|
113 |
run_button = gr.Button("Run")
|
114 |
with gr.Accordion("Advanced options", open=False):
|
115 |
apply_preprocess = gr.Checkbox(label="Apply preprocess", value=True)
|
|
|
69 |
[
|
70 |
"assets/org_canny.jpg",
|
71 |
"Mystical fairy in real, magic, 4k picture, high quality",
|
72 |
+
"canny",
|
73 |
0,
|
74 |
True,
|
75 |
],
|
76 |
[
|
77 |
"assets/org_sketch.png",
|
78 |
"a robot, mount fuji in the background, 4k photo, highly detailed",
|
79 |
+
"sketch",
|
80 |
0,
|
81 |
True,
|
82 |
],
|
83 |
[
|
84 |
"assets/org_lin.jpg",
|
85 |
"Ice dragon roar, 4k photo",
|
86 |
+
"lineart",
|
87 |
0,
|
88 |
True,
|
89 |
],
|
90 |
[
|
91 |
"assets/org_mid.jpg",
|
92 |
"A photo of a room, 4k photo, highly detailed",
|
93 |
+
"depth-midas",
|
94 |
0,
|
95 |
True,
|
96 |
],
|
97 |
[
|
98 |
"assets/org_zoe.jpg",
|
99 |
"A photo of a orchid, 4k photo, highly detailed",
|
100 |
+
"depth-zoe",
|
101 |
0,
|
102 |
True,
|
103 |
],
|
|
|
109 |
with gr.Group():
|
110 |
image = gr.Image(label="Input image", type="pil", height=600)
|
111 |
prompt = gr.Textbox(label="Prompt")
|
112 |
+
adapter_name = gr.Dropdown(label="Adapter name", choices=ADAPTER_NAMES, value=ADAPTER_NAMES[0])
|
113 |
run_button = gr.Button("Run")
|
114 |
with gr.Accordion("Advanced options", open=False):
|
115 |
apply_preprocess = gr.Checkbox(label="Apply preprocess", value=True)
|
model.py
CHANGED
@@ -77,14 +77,15 @@ def resize_to_closest_aspect_ratio(image: PIL.Image.Image) -> PIL.Image.Image:
|
|
77 |
return resized_image
|
78 |
|
79 |
|
80 |
-
|
81 |
-
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
82 |
-
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
83 |
-
"TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
84 |
-
"TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
|
85 |
-
"TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
|
86 |
-
# "TencentARC/t2i-adapter-recolor-sdxl-1.0",
|
87 |
-
|
|
|
88 |
|
89 |
|
90 |
class Preprocessor(ABC):
|
@@ -169,12 +170,12 @@ PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_ME
|
|
169 |
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
|
170 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
171 |
preprocessors_gpu: dict[str, Preprocessor] = {
|
172 |
-
"
|
173 |
-
"
|
174 |
-
"
|
175 |
-
"
|
176 |
-
"
|
177 |
-
"
|
178 |
}
|
179 |
|
180 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
@@ -182,12 +183,12 @@ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
|
|
182 |
|
183 |
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
|
184 |
preprocessors_cpu: dict[str, Preprocessor] = {
|
185 |
-
"
|
186 |
-
"
|
187 |
-
"
|
188 |
-
"
|
189 |
-
"
|
190 |
-
"
|
191 |
}
|
192 |
|
193 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
@@ -196,17 +197,17 @@ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
|
|
196 |
else:
|
197 |
|
198 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
199 |
-
if adapter_name == "
|
200 |
return CannyPreprocessor()
|
201 |
-
elif adapter_name == "
|
202 |
return PidiNetPreprocessor()
|
203 |
-
elif adapter_name == "
|
204 |
return LineartPreprocessor()
|
205 |
-
elif adapter_name == "
|
206 |
return MidasPreprocessor()
|
207 |
-
elif adapter_name == "
|
208 |
return ZoePreprocessor()
|
209 |
-
elif adapter_name == "
|
210 |
return RecolorPreprocessor()
|
211 |
else:
|
212 |
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
|
@@ -222,7 +223,7 @@ else:
|
|
222 |
def download_all_adapters():
|
223 |
for adapter_name in ADAPTER_NAMES:
|
224 |
T2IAdapter.from_pretrained(
|
225 |
-
adapter_name,
|
226 |
torch_dtype=torch.float16,
|
227 |
varient="fp16",
|
228 |
)
|
@@ -248,7 +249,7 @@ class Model:
|
|
248 |
|
249 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
250 |
adapter = T2IAdapter.from_pretrained(
|
251 |
-
adapter_name,
|
252 |
torch_dtype=torch.float16,
|
253 |
varient="fp16",
|
254 |
).to(self.device)
|
@@ -292,7 +293,7 @@ class Model:
|
|
292 |
if adapter_name == self.adapter_name:
|
293 |
return
|
294 |
self.pipe.adapter = T2IAdapter.from_pretrained(
|
295 |
-
adapter_name,
|
296 |
torch_dtype=torch.float16,
|
297 |
varient="fp16",
|
298 |
).to(self.device)
|
|
|
77 |
return resized_image
|
78 |
|
79 |
|
80 |
+
ADAPTER_REPO_IDS = {
|
81 |
+
"canny": "TencentARC/t2i-adapter-canny-sdxl-1.0",
|
82 |
+
"sketch": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
83 |
+
"lineart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
|
84 |
+
"depth-midas": "TencentARC/t2i-adapter-depth-midas-sdxl-1.0",
|
85 |
+
"depth-zoe": "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0",
|
86 |
+
# "recolor": "TencentARC/t2i-adapter-recolor-sdxl-1.0",
|
87 |
+
}
|
88 |
+
ADAPTER_NAMES = list(ADAPTER_REPO_IDS.keys())
|
89 |
|
90 |
|
91 |
class Preprocessor(ABC):
|
|
|
170 |
if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
|
171 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
172 |
preprocessors_gpu: dict[str, Preprocessor] = {
|
173 |
+
"canny": CannyPreprocessor().to(device),
|
174 |
+
"sketch": PidiNetPreprocessor().to(device),
|
175 |
+
"lineart": LineartPreprocessor().to(device),
|
176 |
+
"depth-midas": MidasPreprocessor().to(device),
|
177 |
+
"depth-zoe": ZoePreprocessor().to(device),
|
178 |
+
"recolor": RecolorPreprocessor().to(device),
|
179 |
}
|
180 |
|
181 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
|
|
183 |
|
184 |
elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
|
185 |
preprocessors_cpu: dict[str, Preprocessor] = {
|
186 |
+
"canny": CannyPreprocessor(),
|
187 |
+
"sketch": PidiNetPreprocessor(),
|
188 |
+
"lineart": LineartPreprocessor(),
|
189 |
+
"depth-midas": MidasPreprocessor(),
|
190 |
+
"depth-zoe": ZoePreprocessor(),
|
191 |
+
"recolor": RecolorPreprocessor(),
|
192 |
}
|
193 |
|
194 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
|
|
197 |
else:
|
198 |
|
199 |
def get_preprocessor(adapter_name: str) -> Preprocessor:
|
200 |
+
if adapter_name == "canny":
|
201 |
return CannyPreprocessor()
|
202 |
+
elif adapter_name == "sketch":
|
203 |
return PidiNetPreprocessor()
|
204 |
+
elif adapter_name == "lineart":
|
205 |
return LineartPreprocessor()
|
206 |
+
elif adapter_name == "depth-midas":
|
207 |
return MidasPreprocessor()
|
208 |
+
elif adapter_name == "depth-zoe":
|
209 |
return ZoePreprocessor()
|
210 |
+
elif adapter_name == "recolor":
|
211 |
return RecolorPreprocessor()
|
212 |
else:
|
213 |
raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
|
|
|
223 |
def download_all_adapters():
|
224 |
for adapter_name in ADAPTER_NAMES:
|
225 |
T2IAdapter.from_pretrained(
|
226 |
+
ADAPTER_REPO_IDS[adapter_name],
|
227 |
torch_dtype=torch.float16,
|
228 |
varient="fp16",
|
229 |
)
|
|
|
249 |
|
250 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
251 |
adapter = T2IAdapter.from_pretrained(
|
252 |
+
ADAPTER_REPO_IDS[adapter_name],
|
253 |
torch_dtype=torch.float16,
|
254 |
varient="fp16",
|
255 |
).to(self.device)
|
|
|
293 |
if adapter_name == self.adapter_name:
|
294 |
return
|
295 |
self.pipe.adapter = T2IAdapter.from_pretrained(
|
296 |
+
ADAPTER_REPO_IDS[adapter_name],
|
297 |
torch_dtype=torch.float16,
|
298 |
varient="fp16",
|
299 |
).to(self.device)
|