beingcognitive commited on
Commit
9843137
·
verified ·
1 Parent(s): 31361ed

For Tech Campus class

Browse files
Files changed (1) hide show
  1. app.py +66 -96
app.py CHANGED
@@ -1,14 +1,55 @@
1
  import streamlit as st
2
- # from transformers import AutoProcessor, AutoModelForMaskGeneration
3
- from transformers import SamModel, SamProcessor
4
- from transformers import pipeline
5
  from PIL import Image, ImageOps
6
- # from PIL import Image
7
  import numpy as np
8
- # import matplotlib.pyplot as plt
9
  import torch
10
- import requests
11
- from io import BytesIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def main():
14
  st.title("Image Segmentation with Object Detection")
@@ -29,106 +70,35 @@ def main():
29
  st.write("- Object Detection Model: `facebook/detr-resnet-50`")
30
  st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
31
 
32
-
33
- # Load SAM by Facebook
34
- # processor = AutoProcessor.from_pretrained("facebook/sam-vit-huge")
35
- # model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-huge")
36
- model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
37
- processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
38
- # Load Object Detection
39
- od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
40
 
41
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
42
 
43
- xs_ys = [(2.0, 2.0), (2.5, 2.5)] #, (2.5, 2.0), (2.0, 2.5), (1.5, 1.5)]
44
- alpha = 20
45
- width = 600
46
-
47
  if uploaded_file is not None:
48
  raw_image = Image.open(uploaded_file)
49
-
50
  st.subheader("Uploaded Image")
51
- st.image(raw_image, caption="Uploaded Image", width=width)
52
-
53
- ### STEP 1. Object Detection
54
- pipeline_output = od_pipe(raw_image)
55
-
56
- # Convert the bounding boxes from the pipeline output into the expected format for the SAM processor
57
- input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
58
- labels_format = [b['label'] for b in pipeline_output]
59
- print(input_boxes_format)
60
- print(labels_format)
61
 
62
- # Now use these formatted boxes with the processor
63
- for b, l in zip(input_boxes_format, labels_format):
64
- with st.spinner('Processing...'):
 
 
65
 
 
 
66
  st.subheader(f'bounding box : {l}')
67
- inputs = processor(images=raw_image,
68
- input_boxes=[b],
69
- return_tensors="pt")
70
-
71
- with torch.no_grad():
72
- outputs = model(**inputs)
73
 
74
- predicted_masks = processor.image_processor.post_process_masks(
75
- outputs.pred_masks,
76
- inputs["original_sizes"],
77
- inputs["reshaped_input_sizes"]
78
- )
79
- predicted_mask = predicted_masks[0]
80
-
81
- for i in range(0, 3):
82
- # 2D array (boolean mask)
83
- mask = predicted_mask[0][i]
84
- int_mask = np.array(mask).astype(int) * 255
85
- mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
86
-
87
- # Apply the mask to the image
88
- # Convert mask to a 3-channel image if your base image is in RGB
89
- mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
90
- final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
91
-
92
- #display the final image
93
- st.image(final_image, caption=f"Masked Image {i+1}", width=width)
94
-
95
- ###
96
- for (x, y) in xs_ys:
97
- with st.spinner('Processing...'):
98
-
99
- # Calculate input points
100
- point_x = raw_image.size[0] // x
101
- point_y = raw_image.size[1] // y
102
- input_points = [[[ point_x, point_y ]]]
103
-
104
- # Prepare inputs
105
- inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
106
-
107
- # Generate masks
108
- with torch.no_grad():
109
- outputs = model(**inputs)
110
-
111
- # Post-process masks
112
- predicted_masks = processor.image_processor.post_process_masks(
113
- outputs.pred_masks,
114
- inputs["original_sizes"],
115
- inputs["reshaped_input_sizes"]
116
- )
117
-
118
- predicted_mask = predicted_masks[0]
119
-
120
- # Display masked images
121
  st.subheader(f"Input points : ({1/x},{1/y})")
122
- for i in range(3):
123
- mask = predicted_mask[0][i]
124
- int_mask = np.array(mask).astype(int) * 255
125
- mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
126
-
127
- ###
128
- mask_image_rgb = ImageOps.colorize(mask_image, (0, 0, 0), (255, 255, 255))
129
- final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
130
-
131
- st.image(final_image, caption=f"Masked Image {i+1}", width=width)
132
 
133
  if __name__ == "__main__":
134
  main()
 
1
  import streamlit as st
2
+ from transformers import SamModel, SamProcessor, pipeline
 
 
3
  from PIL import Image, ImageOps
 
4
  import numpy as np
 
5
  import torch
6
+
7
+ # Constants
8
+ XS_YS = [(2.0, 2.0), (2.5, 2.5)]
9
+ WIDTH = 600
10
+
11
+ # Load models
12
+ @st.cache_resource
13
+ def load_models():
14
+ model = SamModel.from_pretrained("Zigeng/SlimSAM-uniform-77")
15
+ processor = SamProcessor.from_pretrained("Zigeng/SlimSAM-uniform-77")
16
+ od_pipe = pipeline("object-detection", "facebook/detr-resnet-50")
17
+ return model, processor, od_pipe
18
+
19
+ def process_image(image, model, processor, bounding_box=None, input_point=None):
20
+ try:
21
+ # Convert image to RGB mode
22
+ image = image.convert('RGB')
23
+ # Convert image to numpy array
24
+ image_array = np.array(image)
25
+
26
+ if bounding_box:
27
+ inputs = processor(images=image_array, input_boxes=[bounding_box], return_tensors="pt")
28
+ elif input_point:
29
+ inputs = processor(images=image_array, input_points=[[input_point]], return_tensors="pt")
30
+ else:
31
+ raise ValueError("Either bounding_box or input_point must be provided")
32
+
33
+ with torch.no_grad():
34
+ outputs = model(**inputs)
35
+
36
+ predicted_masks = processor.image_processor.post_process_masks(
37
+ outputs.pred_masks,
38
+ inputs["original_sizes"],
39
+ inputs["reshaped_input_sizes"]
40
+ )
41
+ return predicted_masks[0]
42
+ except Exception as e:
43
+ st.error(f"Error processing image: {str(e)}")
44
+ return None
45
+
46
+ def display_masked_images(raw_image, predicted_mask, caption_prefix):
47
+ for i in range(3):
48
+ mask = predicted_mask[0][i]
49
+ int_mask = np.array(mask).astype(int) * 255
50
+ mask_image = Image.fromarray(int_mask.astype('uint8'), mode='L')
51
+ final_image = Image.composite(raw_image, Image.new('RGB', raw_image.size, (255,255,255)), mask_image)
52
+ st.image(final_image, caption=f"{caption_prefix} {i+1}", width=WIDTH)
53
 
54
  def main():
55
  st.title("Image Segmentation with Object Detection")
 
70
  st.write("- Object Detection Model: `facebook/detr-resnet-50`")
71
  st.write("- Segmentation Model: `Zigeng/SlimSAM-uniform-77`")
72
 
73
+ model, processor, od_pipe = load_models()
 
 
 
 
 
 
 
74
 
75
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
76
 
 
 
 
 
77
  if uploaded_file is not None:
78
  raw_image = Image.open(uploaded_file)
 
79
  st.subheader("Uploaded Image")
80
+ st.image(raw_image, caption="Uploaded Image", width=WIDTH)
 
 
 
 
 
 
 
 
 
81
 
82
+ with st.spinner('Processing image...'):
83
+ # Object Detection
84
+ pipeline_output = od_pipe(raw_image)
85
+ input_boxes_format = [[[b['box']['xmin'], b['box']['ymin']], [b['box']['xmax'], b['box']['ymax']]] for b in pipeline_output]
86
+ labels_format = [b['label'] for b in pipeline_output]
87
 
88
+ # Process bounding boxes
89
+ for b, l in zip(input_boxes_format, labels_format):
90
  st.subheader(f'bounding box : {l}')
91
+ predicted_mask = process_image(raw_image, model, processor, bounding_box=b)
92
+ if predicted_mask is not None:
93
+ display_masked_images(raw_image, predicted_mask, "Masked Image")
 
 
 
94
 
95
+ # Process input points
96
+ for (x, y) in XS_YS:
97
+ point_x, point_y = raw_image.size[0] // x, raw_image.size[1] // y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  st.subheader(f"Input points : ({1/x},{1/y})")
99
+ predicted_mask = process_image(raw_image, model, processor, input_point=[point_x, point_y])
100
+ if predicted_mask is not None:
101
+ display_masked_images(raw_image, predicted_mask, "Masked Image")
 
 
 
 
 
 
 
102
 
103
  if __name__ == "__main__":
104
  main()