Spaces:
Running
on
Zero
Running
on
Zero
BoyuanJiang
commited on
Commit
·
57b4b9a
1
Parent(s):
08d8dcb
update
Browse files
app.py
CHANGED
@@ -24,91 +24,95 @@ access_token = os.getenv("HF_TOKEN")
|
|
24 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
25 |
repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
|
54 |
-
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
|
113 |
|
114 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
@@ -185,8 +189,7 @@ FitDiT is designed for high-fidelity virtual try-on using Diffusion Transformers
|
|
185 |
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
|
186 |
"""
|
187 |
|
188 |
-
def create_demo(
|
189 |
-
generator = FitDiTGenerator(model_path, device, with_fp16)
|
190 |
with gr.Blocks(title="FitDiT") as demo:
|
191 |
gr.Markdown(HEADER)
|
192 |
with gr.Row():
|
@@ -294,15 +297,10 @@ def create_demo(model_path, device, with_fp16):
|
|
294 |
|
295 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
296 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
297 |
-
run_mask_button.click(fn=
|
298 |
-
run_button.click(fn=
|
299 |
return demo
|
300 |
|
301 |
if __name__ == "__main__":
|
302 |
-
|
303 |
-
parser = argparse.ArgumentParser(description="FitDiT")
|
304 |
-
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
305 |
-
parser.add_argument("--fp16", action="store_true", help="Load model with fp16, default is bf16")
|
306 |
-
args = parser.parse_args()
|
307 |
-
demo = create_demo(repo_path, args.device, args.fp16)
|
308 |
demo.launch()
|
|
|
24 |
fitdit_repo = "BoyuanJiang/FitDiT"
|
25 |
repo_path = snapshot_download(repo_id=fitdit_repo, use_auth_token=access_token)
|
26 |
|
27 |
+
weight_dtype = torch.bfloat16
|
28 |
+
device = "cuda"
|
29 |
+
transformer_garm = SD3Transformer2DModel_Garm.from_pretrained(os.path.join(repo_path, "transformer_garm"), torch_dtype=weight_dtype)
|
30 |
+
transformer_vton = SD3Transformer2DModel_Vton.from_pretrained(os.path.join(repo_path, "transformer_vton"), torch_dtype=weight_dtype)
|
31 |
+
pose_guider = PoseGuider(conditioning_embedding_channels=1536, conditioning_channels=3, block_out_channels=(32, 64, 256, 512))
|
32 |
+
pose_guider.load_state_dict(torch.load(os.path.join(repo_path, "pose_guider", "diffusion_pytorch_model.bin")))
|
33 |
+
image_encoder_large = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=weight_dtype)
|
34 |
+
image_encoder_bigG = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", torch_dtype=weight_dtype)
|
35 |
+
pose_guider.to(device=device, dtype=weight_dtype)
|
36 |
+
image_encoder_large.to(device=device)
|
37 |
+
image_encoder_bigG.to(device=device)
|
38 |
+
pipeline = StableDiffusion3TryOnPipeline.from_pretrained(repo_path, torch_dtype=weight_dtype, \
|
39 |
+
transformer_garm=transformer_garm, transformer_vton=transformer_vton, pose_guider=pose_guider, \
|
40 |
+
image_encoder_large=image_encoder_large, image_encoder_bigG=image_encoder_bigG)
|
41 |
+
pipeline.to(device)
|
42 |
+
dwprocessor = DWposeDetector(model_root=repo_path, device=device)
|
43 |
+
parsing_model = Parsing(model_root=repo_path, device=device)
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
|
48 |
+
def generate_mask(vton_img, category, offset_top, offset_bottom, offset_left, offset_right):
|
49 |
+
with torch.inference_mode():
|
50 |
+
vton_img = Image.open(vton_img)
|
51 |
+
vton_img_det = resize_image(vton_img)
|
52 |
+
pose_image, keypoints, _, candidate = dwprocessor(np.array(vton_img_det)[:,:,::-1])
|
53 |
+
candidate[candidate<0]=0
|
54 |
+
candidate = candidate[0]
|
55 |
|
56 |
+
candidate[:, 0]*=vton_img_det.width
|
57 |
+
candidate[:, 1]*=vton_img_det.height
|
58 |
|
59 |
+
pose_image = pose_image[:,:,::-1] #rgb
|
60 |
+
pose_image = Image.fromarray(pose_image)
|
61 |
+
model_parse, _ = parsing_model(vton_img_det)
|
62 |
|
63 |
+
mask, mask_gray = get_mask_location(category, model_parse, \
|
64 |
+
candidate, model_parse.width, model_parse.height, \
|
65 |
+
offset_top, offset_bottom, offset_left, offset_right)
|
66 |
+
mask = mask.resize(vton_img.size)
|
67 |
+
mask_gray = mask_gray.resize(vton_img.size)
|
68 |
+
mask = mask.convert("L")
|
69 |
+
mask_gray = mask_gray.convert("L")
|
70 |
+
masked_vton_img = Image.composite(mask_gray, vton_img, mask)
|
71 |
|
72 |
+
im = {}
|
73 |
+
im['background'] = np.array(vton_img.convert("RGBA"))
|
74 |
+
im['layers'] = [np.concatenate((np.array(mask_gray.convert("RGB")), np.array(mask)[:,:,np.newaxis]),axis=2)]
|
75 |
+
im['composite'] = np.array(masked_vton_img.convert("RGBA"))
|
76 |
+
|
77 |
+
return im, pose_image
|
78 |
|
79 |
+
@spaces.GPU
|
80 |
+
def process(vton_img, garm_img, pre_mask, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution):
|
81 |
+
assert resolution in ["768x1024", "1152x1536", "1536x2048"]
|
82 |
+
new_width, new_height = resolution.split("x")
|
83 |
+
new_width = int(new_width)
|
84 |
+
new_height = int(new_height)
|
85 |
+
with torch.inference_mode():
|
86 |
+
garm_img = Image.open(garm_img)
|
87 |
+
vton_img = Image.open(vton_img)
|
88 |
|
89 |
+
model_image_size = vton_img.size
|
90 |
+
garm_img, _, _ = pad_and_resize(garm_img, new_width=new_width, new_height=new_height)
|
91 |
+
vton_img, pad_w, pad_h = pad_and_resize(vton_img, new_width=new_width, new_height=new_height)
|
92 |
|
93 |
+
mask = pre_mask["layers"][0][:,:,3]
|
94 |
+
mask = Image.fromarray(mask)
|
95 |
+
mask, _, _ = pad_and_resize(mask, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
96 |
+
mask = mask.convert("L")
|
97 |
+
pose_image = Image.fromarray(pose_image)
|
98 |
+
pose_image, _, _ = pad_and_resize(pose_image, new_width=new_width, new_height=new_height, pad_color=(0,0,0))
|
99 |
+
if seed==-1:
|
100 |
+
seed = random.randint(0, 2147483647)
|
101 |
+
res = pipeline(
|
102 |
+
height=new_height,
|
103 |
+
width=new_width,
|
104 |
+
guidance_scale=image_scale,
|
105 |
+
num_inference_steps=n_steps,
|
106 |
+
generator=torch.Generator("cpu").manual_seed(seed),
|
107 |
+
cloth_image=garm_img,
|
108 |
+
model_image=vton_img,
|
109 |
+
mask=mask,
|
110 |
+
pose_image=pose_image,
|
111 |
+
num_images_per_prompt=num_images_per_prompt
|
112 |
+
).images
|
113 |
+
for idx in range(len(res)):
|
114 |
+
res[idx] = unpad_and_resize(res[idx], pad_w, pad_h, model_image_size[0], model_image_size[1])
|
115 |
+
return res
|
116 |
|
117 |
|
118 |
def pad_and_resize(im, new_width=768, new_height=1024, pad_color=(255, 255, 255), mode=Image.LANCZOS):
|
|
|
189 |
If you like our work, please star <a href="https://github.com/BoyuanJiang/FitDiT" style="color: blue; text-decoration: underline;">our github repository</a>.
|
190 |
"""
|
191 |
|
192 |
+
def create_demo():
|
|
|
193 |
with gr.Blocks(title="FitDiT") as demo:
|
194 |
gr.Markdown(HEADER)
|
195 |
with gr.Row():
|
|
|
297 |
|
298 |
ips1 = [vton_img, category, offset_top, offset_bottom, offset_left, offset_right]
|
299 |
ips2 = [vton_img, garm_img, masked_vton_img, pose_image, n_steps, image_scale, seed, num_images_per_prompt, resolution]
|
300 |
+
run_mask_button.click(fn=generate_mask, inputs=ips1, outputs=[masked_vton_img, pose_image])
|
301 |
+
run_button.click(fn=process, inputs=ips2, outputs=[result_gallery])
|
302 |
return demo
|
303 |
|
304 |
if __name__ == "__main__":
|
305 |
+
demo = create_demo()
|
|
|
|
|
|
|
|
|
|
|
306 |
demo.launch()
|