hysts HF staff commited on
Commit
bb7edb9
·
1 Parent(s): c453122
Files changed (2) hide show
  1. app.py +3 -12
  2. model.py +135 -39
app.py CHANGED
@@ -31,6 +31,7 @@ def run(
31
  image: PIL.Image.Image,
32
  prompt: str,
33
  negative_prompt: str,
 
34
  num_inference_steps: int = 30,
35
  guidance_scale: float = 5.0,
36
  adapter_conditioning_scale: float = 1.0,
@@ -43,6 +44,7 @@ def run(
43
  image=image,
44
  prompt=prompt,
45
  negative_prompt=negative_prompt,
 
46
  num_inference_steps=num_inference_steps,
47
  guidance_scale=guidance_scale,
48
  adapter_conditioning_scale=adapter_conditioning_scale,
@@ -116,6 +118,7 @@ with gr.Blocks(css="style.css") as demo:
116
  image,
117
  prompt,
118
  negative_prompt,
 
119
  num_inference_steps,
120
  guidance_scale,
121
  adapter_conditioning_scale,
@@ -130,10 +133,6 @@ with gr.Blocks(css="style.css") as demo:
130
  queue=False,
131
  api_name=False,
132
  ).then(
133
- fn=model.change_adapter,
134
- inputs=adapter_name,
135
- api_name=False,
136
- ).success(
137
  fn=run,
138
  inputs=inputs,
139
  outputs=result,
@@ -146,10 +145,6 @@ with gr.Blocks(css="style.css") as demo:
146
  queue=False,
147
  api_name=False,
148
  ).then(
149
- fn=model.change_adapter,
150
- inputs=adapter_name,
151
- api_name=False,
152
- ).success(
153
  fn=run,
154
  inputs=inputs,
155
  outputs=result,
@@ -162,10 +157,6 @@ with gr.Blocks(css="style.css") as demo:
162
  queue=False,
163
  api_name=False,
164
  ).then(
165
- fn=model.change_adapter,
166
- inputs=adapter_name,
167
- api_name=False,
168
- ).success(
169
  fn=run,
170
  inputs=inputs,
171
  outputs=result,
 
31
  image: PIL.Image.Image,
32
  prompt: str,
33
  negative_prompt: str,
34
+ adapter_name: str,
35
  num_inference_steps: int = 30,
36
  guidance_scale: float = 5.0,
37
  adapter_conditioning_scale: float = 1.0,
 
44
  image=image,
45
  prompt=prompt,
46
  negative_prompt=negative_prompt,
47
+ adapter_name=adapter_name,
48
  num_inference_steps=num_inference_steps,
49
  guidance_scale=guidance_scale,
50
  adapter_conditioning_scale=adapter_conditioning_scale,
 
118
  image,
119
  prompt,
120
  negative_prompt,
121
+ adapter_name,
122
  num_inference_steps,
123
  guidance_scale,
124
  adapter_conditioning_scale,
 
133
  queue=False,
134
  api_name=False,
135
  ).then(
 
 
 
 
136
  fn=run,
137
  inputs=inputs,
138
  outputs=result,
 
145
  queue=False,
146
  api_name=False,
147
  ).then(
 
 
 
 
148
  fn=run,
149
  inputs=inputs,
150
  outputs=result,
 
157
  queue=False,
158
  api_name=False,
159
  ).then(
 
 
 
 
160
  fn=run,
161
  inputs=inputs,
162
  outputs=result,
model.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import Callable
 
 
2
 
3
  import PIL.Image
4
  import torch
@@ -26,74 +28,149 @@ ADAPTER_NAMES = [
26
  ]
27
 
28
 
29
- class CannyPreprocessor:
 
 
 
 
 
 
 
 
 
 
30
  def __init__(self):
31
  self.model = CannyDetector()
32
 
 
 
 
33
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
34
  return self.model(image, detect_resolution=384, image_resolution=1024)
35
 
36
 
37
- class LineartPreprocessor:
38
  def __init__(self):
39
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- self.model = LineartDetector.from_pretrained("lllyasviel/Annotators").to(device)
 
 
41
 
42
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
43
  return self.model(image, detect_resolution=384, image_resolution=1024)
44
 
45
 
46
- class MidasPreprocessor:
47
  def __init__(self):
48
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
  self.model = MidasDetector.from_pretrained(
50
  "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
51
- ).to(device)
 
 
 
52
 
53
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
54
  return self.model(image, detect_resolution=512, image_resolution=1024)
55
 
56
 
57
- class PidiNetPreprocessor:
58
  def __init__(self):
59
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
- self.model = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to(device)
 
 
61
 
62
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
63
  return self.model(image, detect_resolution=512, image_resolution=1024, apply_filter=True)
64
 
65
 
66
- class RecolorPreprocessor:
 
 
 
67
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
68
  return image.convert("L").convert("RGB")
69
 
70
 
71
- class ZoePreprocessor:
72
  def __init__(self):
73
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
  self.model = ZoeDetector.from_pretrained(
75
  "valhalla/t2iadapter-aux-models", filename="zoed_nk.pth", model_type="zoedepth_nk"
76
- ).to(device)
 
 
 
77
 
78
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
79
  return self.model(image, gamma_corrected=True, image_resolution=1024)
80
 
81
 
82
- def get_preprocessor(adapter_name: str) -> Callable[[PIL.Image.Image], PIL.Image.Image]:
83
- if adapter_name == "TencentARC/t2i-adapter-canny-sdxl-1.0":
84
- return CannyPreprocessor()
85
- elif adapter_name == "TencentARC/t2i-adapter-sketch-sdxl-1.0":
86
- return PidiNetPreprocessor()
87
- elif adapter_name == "TencentARC/t2i-adapter-lineart-sdxl-1.0":
88
- return LineartPreprocessor()
89
- elif adapter_name == "TencentARC/t2i-adapter-depth-midas-sdxl-1.0":
90
- return MidasPreprocessor()
91
- elif adapter_name == "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0":
92
- return ZoePreprocessor()
93
- elif adapter_name == "TencentARC/t2i-adapter-recolor-sdxl-1.0":
94
- return RecolorPreprocessor()
95
- else:
96
- raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
 
99
  class Model:
@@ -103,11 +180,12 @@ class Model:
103
  if adapter_name not in ADAPTER_NAMES:
104
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
105
 
 
106
  self.adapter_name = adapter_name
107
 
108
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
109
  if torch.cuda.is_available():
110
- self.preprocessor = get_preprocessor(adapter_name)
111
 
112
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
113
  adapter = T2IAdapter.from_pretrained(
@@ -127,27 +205,39 @@ class Model:
127
  ).to(self.device)
128
  self.pipe.enable_xformers_memory_efficient_attention()
129
  else:
 
130
  self.pipe = None
131
 
132
- def change_adapter(self, adapter_name: str) -> None:
133
- if not torch.cuda.is_available():
134
- raise RuntimeError("This demo does not work on CPU.")
135
  if adapter_name not in ADAPTER_NAMES:
136
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
137
- if adapter_name == self.adapter_name:
138
  return
139
 
140
- self.preprocessor = None # type: ignore
 
 
 
 
 
 
 
 
141
  torch.cuda.empty_cache()
142
- self.preprocessor = get_preprocessor(adapter_name)
143
 
144
- self.pipe.adapter = None
145
- torch.cuda.empty_cache()
 
 
 
146
  self.pipe.adapter = T2IAdapter.from_pretrained(
147
  adapter_name,
148
  torch_dtype=torch.float16,
149
  varient="fp16",
150
  ).to(self.device)
 
 
 
151
 
152
  def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image:
153
  w, h = image.size
@@ -161,6 +251,7 @@ class Model:
161
  image: PIL.Image.Image,
162
  prompt: str,
163
  negative_prompt: str,
 
164
  num_inference_steps: int = 30,
165
  guidance_scale: float = 5.0,
166
  adapter_conditioning_scale: float = 1.0,
@@ -168,12 +259,17 @@ class Model:
168
  seed: int = 0,
169
  apply_preprocess: bool = True,
170
  ) -> list[PIL.Image.Image]:
 
 
171
  if num_inference_steps > self.MAX_NUM_INFERENCE_STEPS:
172
  raise ValueError(f"Number of steps must be less than {self.MAX_NUM_INFERENCE_STEPS}")
173
 
174
  # Resize image to avoid OOM
175
  image = self.resize_image(image)
176
 
 
 
 
177
  if apply_preprocess:
178
  image = self.preprocessor(image)
179
 
 
1
+ import gc
2
+ import os
3
+ from abc import ABC, abstractmethod
4
 
5
  import PIL.Image
6
  import torch
 
28
  ]
29
 
30
 
31
+ class Preprocessor(ABC):
32
+ @abstractmethod
33
+ def to(self, device: torch.device | str) -> "Preprocessor":
34
+ pass
35
+
36
+ @abstractmethod
37
+ def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
38
+ pass
39
+
40
+
41
+ class CannyPreprocessor(Preprocessor):
42
  def __init__(self):
43
  self.model = CannyDetector()
44
 
45
+ def to(self, device: torch.device | str) -> Preprocessor:
46
+ return self
47
+
48
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
49
  return self.model(image, detect_resolution=384, image_resolution=1024)
50
 
51
 
52
+ class LineartPreprocessor(Preprocessor):
53
  def __init__(self):
54
+ self.model = LineartDetector.from_pretrained("lllyasviel/Annotators")
55
+
56
+ def to(self, device: torch.device | str) -> Preprocessor:
57
+ return self.model.to(device)
58
 
59
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
60
  return self.model(image, detect_resolution=384, image_resolution=1024)
61
 
62
 
63
+ class MidasPreprocessor(Preprocessor):
64
  def __init__(self):
 
65
  self.model = MidasDetector.from_pretrained(
66
  "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
67
+ )
68
+
69
+ def to(self, device: torch.device | str) -> Preprocessor:
70
+ return self.model.to(device)
71
 
72
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
73
  return self.model(image, detect_resolution=512, image_resolution=1024)
74
 
75
 
76
+ class PidiNetPreprocessor(Preprocessor):
77
  def __init__(self):
78
+ self.model = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
79
+
80
+ def to(self, device: torch.device | str) -> Preprocessor:
81
+ return self.model.to(device)
82
 
83
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
84
  return self.model(image, detect_resolution=512, image_resolution=1024, apply_filter=True)
85
 
86
 
87
+ class RecolorPreprocessor(Preprocessor):
88
+ def to(self, device: torch.device | str) -> Preprocessor:
89
+ return self
90
+
91
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
92
  return image.convert("L").convert("RGB")
93
 
94
 
95
+ class ZoePreprocessor(Preprocessor):
96
  def __init__(self):
 
97
  self.model = ZoeDetector.from_pretrained(
98
  "valhalla/t2iadapter-aux-models", filename="zoed_nk.pth", model_type="zoedepth_nk"
99
+ )
100
+
101
+ def to(self, device: torch.device | str) -> Preprocessor:
102
+ return self.model.to(device)
103
 
104
  def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
105
  return self.model(image, gamma_corrected=True, image_resolution=1024)
106
 
107
 
108
+ PRELOAD_PREPROCESSORS_IN_GPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_GPU_MEMORY", "1") == "1"
109
+ PRELOAD_PREPROCESSORS_IN_CPU_MEMORY = os.getenv("PRELOAD_PREPROCESSORS_IN_CPU_MEMORY", "0") == "1"
110
+ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
111
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
112
+ preprocessors_gpu: dict[str, Preprocessor] = {
113
+ "TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor().to(device),
114
+ "TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor().to(device),
115
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor().to(device),
116
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor().to(device),
117
+ "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor().to(device),
118
+ "TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor().to(device),
119
+ }
120
+
121
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
122
+ return preprocessors_gpu[adapter_name]
123
+
124
+ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
125
+ preprocessors_cpu: dict[str, Preprocessor] = {
126
+ "TencentARC/t2i-adapter-canny-sdxl-1.0": CannyPreprocessor(),
127
+ "TencentARC/t2i-adapter-sketch-sdxl-1.0": PidiNetPreprocessor(),
128
+ "TencentARC/t2i-adapter-lineart-sdxl-1.0": LineartPreprocessor(),
129
+ "TencentARC/t2i-adapter-depth-midas-sdxl-1.0": MidasPreprocessor(),
130
+ "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0": ZoePreprocessor(),
131
+ "TencentARC/t2i-adapter-recolor-sdxl-1.0": RecolorPreprocessor(),
132
+ }
133
+
134
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
135
+ return preprocessors_cpu[adapter_name]
136
+
137
+ else:
138
+
139
+ def get_preprocessor(adapter_name: str) -> Preprocessor:
140
+ if adapter_name == "TencentARC/t2i-adapter-canny-sdxl-1.0":
141
+ return CannyPreprocessor()
142
+ elif adapter_name == "TencentARC/t2i-adapter-sketch-sdxl-1.0":
143
+ return PidiNetPreprocessor()
144
+ elif adapter_name == "TencentARC/t2i-adapter-lineart-sdxl-1.0":
145
+ return LineartPreprocessor()
146
+ elif adapter_name == "TencentARC/t2i-adapter-depth-midas-sdxl-1.0":
147
+ return MidasPreprocessor()
148
+ elif adapter_name == "TencentARC/t2i-adapter-depth-zoe-sdxl-1.0":
149
+ return ZoePreprocessor()
150
+ elif adapter_name == "TencentARC/t2i-adapter-recolor-sdxl-1.0":
151
+ return RecolorPreprocessor()
152
+ else:
153
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
154
+
155
+ def download_all_preprocessors():
156
+ for adapter_name in ADAPTER_NAMES:
157
+ get_preprocessor(adapter_name)
158
+ gc.collect()
159
+
160
+ download_all_preprocessors()
161
+
162
+
163
+ def download_all_adapters():
164
+ for adapter_name in ADAPTER_NAMES:
165
+ T2IAdapter.from_pretrained(
166
+ adapter_name,
167
+ torch_dtype=torch.float16,
168
+ varient="fp16",
169
+ )
170
+ gc.collect()
171
+
172
+
173
+ download_all_adapters()
174
 
175
 
176
  class Model:
 
180
  if adapter_name not in ADAPTER_NAMES:
181
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
182
 
183
+ self.preprocessor_name = adapter_name
184
  self.adapter_name = adapter_name
185
 
186
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
187
  if torch.cuda.is_available():
188
+ self.preprocessor = get_preprocessor(adapter_name).to(self.device)
189
 
190
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
191
  adapter = T2IAdapter.from_pretrained(
 
205
  ).to(self.device)
206
  self.pipe.enable_xformers_memory_efficient_attention()
207
  else:
208
+ self.preprocessor = None # type: ignore
209
  self.pipe = None
210
 
211
+ def change_preprocessor(self, adapter_name: str) -> None:
 
 
212
  if adapter_name not in ADAPTER_NAMES:
213
  raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
214
+ if adapter_name == self.preprocessor_name:
215
  return
216
 
217
+ if PRELOAD_PREPROCESSORS_IN_GPU_MEMORY:
218
+ pass
219
+ elif PRELOAD_PREPROCESSORS_IN_CPU_MEMORY:
220
+ self.preprocessor.to("cpu")
221
+ else:
222
+ del self.preprocessor
223
+ self.preprocessor = get_preprocessor(adapter_name).to(self.device)
224
+ self.preprocessor_name = adapter_name
225
+ gc.collect()
226
  torch.cuda.empty_cache()
 
227
 
228
+ def change_adapter(self, adapter_name: str) -> None:
229
+ if adapter_name not in ADAPTER_NAMES:
230
+ raise ValueError(f"Adapter name must be one of {ADAPTER_NAMES}")
231
+ if adapter_name == self.adapter_name:
232
+ return
233
  self.pipe.adapter = T2IAdapter.from_pretrained(
234
  adapter_name,
235
  torch_dtype=torch.float16,
236
  varient="fp16",
237
  ).to(self.device)
238
+ self.adapter_name = adapter_name
239
+ gc.collect()
240
+ torch.cuda.empty_cache()
241
 
242
  def resize_image(self, image: PIL.Image.Image) -> PIL.Image.Image:
243
  w, h = image.size
 
251
  image: PIL.Image.Image,
252
  prompt: str,
253
  negative_prompt: str,
254
+ adapter_name: str,
255
  num_inference_steps: int = 30,
256
  guidance_scale: float = 5.0,
257
  adapter_conditioning_scale: float = 1.0,
 
259
  seed: int = 0,
260
  apply_preprocess: bool = True,
261
  ) -> list[PIL.Image.Image]:
262
+ if not torch.cuda.is_available():
263
+ raise RuntimeError("This demo does not work on CPU.")
264
  if num_inference_steps > self.MAX_NUM_INFERENCE_STEPS:
265
  raise ValueError(f"Number of steps must be less than {self.MAX_NUM_INFERENCE_STEPS}")
266
 
267
  # Resize image to avoid OOM
268
  image = self.resize_image(image)
269
 
270
+ self.change_preprocessor(adapter_name)
271
+ self.change_adapter(adapter_name)
272
+
273
  if apply_preprocess:
274
  image = self.preprocessor(image)
275