Ashoka74 commited on
Commit
30d0d94
·
1 Parent(s): 66a74d5

change app

Browse files
Files changed (1) hide show
  1. app.py +255 -5
app.py CHANGED
@@ -1,7 +1,257 @@
1
- import gradio as gr
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
 
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+ from tqdm import tqdm
9
+ from transformers import AutoModelForImageSegmentation
10
 
11
+ from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline
12
+ from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler
13
+ from mvadapter.utils import (
14
+ get_orthogonal_camera,
15
+ get_plucker_embeds_from_cameras_ortho,
16
+ make_image_grid,
17
+ )
18
+
19
+
20
+ def prepare_pipeline(
21
+ base_model,
22
+ vae_model,
23
+ unet_model,
24
+ lora_model,
25
+ adapter_path,
26
+ scheduler,
27
+ num_views,
28
+ device,
29
+ dtype,
30
+ ):
31
+ # Load vae and unet if provided
32
+ pipe_kwargs = {}
33
+ if vae_model is not None:
34
+ pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model)
35
+ if unet_model is not None:
36
+ pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model)
37
+
38
+ # Prepare pipeline
39
+ pipe: MVAdapterI2MVSDXLPipeline
40
+ pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs)
41
+
42
+ # Load scheduler if provided
43
+ scheduler_class = None
44
+ if scheduler == "ddpm":
45
+ scheduler_class = DDPMScheduler
46
+ elif scheduler == "lcm":
47
+ scheduler_class = LCMScheduler
48
+
49
+ pipe.scheduler = ShiftSNRScheduler.from_scheduler(
50
+ pipe.scheduler,
51
+ shift_mode="interpolated",
52
+ shift_scale=8.0,
53
+ scheduler_class=scheduler_class,
54
+ )
55
+ pipe.init_custom_adapter(num_views=num_views)
56
+ pipe.load_custom_adapter(
57
+ adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors"
58
+ )
59
+
60
+ pipe.to(device=device, dtype=dtype)
61
+ pipe.cond_encoder.to(device=device, dtype=dtype)
62
+
63
+ # load lora if provided
64
+ if lora_model is not None:
65
+ model_, name_ = lora_model.rsplit("/", 1)
66
+ pipe.load_lora_weights(model_, weight_name=name_)
67
+
68
+ return pipe
69
+
70
+
71
+ def remove_bg(image, net, transform, device):
72
+ image_size = image.size
73
+ input_images = transform(image).unsqueeze(0).to(device)
74
+ with torch.no_grad():
75
+ preds = net(input_images)[-1].sigmoid().cpu()
76
+ pred = preds[0].squeeze()
77
+ pred_pil = transforms.ToPILImage()(pred)
78
+ mask = pred_pil.resize(image_size)
79
+ image.putalpha(mask)
80
+ return image
81
+
82
+
83
+ def preprocess_image(image: Image.Image, height, width):
84
+ image = np.array(image)
85
+ alpha = image[..., 3] > 0
86
+ H, W = alpha.shape
87
+ # get the bounding box of alpha
88
+ y, x = np.where(alpha)
89
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
90
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
91
+ image_center = image[y0:y1, x0:x1]
92
+ # resize the longer side to H * 0.9
93
+ H, W, _ = image_center.shape
94
+ if H > W:
95
+ W = int(W * (height * 0.9) / H)
96
+ H = int(height * 0.9)
97
+ else:
98
+ H = int(H * (width * 0.9) / W)
99
+ W = int(width * 0.9)
100
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
101
+ # pad to H, W
102
+ start_h = (height - H) // 2
103
+ start_w = (width - W) // 2
104
+ image = np.zeros((height, width, 4), dtype=np.uint8)
105
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
106
+ image = image.astype(np.float32) / 255.0
107
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
108
+ image = (image * 255).clip(0, 255).astype(np.uint8)
109
+ image = Image.fromarray(image)
110
+
111
+ return image
112
+
113
+
114
+ def run_pipeline(
115
+ pipe,
116
+ num_views,
117
+ text,
118
+ image,
119
+ height,
120
+ width,
121
+ num_inference_steps,
122
+ guidance_scale,
123
+ seed,
124
+ remove_bg_fn=None,
125
+ reference_conditioning_scale=1.0,
126
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
127
+ lora_scale=1.0,
128
+ device="cuda",
129
+ ):
130
+ # Prepare cameras
131
+ cameras = get_orthogonal_camera(
132
+ elevation_deg=[0, 0, 0, 0, 0, 0],
133
+ distance=[1.8] * num_views,
134
+ left=-0.55,
135
+ right=0.55,
136
+ bottom=-0.55,
137
+ top=0.55,
138
+ azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]],
139
+ device=device,
140
+ )
141
+
142
+ plucker_embeds = get_plucker_embeds_from_cameras_ortho(
143
+ cameras.c2w, [1.1] * num_views, width
144
+ )
145
+ control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1)
146
+
147
+ # Prepare image
148
+ reference_image = Image.open(image) if isinstance(image, str) else image
149
+ if remove_bg_fn is not None:
150
+ reference_image = remove_bg_fn(reference_image)
151
+ reference_image = preprocess_image(reference_image, height, width)
152
+ elif reference_image.mode == "RGBA":
153
+ reference_image = preprocess_image(reference_image, height, width)
154
+
155
+ pipe_kwargs = {}
156
+ if seed != -1 and isinstance(seed, int):
157
+ pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed)
158
+
159
+ images = pipe(
160
+ text,
161
+ height=height,
162
+ width=width,
163
+ num_inference_steps=num_inference_steps,
164
+ guidance_scale=guidance_scale,
165
+ num_images_per_prompt=num_views,
166
+ control_image=control_images,
167
+ control_conditioning_scale=1.0,
168
+ reference_image=reference_image,
169
+ reference_conditioning_scale=reference_conditioning_scale,
170
+ negative_prompt=negative_prompt,
171
+ cross_attention_kwargs={"scale": lora_scale},
172
+ **pipe_kwargs,
173
+ ).images
174
+
175
+ return images, reference_image
176
+
177
+
178
+ if __name__ == "__main__":
179
+ parser = argparse.ArgumentParser()
180
+ # Models
181
+ parser.add_argument(
182
+ "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
183
+ )
184
+ parser.add_argument(
185
+ "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix"
186
+ )
187
+ parser.add_argument("--unet_model", type=str, default=None)
188
+ parser.add_argument("--scheduler", type=str, default=None)
189
+ parser.add_argument("--lora_model", type=str, default=None)
190
+ parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter")
191
+ parser.add_argument("--num_views", type=int, default=6)
192
+ # Device
193
+ parser.add_argument("--device", type=str, default="cuda")
194
+ # Inference
195
+ parser.add_argument("--image", type=str, required=True)
196
+ parser.add_argument("--text", type=str, default="high quality")
197
+ parser.add_argument("--num_inference_steps", type=int, default=50)
198
+ parser.add_argument("--guidance_scale", type=float, default=3.0)
199
+ parser.add_argument("--seed", type=int, default=-1)
200
+ parser.add_argument("--lora_scale", type=float, default=1.0)
201
+ parser.add_argument("--reference_conditioning_scale", type=float, default=1.0)
202
+ parser.add_argument(
203
+ "--negative_prompt",
204
+ type=str,
205
+ default="watermark, ugly, deformed, noisy, blurry, low contrast",
206
+ )
207
+ parser.add_argument("--output", type=str, default="output.png")
208
+ # Extra
209
+ parser.add_argument("--remove_bg", action="store_true", help="Remove background")
210
+ args = parser.parse_args()
211
+
212
+ pipe = prepare_pipeline(
213
+ base_model=args.base_model,
214
+ vae_model=args.vae_model,
215
+ unet_model=args.unet_model,
216
+ lora_model=args.lora_model,
217
+ adapter_path=args.adapter_path,
218
+ scheduler=args.scheduler,
219
+ num_views=args.num_views,
220
+ device=args.device,
221
+ dtype=torch.float16,
222
+ )
223
+
224
+ if args.remove_bg:
225
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
226
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
227
+ )
228
+ birefnet.to(args.device)
229
+ transform_image = transforms.Compose(
230
+ [
231
+ transforms.Resize((1024, 1024)),
232
+ transforms.ToTensor(),
233
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
234
+ ]
235
+ )
236
+ remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device)
237
+ else:
238
+ remove_bg_fn = None
239
+
240
+ images, reference_image = run_pipeline(
241
+ pipe,
242
+ num_views=args.num_views,
243
+ text=args.text,
244
+ image=args.image,
245
+ height=768,
246
+ width=768,
247
+ num_inference_steps=args.num_inference_steps,
248
+ guidance_scale=args.guidance_scale,
249
+ seed=args.seed,
250
+ lora_scale=args.lora_scale,
251
+ reference_conditioning_scale=args.reference_conditioning_scale,
252
+ negative_prompt=args.negative_prompt,
253
+ device=args.device,
254
+ remove_bg_fn=remove_bg_fn,
255
+ )
256
+ make_image_grid(images, rows=1).save(args.output)
257
+ reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png")