svjack commited on
Commit
168444f
·
verified ·
1 Parent(s): 200b932

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -151
app.py CHANGED
@@ -1,151 +1,151 @@
1
- import os
2
- import gradio as gr
3
- from gradio_imageslider import ImageSlider
4
- from loadimg import load_img
5
- import spaces
6
- from transformers import AutoModelForImageSegmentation
7
- import torch
8
- from torchvision import transforms
9
- from PIL import Image, ImageChops
10
- from moviepy.editor import VideoFileClip, ImageSequenceClip
11
- import numpy as np
12
- from tqdm import tqdm
13
- from uuid import uuid1
14
-
15
- # Check CUDA availability
16
- if torch.cuda.is_available():
17
- device = "cuda"
18
- else:
19
- device = "cpu"
20
-
21
- torch.set_float32_matmul_precision(["high", "highest"][0])
22
-
23
- # Load the model
24
- birefnet = AutoModelForImageSegmentation.from_pretrained(
25
- "briaai/RMBG-2.0", trust_remote_code=True
26
- )
27
- birefnet.to(device)
28
- transform_image = transforms.Compose(
29
- [
30
- transforms.Resize((1024, 1024)),
31
- transforms.ToTensor(),
32
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
- ]
34
- )
35
-
36
- output_folder = 'output_images'
37
- if not os.path.exists(output_folder):
38
- os.makedirs(output_folder)
39
-
40
- def fn(image):
41
- im = load_img(image, output_type="pil")
42
- im = im.convert("RGB")
43
- origin = im.copy()
44
- image = process(im)
45
- image_path = os.path.join(output_folder, "no_bg_image.png")
46
- image.save(image_path)
47
- return (image, origin), image_path
48
-
49
- @spaces.GPU
50
- def process(image):
51
- image_size = image.size
52
- input_images = transform_image(image).unsqueeze(0).to(device)
53
- # Prediction
54
- with torch.no_grad():
55
- preds = birefnet(input_images)[-1].sigmoid().cpu()
56
- pred = preds[0].squeeze()
57
- pred_pil = transforms.ToPILImage()(pred)
58
- mask = pred_pil.resize(image_size)
59
- image.putalpha(mask)
60
- return image
61
-
62
- def process_file(f):
63
- name_path = f.rsplit(".",1)[0]+".png"
64
- im = load_img(f, output_type="pil")
65
- im = im.convert("RGB")
66
- transparent = process(im)
67
- transparent.save(name_path)
68
- return name_path
69
-
70
- def remove_background(image):
71
- """Remove background from a single image."""
72
- input_images = transform_image(image).unsqueeze(0).to(device)
73
-
74
- # Prediction
75
- with torch.no_grad():
76
- preds = birefnet(input_images)[-1].sigmoid().cpu()
77
- pred = preds[0].squeeze()
78
-
79
- # Convert the prediction to a mask
80
- mask = (pred * 255).byte() # Convert to 0-255 range
81
- mask_pil = transforms.ToPILImage()(mask).convert("L")
82
- mask_resized = mask_pil.resize(image.size, Image.LANCZOS)
83
-
84
- # Apply the mask to the image
85
- image.putalpha(mask_resized)
86
-
87
- return image, mask_resized
88
-
89
- def process_video(input_video_path):
90
- """Process a video to remove the background from each frame."""
91
- # Load the video
92
- video_clip = VideoFileClip(input_video_path)
93
-
94
- # Process each frame
95
- frames = []
96
- for frame in tqdm(video_clip.iter_frames()):
97
- frame_pil = Image.fromarray(frame)
98
- frame_no_bg, mask_resized = remove_background(frame_pil)
99
- path = "{}.png".format(uuid1())
100
- frame_no_bg.save(path)
101
- frame_no_bg = Image.open(path).convert("RGBA")
102
- os.remove(path)
103
-
104
- # Convert mask_resized to RGBA mode
105
- mask_resized_rgba = mask_resized.convert("RGBA")
106
-
107
- # Apply the mask using ImageChops.multiply
108
- output = ImageChops.multiply(frame_no_bg, mask_resized_rgba)
109
- output_np = np.array(output)
110
- frames.append(output_np)
111
-
112
- # Save the processed frames as a new video
113
- output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
114
- processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
115
- processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
116
-
117
- return output_video_path
118
-
119
- # Gradio components
120
- slider1 = ImageSlider(label="RMBG-2.0", type="pil")
121
- slider2 = ImageSlider(label="RMBG-2.0", type="pil")
122
- image = gr.Image(label="Upload an image")
123
- image2 = gr.Image(label="Upload an image", type="filepath")
124
- text = gr.Textbox(label="Paste an image URL")
125
- png_file = gr.File(label="output png file")
126
- video_input = gr.Video(label="Upload a video")
127
- video_output = gr.Video(label="Processed video")
128
-
129
- # Example videos
130
- example_videos = [
131
- "pexels-cottonbro-5319934.mp4",
132
- "300_A_car_is_running_on_the_road.mp4",
133
- "A_Terracotta_Warrior_is_skateboarding_9033688.mp4"
134
- ]
135
-
136
- # Gradio interfaces
137
- tab1 = gr.Interface(
138
- fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[load_img("giraffe.jpg", output_type="pil")], api_name="image"
139
- )
140
-
141
- tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
142
- #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
143
- tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video")
144
-
145
- # Gradio tabbed interface
146
- demo = gr.TabbedInterface(
147
- [tab4, tab1, tab2], ["input video", "input image", "input url"], title="RMBG-2.0 for background removal"
148
- )
149
-
150
- if __name__ == "__main__":
151
- demo.launch(share=True, show_error=True)
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio_imageslider import ImageSlider
4
+ from loadimg import load_img
5
+ import spaces
6
+ from transformers import AutoModelForImageSegmentation
7
+ import torch
8
+ from torchvision import transforms
9
+ from PIL import Image, ImageChops
10
+ from moviepy.editor import VideoFileClip, ImageSequenceClip
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from uuid import uuid1
14
+
15
+ # Check CUDA availability
16
+ if torch.cuda.is_available():
17
+ device = "cuda"
18
+ else:
19
+ device = "cpu"
20
+
21
+ torch.set_float32_matmul_precision(["high", "highest"][0])
22
+
23
+ # Load the model
24
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
25
+ "briaai/RMBG-2.0", trust_remote_code=True
26
+ )
27
+ birefnet.to(device)
28
+ transform_image = transforms.Compose(
29
+ [
30
+ transforms.Resize((1024, 1024)),
31
+ transforms.ToTensor(),
32
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
33
+ ]
34
+ )
35
+
36
+ output_folder = 'output_images'
37
+ if not os.path.exists(output_folder):
38
+ os.makedirs(output_folder)
39
+
40
+ def fn(image):
41
+ im = load_img(image, output_type="pil")
42
+ im = im.convert("RGB")
43
+ origin = im.copy()
44
+ image = process(im)
45
+ image_path = os.path.join(output_folder, "no_bg_image.png")
46
+ image.save(image_path)
47
+ return (image, origin), image_path
48
+
49
+ @spaces.GPU
50
+ def process(image):
51
+ image_size = image.size
52
+ input_images = transform_image(image).unsqueeze(0).to(device)
53
+ # Prediction
54
+ with torch.no_grad():
55
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
56
+ pred = preds[0].squeeze()
57
+ pred_pil = transforms.ToPILImage()(pred)
58
+ mask = pred_pil.resize(image_size)
59
+ image.putalpha(mask)
60
+ return image
61
+
62
+ def process_file(f):
63
+ name_path = f.rsplit(".",1)[0]+".png"
64
+ im = load_img(f, output_type="pil")
65
+ im = im.convert("RGB")
66
+ transparent = process(im)
67
+ transparent.save(name_path)
68
+ return name_path
69
+
70
+ def remove_background(image):
71
+ """Remove background from a single image."""
72
+ input_images = transform_image(image).unsqueeze(0).to(device)
73
+
74
+ # Prediction
75
+ with torch.no_grad():
76
+ preds = birefnet(input_images)[-1].sigmoid().cpu()
77
+ pred = preds[0].squeeze()
78
+
79
+ # Convert the prediction to a mask
80
+ mask = (pred * 255).byte() # Convert to 0-255 range
81
+ mask_pil = transforms.ToPILImage()(mask).convert("L")
82
+ mask_resized = mask_pil.resize(image.size, Image.LANCZOS)
83
+
84
+ # Apply the mask to the image
85
+ image.putalpha(mask_resized)
86
+
87
+ return image, mask_resized
88
+
89
+ def process_video(input_video_path):
90
+ """Process a video to remove the background from each frame."""
91
+ # Load the video
92
+ video_clip = VideoFileClip(input_video_path)
93
+
94
+ # Process each frame
95
+ frames = []
96
+ for frame in tqdm(video_clip.iter_frames()):
97
+ frame_pil = Image.fromarray(frame)
98
+ frame_no_bg, mask_resized = remove_background(frame_pil)
99
+ path = "{}.png".format(uuid1())
100
+ frame_no_bg.save(path)
101
+ frame_no_bg = Image.open(path).convert("RGBA")
102
+ os.remove(path)
103
+
104
+ # Convert mask_resized to RGBA mode
105
+ mask_resized_rgba = mask_resized.convert("RGBA")
106
+
107
+ # Apply the mask using ImageChops.multiply
108
+ output = ImageChops.multiply(frame_no_bg, mask_resized_rgba)
109
+ output_np = np.array(output)
110
+ frames.append(output_np)
111
+
112
+ # Save the processed frames as a new video
113
+ output_video_path = os.path.join(output_folder, "no_bg_video.mp4")
114
+ processed_clip = ImageSequenceClip(frames, fps=video_clip.fps)
115
+ processed_clip.write_videofile(output_video_path, codec='libx264', ffmpeg_params=['-pix_fmt', 'yuva420p'])
116
+
117
+ return output_video_path
118
+
119
+ # Gradio components
120
+ slider1 = ImageSlider(label="RMBG-2.0", type="pil")
121
+ slider2 = ImageSlider(label="RMBG-2.0", type="pil")
122
+ image = gr.Image(label="Upload an image")
123
+ image2 = gr.Image(label="Upload an image", type="filepath")
124
+ text = gr.Textbox(label="Paste an image URL")
125
+ png_file = gr.File(label="output png file")
126
+ video_input = gr.Video(label="Upload a video")
127
+ video_output = gr.Video(label="Processed video")
128
+
129
+ # Example videos
130
+ example_videos = [
131
+ "pexels-cottonbro-5319934.mp4",
132
+ "300_A_car_is_running_on_the_road.mp4",
133
+ "A_Terracotta_Warrior_is_skateboarding_9033688.mp4"
134
+ ]
135
+
136
+ # Gradio interfaces
137
+ tab1 = gr.Interface(
138
+ fn, inputs=image, outputs=[slider1, gr.File(label="output png file")], examples=[load_img("giraffe.jpg", output_type="pil")], api_name="image"
139
+ )
140
+
141
+ tab2 = gr.Interface(fn, inputs=text, outputs=[slider2, gr.File(label="output png file")], examples=["http://farm9.staticflickr.com/8488/8228323072_76eeddfea3_z.jpg"], api_name="text")
142
+ #tab3 = gr.Interface(process_file, inputs=image2, outputs=png_file, examples=["giraffe.jpg"], api_name="png")
143
+ tab4 = gr.Interface(process_video, inputs=video_input, outputs=video_output, examples=example_videos, api_name="video", cache_examples = False)
144
+
145
+ # Gradio tabbed interface
146
+ demo = gr.TabbedInterface(
147
+ [tab4, tab1, tab2], ["input video", "input image", "input url"], title="RMBG-2.0 for background removal"
148
+ )
149
+
150
+ if __name__ == "__main__":
151
+ demo.launch(share=True, show_error=True)