laudavid commited on
Commit
4bfc75e
·
1 Parent(s): 4cf63f7

modify od fashion

Browse files
data/dior_show/dior1.jpg ADDED
data/dior_show/dior2.jpg ADDED
data/dior_show/dior3.jpg ADDED
data/dior_show/dior4.jpg ADDED
images/fashion_ai.jpg ADDED
images/fashion_od.jpg ADDED
images/fashion_od2.png ADDED
pages/object_detection.py CHANGED
@@ -4,14 +4,24 @@ import streamlit as st
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import numpy as np
7
- import altair as alt
 
 
8
 
9
  from PIL import Image
10
  from transformers import YolosFeatureExtractor, YolosForObjectDetection
11
  from torchvision.transforms import ToTensor, ToPILImage
 
 
12
 
13
  st.set_page_config(layout="wide")
14
 
 
 
 
 
 
 
15
 
16
  def rgb_to_hex(rgb):
17
  """Converts an RGB tuple to an HTML-style Hex string."""
@@ -76,7 +86,7 @@ def plot_results(pil_img, prob, boxes):
76
 
77
  plt.savefig("results_od.png",
78
  bbox_inches ="tight")
79
- #plt.show()
80
  st.image("results_od.png")
81
 
82
  return colors_used
@@ -112,15 +122,23 @@ def visualize_probas(probas, threshold, colors):
112
  top_label_df["colors"] = colors
113
  top_label_df.sort_values(by=["proba"], ascending=False, inplace=True)
114
 
115
- st.dataframe(top_label_df.drop(columns=["colors"]))
116
 
117
  mode_func = lambda x: x.mode().iloc[0]
118
  top_label_df_agg = top_label_df.groupby("label").agg({"proba":"mean", "colors":mode_func})
119
  top_label_df_agg = top_label_df_agg.reset_index().sort_values(by=["proba"], ascending=False)
 
 
 
 
 
 
 
 
120
 
121
- chart = alt.Chart(top_label_df_agg).mark_bar().encode(x="proba", y="label",
122
- color=alt.Color('colors:N', scale=None)).interactive()
123
- #st.altair_chart(chart)
124
 
125
 
126
 
@@ -156,34 +174,38 @@ st.markdown("""Common applications of Object Detection include:
156
  st.markdown(" ")
157
  st.divider()
158
 
159
- st.markdown("### Fashion object detection 👗")
160
- st.markdown(""" The following example showcases the use of an **Object detection algorithm** for clothing items/features on fashion images. <br>
161
- This use case can be seen as an application of AI models for Fashion and E-commerce. <br>
162
- """, unsafe_allow_html=True)
 
 
 
163
 
164
- st.image("images/od_fashion.jpg", width=700)
 
 
 
 
165
 
166
- #images_dior = [os.path.join("data/dior_show",url) for url in os.listdir("data/dior_show") if url != "results"]
167
- #st.image(images_dior, width=250, caption=[file for file in os.listdir("data/dior_show") if file != "results"])
168
 
169
  st.markdown(" ")
170
- #st.markdown("##### Select an image")
171
 
172
 
173
  ############## SELECT AN IMAGE ###############
174
 
175
- st.markdown("#### Step 1: Select an image")
176
- st.info("""First, select the image that you wish to use the Object detecion model on.""")
177
- st.markdown("**Note:** The model was trained to detect clothing items on a single person. If your image has more than individuals, the model will ignore one of them in its detection.")
178
 
179
  image_ = None
180
  select_image_box = st.radio(
181
- "",
182
  ["Choose an existing image", "Load your own image"],
183
- index=None, label_visibility="collapsed")
184
 
185
  if select_image_box == "Choose an existing image":
186
- fashion_images_path = r"data/pinterest"
187
  list_images = os.listdir(fashion_images_path)
188
  image_ = st.selectbox("", list_images, label_visibility="collapsed")
189
 
@@ -198,6 +220,8 @@ elif select_image_box == "Load your own image":
198
 
199
  st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
200
  Choose this type of image if you want optimal results.""")
 
 
201
 
202
  if image_ is not None:
203
  st.image(Image.open(image_), width=300)
@@ -216,7 +240,7 @@ cats = ['shirt, blouse', 'top, t-shirt, sweatshirt', 'sweater', 'cardigan', 'jac
216
 
217
  dict_cats = dict(zip(np.arange(len(cats)), cats))
218
 
219
- st.markdown("#### Step 2: Choose the elements you want to detect")
220
 
221
  # Select one or more elements to detect
222
  container = st.container()
@@ -239,21 +263,31 @@ st.markdown(" ")
239
 
240
  ############## SELECT A THRESHOLD ###############
241
 
242
- st.markdown("#### Step 3: Select a threshold")
 
 
 
243
 
244
- st.markdown("""Finally, select a threshold for the model.
245
- The threshold helps you decide how confident you want your model to be with its predictions.
246
- Elements that were identified with a lower probability than the given threshold will be ignored in the final results.""")
247
 
248
- threshold = st.slider('**Select a threshold**', min_value=0.0, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed")
249
- # min_value=0.000000, step=0.000001, max_value=0.0005, value=0.0000045, format="%f"
250
 
251
- if threshold < 0.6:
252
- st.warning("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
 
253
 
254
- st.write("You've selected a threshold at", threshold)
 
 
 
 
 
255
 
 
256
 
 
 
 
 
257
  st.markdown(" ")
258
 
259
 
@@ -269,29 +303,37 @@ if run_model:
269
  image = fix_channels(ToTensor()(image))
270
 
271
  ## LOAD OBJECT DETECTION MODEL
272
- MODEL_NAME = "valentinafeve/yolos-fashionpedia"
273
- feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
274
- model = YolosForObjectDetection.from_pretrained(MODEL_NAME)
 
 
275
 
276
  # RUN MODEL ON IMAGE
277
  inputs = feature_extractor(images=image, return_tensors="pt")
278
  outputs = model(**inputs)
279
  probas, keep = return_probas(outputs, threshold)
280
 
 
 
281
  # PLOT BOUNDING BOX AND BARS/PROBA
282
  col1, col2 = st.columns(2)
283
  with col1:
284
- st.markdown("##### Bounding box results")
285
  bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
286
  colors_used = plot_results(image, probas[keep], bboxes_scaled)
287
 
288
  with col2:
289
- visualize_probas(probas, threshold, colors_used)
 
 
 
 
 
290
 
291
- st.info("Done")
292
 
293
  else:
294
- st.warning("You must select an **image**, **elements to detect** and a **threshold** to run the model !")
295
 
296
 
297
 
 
4
  import matplotlib.pyplot as plt
5
  import pandas as pd
6
  import numpy as np
7
+ #import altair as alt
8
+ import plotly.express as px
9
+
10
 
11
  from PIL import Image
12
  from transformers import YolosFeatureExtractor, YolosForObjectDetection
13
  from torchvision.transforms import ToTensor, ToPILImage
14
+ #from utils import load_model_huggingface
15
+
16
 
17
  st.set_page_config(layout="wide")
18
 
19
+ @st.cache_data(ttl=3600, show_spinner=False)
20
+ def load_model(feature_extractor_url, model_url):
21
+ feature_extractor_ = YolosFeatureExtractor.from_pretrained(feature_extractor_url)
22
+ model_ = YolosForObjectDetection.from_pretrained(model_url)
23
+ return feature_extractor_, model_
24
+
25
 
26
  def rgb_to_hex(rgb):
27
  """Converts an RGB tuple to an HTML-style Hex string."""
 
86
 
87
  plt.savefig("results_od.png",
88
  bbox_inches ="tight")
89
+ plt.show()
90
  st.image("results_od.png")
91
 
92
  return colors_used
 
122
  top_label_df["colors"] = colors
123
  top_label_df.sort_values(by=["proba"], ascending=False, inplace=True)
124
 
125
+ #st.dataframe(top_label_df.drop(columns=["colors"]))
126
 
127
  mode_func = lambda x: x.mode().iloc[0]
128
  top_label_df_agg = top_label_df.groupby("label").agg({"proba":"mean", "colors":mode_func})
129
  top_label_df_agg = top_label_df_agg.reset_index().sort_values(by=["proba"], ascending=False)
130
+ top_label_df_agg.columns = ["Item","Score","Colors"]
131
+
132
+ color_map = dict(zip(top_label_df_agg["Item"].to_list(),
133
+ top_label_df_agg["Colors"].to_list()))
134
+
135
+ fig = px.bar(top_label_df_agg, y='Item', x='Score',
136
+ color="Item", title="Probability scores")
137
+ st.plotly_chart(fig, use_container_width=True)
138
 
139
+ # chart = alt.Chart(top_label_df_agg).mark_bar().encode(x="proba", y="label",
140
+ # color=alt.Color('colors:N', scale=None)).interactive()
141
+ # st.altair_chart(chart)
142
 
143
 
144
 
 
174
  st.markdown(" ")
175
  st.divider()
176
 
177
+ st.markdown("## Fashion Object Detection 👗")
178
+ # st.info("""This use case showcases the application of **Object detection** to detect clothing items/features on images. <br>
179
+ # The images used were gathered from Dior's""")
180
+ st.info("""In this use case, we are going to identify and locate different articles of clothings, as well as finer details such as a collar or pocket using an object detection AI model.
181
+ The images used were taken from **Dior's 2020 Fall Women Fashion Show**.""")
182
+
183
+ st.markdown(" ")
184
 
185
+ images_dior = [os.path.join("data/dior_show",url) for url in os.listdir("data/dior_show") if url != "results"]
186
+ columns_img = st.columns(4)
187
+ for img, col in zip(images_dior,columns_img):
188
+ with col:
189
+ st.image(img)
190
 
 
 
191
 
192
  st.markdown(" ")
 
193
 
194
 
195
  ############## SELECT AN IMAGE ###############
196
 
197
+ st.markdown("#### Select an image 🖼️")
198
+ #st.markdown("""**Select an image that you wish to run the Object Detection model on.**""")
199
+
200
 
201
  image_ = None
202
  select_image_box = st.radio(
203
+ "**Select the image you wish to run the model on**",
204
  ["Choose an existing image", "Load your own image"],
205
+ index=None,)# #label_visibility="collapsed")
206
 
207
  if select_image_box == "Choose an existing image":
208
+ fashion_images_path = r"data/dior_show"
209
  list_images = os.listdir(fashion_images_path)
210
  image_ = st.selectbox("", list_images, label_visibility="collapsed")
211
 
 
220
 
221
  st.warning("""**Note**: The model tends to perform better with images of people/clothing items facing forward.
222
  Choose this type of image if you want optimal results.""")
223
+ st.warning("""**Note:** The model was trained to detect clothing items on a single person.
224
+ If your image contains more than one person, the model won't detect the items of the other persons.""")
225
 
226
  if image_ is not None:
227
  st.image(Image.open(image_), width=300)
 
240
 
241
  dict_cats = dict(zip(np.arange(len(cats)), cats))
242
 
243
+ st.markdown("#### Choose the elements you want to detect 👉")
244
 
245
  # Select one or more elements to detect
246
  container = st.container()
 
263
 
264
  ############## SELECT A THRESHOLD ###############
265
 
266
+ st.markdown("#### Define a threshold for predictions 🔎")
267
+ st.markdown("""Object detection models assign to each element detected a **probability score**. <br>
268
+ This score represents the model's belief in the accuracy of its prediction for a specific object.
269
+ """, unsafe_allow_html=True)
270
 
271
+ st.warning("**Note:** Objects that are assigned a lower score than the chosen threshold will be ignored in the final results.")
 
 
272
 
 
 
273
 
274
+ _, col, _ = st.columns([0.2,0.6,0.2])
275
+ with col:
276
+ st.image("images/probability_od.png", caption="Example of object detection with probability scores")
277
 
278
+ st.markdown(" ")
279
+
280
+ st.markdown("**Select a threshold** ")
281
+
282
+ # st.warning("""**Note**: The threshold helps you decide how confident you want your model to be with its predictions.
283
+ # Elements that are identified with a lower probability than the given threshold will be ignored in the final results.""")
284
 
285
+ threshold = st.slider('**Select a threshold**', min_value=0.5, step=0.05, max_value=1.0, value=0.75, label_visibility="collapsed")
286
 
287
+ if threshold < 0.6:
288
+ st.error("""**Warning**: Selecting a low threshold (below 0.6) could lead the model to make errors and detect too many objects.""")
289
+
290
+ st.write("You've selected a threshold at", threshold)
291
  st.markdown(" ")
292
 
293
 
 
303
  image = fix_channels(ToTensor()(image))
304
 
305
  ## LOAD OBJECT DETECTION MODEL
306
+ FEATURE_EXTRACTOR_PATH = "hustvl/yolos-small"
307
+ MODEL_PATH = "valentinafeve/yolos-fashionpedia"
308
+ feature_extractor, model = load_model(FEATURE_EXTRACTOR_PATH, MODEL_PATH)
309
+ # feature_extractor = YolosFeatureExtractor.from_pretrained('hustvl/yolos-small')
310
+ # model = YolosForObjectDetection.from_pretrained(MODEL)
311
 
312
  # RUN MODEL ON IMAGE
313
  inputs = feature_extractor(images=image, return_tensors="pt")
314
  outputs = model(**inputs)
315
  probas, keep = return_probas(outputs, threshold)
316
 
317
+ st.markdown("#### See the results ☑️")
318
+
319
  # PLOT BOUNDING BOX AND BARS/PROBA
320
  col1, col2 = st.columns(2)
321
  with col1:
322
+ #st.markdown("**Bounding box results**")
323
  bboxes_scaled = rescale_bboxes(outputs.pred_boxes[0, keep].cpu(), image.size)
324
  colors_used = plot_results(image, probas[keep], bboxes_scaled)
325
 
326
  with col2:
327
+ #st.markdown("**Probability scores**")
328
+ if not any(keep.tolist()):
329
+ st.error("""No objects were detected on the image.
330
+ Decrease your threshold or choose differents items to detect.""")
331
+ else:
332
+ visualize_probas(probas, threshold, colors_used)
333
 
 
334
 
335
  else:
336
+ st.error("You must select an **image**, **elements to detect** and a **threshold** to run the model !")
337
 
338
 
339