BoyuanJiang commited on
Commit
57b4b9a
·
1 Parent(s): 08d8dcb
Files changed (1) hide show
  1. app.py +85 -87
app.py CHANGED
@@ -24,91 +24,95 @@ access_token = os.getenv("HF_TOKEN")
24
  fitdit_repo = "BoyuanJiang/FitDiT"
25
  repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
26
 
27
- @spaces.GPU
28
- class FitDiTGenerator:
29
- def __init__(self, model_root, device="cuda", with_fp16=False):
30
- weight_dtype = torch.float16 if with_fp16 else torch.bfloat16
31
- transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(model_root, "transformer_garm"), torch_dtype=weight_dtype)
32
- transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(model_root, "transformer_vton"), torch_dtype=weight_dtype)
33
- pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
34
- pose_guider.load_state_dict(torch.load(os.path.join(model_root, "pose_guider", "diffusion_pytorch_model.bin")))
35
- image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
36
- image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
37
- pose_guider.to(device=device, dtype=weight_dtype)
38
- image_encoder_large.to(device=device)
39
- image_encoder_bigG.to(device=device)
40
- self.pipeline = StableDiffusion3TryOnPipeline.from_pretrained(model_root, torch_dtype=weight_dtype, transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
41
- self.pipeline.to(device)
42
- self.dwprocessor = DWposeDetector(model_root=model_root, device=device)
43
- self.parsing_model = Parsing(model_root=model_root, device=device)
 
 
 
44
 
45
- def generate_mask(self, vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
46
- with torch.inference_mode():
47
- vton_img = Image.open(vton_img)
48
- vton_img_det = resize_image(vton_img)
49
- pose_image, keypoints, _, candidate = self.dwprocessor(np.array(vton_img_det)[:,:,::-1])
50
- candidate[candidate<0]=0
51
- candidate = candidate[0]
52
 
53
- candidate[:, 0]*=vton_img_det.width
54
- candidate[:, 1]*=vton_img_det.height
55
 
56
- pose_image = pose_image[:,:,::-1] #rgb
57
- pose_image = Image.fromarray(pose_image)
58
- model_parse, _ = self.parsing_model(vton_img_det)
59
 
60
- mask, mask_gray = get_mask_location(category, model_parse, \
61
- candidate, model_parse.width, model_parse.height, \
62
- offset_top, offset_bottom, offset_left, offset_right)
63
- mask = mask.resize(vton_img.size)
64
- mask_gray = mask_gray.resize(vton_img.size)
65
- mask = mask.convert("L")
66
- mask_gray = mask_gray.convert("L")
67
- masked_vton_img = Image.composite(mask_gray, vton_img, mask)
68
 
69
- im = {}
70
- im['background'] = np.array(vton_img.convert("RGBA"))
71
- im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
72
- im['composite'] = np.array(masked_vton_img.convert("RGBA"))
73
-
74
- return im, pose_image
75
 
76
- def process(self, vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
77
- assert resolution in ["768x1024", "1152x1536", "1536x2048"]
78
- new_width, new_height = resolution.split("x")
79
- new_width = int(new_width)
80
- new_height = int(new_height)
81
- with torch.inference_mode():
82
- garm_img = Image.open(garm_img)
83
- vton_img = Image.open(vton_img)
 
84
 
85
- model_image_size = vton_img.size
86
- garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
87
- vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
88
 
89
- mask = pre_mask["layers"][0][:,:,3]
90
- mask = Image.fromarray(mask)
91
- mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
92
- mask = mask.convert("L")
93
- pose_image = Image.fromarray(pose_image)
94
- pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
95
- if seed==-1:
96
- seed = random.randint(0, 2147483647)
97
- res = self.pipeline(
98
- height=new_height,
99
- width=new_width,
100
- guidance_scale=image_scale,
101
- num_inference_steps=n_steps,
102
- generator=torch.Generator("cpu").manual_seed(seed),
103
- cloth_image=garm_img,
104
- model_image=vton_img,
105
- mask=mask,
106
- pose_image=pose_image,
107
- num_images_per_prompt=num_images_per_prompt
108
- ).images
109
- for idx in range(len(res)):
110
- res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
111
- return res
112
 
113
 
114
  def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
@@ -185,8 +189,7 @@ FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers
185
  If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
186
  """
187
 
188
- def create_demo(model_path, device, with_fp16):
189
- generator = FitDiTGenerator(model_path, device, with_fp16)
190
  with gr.Blocks(title="FitDiT") as demo:
191
  gr.Markdown(HEADER)
192
  with gr.Row():
@@ -294,15 +297,10 @@ def create_demo(model_path, device, with_fp16):
294
 
295
  ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
296
  ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
297
- run_mask_button.click(fn=generator.generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
298
- run_button.click(fn=generator.process, inputs=ips2, outputs=[result_gallery])
299
  return demo
300
 
301
  if __name__ == "__main__":
302
- import argparse
303
- parser = argparse.ArgumentParser(description="FitDiT")
304
- parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
305
- parser.add_argument("--fp16", action="store_true", help="Load model with fp16, default is bf16")
306
- args = parser.parse_args()
307
- demo = create_demo(repo_path, args.device, args.fp16)
308
  demo.launch()
 
24
  fitdit_repo = "BoyuanJiang/FitDiT"
25
  repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
26
 
27
+ weight_dtype = torch.bfloat16
28
+ device = "cuda"
29
+ transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
30
+ transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
31
+ pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
32
+ pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
33
+ image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
34
+ image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
35
+ pose_guider.to(device=device, dtype=weight_dtype)
36
+ image_encoder_large.to(device=device)
37
+ image_encoder_bigG.to(device=device)
38
+ pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
39
+ transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
40
+ image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
41
+ pipeline.to(device)
42
+ dwprocessor = DWposeDetector(model_root=repo_path, device=device)
43
+ parsing_model = Parsing(model_root=repo_path, device=device)
44
+
45
+
46
+
47
 
48
+ def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
49
+ with torch.inference_mode():
50
+ vton_img = Image.open(vton_img)
51
+ vton_img_det = resize_image(vton_img)
52
+ pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
53
+ candidate[candidate<0]=0
54
+ candidate = candidate[0]
55
 
56
+ candidate[:, 0]*=vton_img_det.width
57
+ candidate[:, 1]*=vton_img_det.height
58
 
59
+ pose_image = pose_image[:,:,::-1] #rgb
60
+ pose_image = Image.fromarray(pose_image)
61
+ model_parse, _ = parsing_model(vton_img_det)
62
 
63
+ mask, mask_gray = get_mask_location(category, model_parse, \
64
+ candidate, model_parse.width, model_parse.height, \
65
+ offset_top, offset_bottom, offset_left, offset_right)
66
+ mask = mask.resize(vton_img.size)
67
+ mask_gray = mask_gray.resize(vton_img.size)
68
+ mask = mask.convert("L")
69
+ mask_gray = mask_gray.convert("L")
70
+ masked_vton_img = Image.composite(mask_gray, vton_img, mask)
71
 
72
+ im = {}
73
+ im['background'] = np.array(vton_img.convert("RGBA"))
74
+ im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
75
+ im['composite'] = np.array(masked_vton_img.convert("RGBA"))
76
+
77
+ return im, pose_image
78
 
79
+ @spaces.GPU
80
+ def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
81
+ assert resolution in ["768x1024", "1152x1536", "1536x2048"]
82
+ new_width, new_height = resolution.split("x")
83
+ new_width = int(new_width)
84
+ new_height = int(new_height)
85
+ with torch.inference_mode():
86
+ garm_img = Image.open(garm_img)
87
+ vton_img = Image.open(vton_img)
88
 
89
+ model_image_size = vton_img.size
90
+ garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
91
+ vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
92
 
93
+ mask = pre_mask["layers"][0][:,:,3]
94
+ mask = Image.fromarray(mask)
95
+ mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
96
+ mask = mask.convert("L")
97
+ pose_image = Image.fromarray(pose_image)
98
+ pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
99
+ if seed==-1:
100
+ seed = random.randint(0, 2147483647)
101
+ res = pipeline(
102
+ height=new_height,
103
+ width=new_width,
104
+ guidance_scale=image_scale,
105
+ num_inference_steps=n_steps,
106
+ generator=torch.Generator("cpu").manual_seed(seed),
107
+ cloth_image=garm_img,
108
+ model_image=vton_img,
109
+ mask=mask,
110
+ pose_image=pose_image,
111
+ num_images_per_prompt=num_images_per_prompt
112
+ ).images
113
+ for idx in range(len(res)):
114
+ res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
115
+ return res
116
 
117
 
118
  def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
 
189
  If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
190
  """
191
 
192
+ def create_demo():
 
193
  with gr.Blocks(title="FitDiT") as demo:
194
  gr.Markdown(HEADER)
195
  with gr.Row():
 
297
 
298
  ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
299
  ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
300
+ run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
301
+ run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
302
  return demo
303
 
304
  if __name__ == "__main__":
305
+ demo = create_demo()
 
 
 
 
 
306
  demo.launch()