h-siyuan commited on
Commit
5028d04
·
verified ·
1 Parent(s): d337207

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -98
app.py CHANGED
@@ -66,13 +66,11 @@ def array_to_image_path(image_array, session_id):
66
  img.save(filename)
67
  return os.path.abspath(filename)
68
 
69
- # Function to upload the file to S3
70
  def upload_to_s3(file_name, bucket, object_name=None):
71
  """Upload a file to an S3 bucket."""
72
  if object_name is None:
73
  object_name = file_name
74
 
75
- # Create an S3 client
76
  s3 = boto3.client('s3')
77
 
78
  try:
@@ -87,14 +85,10 @@ def upload_to_s3(file_name, bucket, object_name=None):
87
  return False
88
 
89
  @spaces.GPU
90
- def run_showui(image, query):
91
  """Main function for inference."""
92
- session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
93
  image_path = array_to_image_path(image, session_id)
94
 
95
- # Upload the image to S3
96
- upload_to_s3(image_path, 'altair.storage', object_name=f"ootb/images/{os.path.basename(image_path)}")
97
-
98
  messages = [
99
  {
100
  "role": "user",
@@ -106,7 +100,6 @@ def run_showui(image, query):
106
  }
107
  ]
108
 
109
- # Prepare inputs for the model
110
  global model
111
  model = model.to("cuda")
112
 
@@ -121,7 +114,6 @@ def run_showui(image, query):
121
  )
122
  inputs = inputs.to("cuda")
123
 
124
- # Generate output
125
  generated_ids = model.generate(**inputs, max_new_tokens=128)
126
  generated_ids_trimmed = [
127
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
@@ -130,50 +122,53 @@ def run_showui(image, query):
130
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
131
  )[0]
132
 
133
- # Parse the output into coordinates
134
  click_xy = ast.literal_eval(output_text)
135
-
136
- # Draw the point on the image
137
  result_image = draw_point(image_path, click_xy, radius=10)
138
- return result_image, str(click_xy), image_path, session_id
139
 
140
- # Modify the record_vote function
141
- def record_vote(vote_type, image_path, query, action_generated, session_id):
142
- """Record a vote in a JSON file and upload to S3."""
143
- vote_data = {
144
- "vote_type": vote_type,
145
  "image_path": image_path,
146
  "query": query,
147
- "action_generated": action_generated,
148
  "timestamp": datetime.now().isoformat()
149
  }
150
 
151
- local_file_name = f"votes_{session_id}.json"
152
 
153
- # Append vote data to the local JSON file
154
- with open(local_file_name, "a") as f:
155
- f.write(json.dumps(vote_data) + "\n")
156
 
157
- # Upload the updated JSON file to S3
158
- upload_to_s3(local_file_name, 'altair.storage', object_name=f"ootb/votes/{local_file_name}")
159
 
160
- return f"Your {vote_type} has been recorded. Thank you!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
162
- # Use session_id in the handle_vote function
163
- def handle_vote(vote_type, image_path, query, action_generated, session_id):
164
- """Handle vote recording by using the consistent image path."""
165
- if image_path is None:
166
- return "No image uploaded. Please upload an image before voting."
167
- return record_vote(vote_type, image_path, query, action_generated, session_id)
168
 
169
- # Load logo and encode to Base64
170
  with open("./assets/showui.png", "rb") as image_file:
171
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
172
 
173
- # Define layout and UI
174
  def build_demo(embed_mode, concurrency_count=1):
175
  with gr.Blocks(title="ShowUI Demo", theme=gr.themes.Default()) as demo:
176
- # State to store the consistent image path
177
  state_image_path = gr.State(value=None)
178
  state_session_id = gr.State(value=None)
179
 
@@ -181,15 +176,10 @@ def build_demo(embed_mode, concurrency_count=1):
181
  gr.HTML(
182
  f"""
183
  <div style="text-align: center; margin-bottom: 20px;">
184
- <!-- Image -->
185
  <div style="display: flex; justify-content: center;">
186
  <img src="data:image/png;base64,{base64_image}" alt="ShowUI" width="320" style="margin-bottom: 10px;"/>
187
  </div>
188
-
189
- <!-- Description -->
190
  <p>ShowUI is a lightweight vision-language-action model for GUI agents.</p>
191
-
192
- <!-- Links -->
193
  <div style="display: flex; justify-content: center; gap: 15px; font-size: 20px;">
194
  <a href="https://huggingface.co/showlab/ShowUI-2B" target="_blank">
195
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-ShowUI--2B-blue" alt="model"/>
@@ -207,7 +197,6 @@ def build_demo(embed_mode, concurrency_count=1):
207
 
208
  with gr.Row():
209
  with gr.Column(scale=3):
210
- # Input components
211
  imagebox = gr.Image(type="numpy", label="Input Screenshot")
212
  textbox = gr.Textbox(
213
  show_label=True,
@@ -216,7 +205,6 @@ def build_demo(embed_mode, concurrency_count=1):
216
  )
217
  submit_btn = gr.Button(value="Submit", variant="primary")
218
 
219
- # Placeholder examples
220
  gr.Examples(
221
  examples=[
222
  ["./examples/app_store.png", "Download Kindle."],
@@ -234,9 +222,7 @@ def build_demo(embed_mode, concurrency_count=1):
234
  )
235
 
236
  with gr.Column(scale=8):
237
- # Output components
238
  output_img = gr.Image(type="pil", label="Output Image")
239
- # Add a note below the image to explain the red point
240
  gr.HTML(
241
  """
242
  <p><strong>Note:</strong> The <span style="color: red;">red point</span> on the output image represents the predicted clickable coordinates.</p>
@@ -244,39 +230,21 @@ def build_demo(embed_mode, concurrency_count=1):
244
  )
245
  output_coords = gr.Textbox(label="Clickable Coordinates")
246
 
247
- # Buttons for voting, flagging, regenerating, and clearing
248
  with gr.Row(elem_id="action-buttons", equal_height=True):
249
- vote_btn = gr.Button(value="👍 Vote", variant="secondary")
250
- downvote_btn = gr.Button(value="👎 Downvote", variant="secondary")
251
- flag_btn = gr.Button(value="🚩 Flag", variant="secondary")
252
- regenerate_btn = gr.Button(value="🔄 Regenerate", variant="secondary")
253
- clear_btn = gr.Button(value="🗑️ Clear", interactive=True) # Combined Clear button
254
 
255
- # Define button actions
256
  def on_submit(image, query):
257
- """Handle the submit button click."""
258
  if image is None:
259
  raise ValueError("No image provided. Please upload an image before submitting.")
260
 
261
- # Generate consistent image path and store it in the state
262
- result_image, click_coords, image_path, session_id = run_showui(image, query)
263
- return result_image, click_coords, image_path, session_id
264
-
265
- def on_image_upload(image):
266
- """Generate a new session ID when a new image is uploaded."""
267
- if image is None:
268
- raise ValueError("No image provided. Please upload an image.")
269
 
270
- # Generate a new session ID
271
- new_session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
272
- return new_session_id
273
-
274
- imagebox.upload(
275
- on_image_upload,
276
- inputs=imagebox,
277
- outputs=state_session_id,
278
- queue=False
279
- )
280
 
281
  submit_btn.click(
282
  on_submit,
@@ -287,47 +255,26 @@ def build_demo(embed_mode, concurrency_count=1):
287
  clear_btn.click(
288
  lambda: (None, None, None, None, None),
289
  inputs=None,
290
- outputs=[imagebox, textbox, output_img, output_coords, state_image_path, state_session_id], # Clear all outputs
291
  queue=False
292
  )
293
 
294
- regenerate_btn.click(
295
- lambda image, query, state_image_path: run_showui(image, query),
296
- [imagebox, textbox, state_image_path],
297
- [output_img, output_coords, state_image_path],
298
- )
299
-
300
- # Record vote actions without feedback messages
301
- vote_btn.click(
302
- lambda image_path, query, action_generated, session_id: handle_vote(
303
- "upvote", image_path, query, action_generated, session_id
304
- ),
305
- inputs=[state_image_path, textbox, output_coords, state_session_id],
306
  outputs=[],
307
  queue=False
308
  )
309
 
310
  downvote_btn.click(
311
- lambda image_path, query, action_generated, session_id: handle_vote(
312
- "downvote", image_path, query, action_generated, session_id
313
- ),
314
- inputs=[state_image_path, textbox, output_coords, state_session_id],
315
- outputs=[],
316
- queue=False
317
- )
318
-
319
- flag_btn.click(
320
- lambda image_path, query, action_generated, session_id: handle_vote(
321
- "flag", image_path, query, action_generated, session_id
322
- ),
323
- inputs=[state_image_path, textbox, output_coords, state_session_id],
324
  outputs=[],
325
  queue=False
326
  )
327
 
328
  return demo
329
 
330
- # Launch the app
331
  if __name__ == "__main__":
332
  demo = build_demo(embed_mode=False)
333
  demo.queue(api_open=False).launch(
@@ -335,4 +282,4 @@ if __name__ == "__main__":
335
  server_port=7860,
336
  ssr_mode=False,
337
  debug=True,
338
- )
 
66
  img.save(filename)
67
  return os.path.abspath(filename)
68
 
 
69
  def upload_to_s3(file_name, bucket, object_name=None):
70
  """Upload a file to an S3 bucket."""
71
  if object_name is None:
72
  object_name = file_name
73
 
 
74
  s3 = boto3.client('s3')
75
 
76
  try:
 
85
  return False
86
 
87
  @spaces.GPU
88
+ def run_showui(image, query, session_id):
89
  """Main function for inference."""
 
90
  image_path = array_to_image_path(image, session_id)
91
 
 
 
 
92
  messages = [
93
  {
94
  "role": "user",
 
100
  }
101
  ]
102
 
 
103
  global model
104
  model = model.to("cuda")
105
 
 
114
  )
115
  inputs = inputs.to("cuda")
116
 
 
117
  generated_ids = model.generate(**inputs, max_new_tokens=128)
118
  generated_ids_trimmed = [
119
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
122
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
123
  )[0]
124
 
 
125
  click_xy = ast.literal_eval(output_text)
 
 
126
  result_image = draw_point(image_path, click_xy, radius=10)
127
+ return result_image, str(click_xy), image_path
128
 
129
+ def save_and_upload_data(image_path, query, session_id, votes=None):
130
+ """Save the data to a JSON file and upload to S3."""
131
+ votes = votes or {"upvotes": 0, "downvotes": 0}
132
+ data = {
 
133
  "image_path": image_path,
134
  "query": query,
135
+ "votes": votes,
136
  "timestamp": datetime.now().isoformat()
137
  }
138
 
139
+ local_file_name = f"data_{session_id}.json"
140
 
141
+ with open(local_file_name, "w") as f:
142
+ json.dump(data, f)
 
143
 
144
+ upload_to_s3(local_file_name, 'altair.storage', object_name=f"ootb/{local_file_name}")
 
145
 
146
+ return data
147
+
148
+ def update_vote(vote_type, session_id):
149
+ """Update the vote count and re-upload the JSON file."""
150
+ local_file_name = f"data_{session_id}.json"
151
+
152
+ with open(local_file_name, "r") as f:
153
+ data = json.load(f)
154
+
155
+ if vote_type == "upvote":
156
+ data["votes"]["upvotes"] += 1
157
+ elif vote_type == "downvote":
158
+ data["votes"]["downvotes"] += 1
159
+
160
+ with open(local_file_name, "w") as f:
161
+ json.dump(data, f)
162
+
163
+ upload_to_s3(local_file_name, 'altair.storage', object_name=f"ootb/{local_file_name}")
164
 
165
+ return f"Your {vote_type} has been recorded. Thank you!"
 
 
 
 
 
166
 
 
167
  with open("./assets/showui.png", "rb") as image_file:
168
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
169
 
 
170
  def build_demo(embed_mode, concurrency_count=1):
171
  with gr.Blocks(title="ShowUI Demo", theme=gr.themes.Default()) as demo:
 
172
  state_image_path = gr.State(value=None)
173
  state_session_id = gr.State(value=None)
174
 
 
176
  gr.HTML(
177
  f"""
178
  <div style="text-align: center; margin-bottom: 20px;">
 
179
  <div style="display: flex; justify-content: center;">
180
  <img src="data:image/png;base64,{base64_image}" alt="ShowUI" width="320" style="margin-bottom: 10px;"/>
181
  </div>
 
 
182
  <p>ShowUI is a lightweight vision-language-action model for GUI agents.</p>
 
 
183
  <div style="display: flex; justify-content: center; gap: 15px; font-size: 20px;">
184
  <a href="https://huggingface.co/showlab/ShowUI-2B" target="_blank">
185
  <img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-ShowUI--2B-blue" alt="model"/>
 
197
 
198
  with gr.Row():
199
  with gr.Column(scale=3):
 
200
  imagebox = gr.Image(type="numpy", label="Input Screenshot")
201
  textbox = gr.Textbox(
202
  show_label=True,
 
205
  )
206
  submit_btn = gr.Button(value="Submit", variant="primary")
207
 
 
208
  gr.Examples(
209
  examples=[
210
  ["./examples/app_store.png", "Download Kindle."],
 
222
  )
223
 
224
  with gr.Column(scale=8):
 
225
  output_img = gr.Image(type="pil", label="Output Image")
 
226
  gr.HTML(
227
  """
228
  <p><strong>Note:</strong> The <span style="color: red;">red point</span> on the output image represents the predicted clickable coordinates.</p>
 
230
  )
231
  output_coords = gr.Textbox(label="Clickable Coordinates")
232
 
 
233
  with gr.Row(elem_id="action-buttons", equal_height=True):
234
+ upvote_btn = gr.Button(value="Looks good!", variant="secondary")
235
+ downvote_btn = gr.Button(value="Too bad!", variant="secondary")
236
+ clear_btn = gr.Button(value="🗑️ Clear", interactive=True)
 
 
237
 
 
238
  def on_submit(image, query):
 
239
  if image is None:
240
  raise ValueError("No image provided. Please upload an image before submitting.")
241
 
242
+ session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
243
+ result_image, click_coords, image_path = run_showui(image, query, session_id)
 
 
 
 
 
 
244
 
245
+ save_and_upload_data(image_path, query, session_id)
246
+
247
+ return result_image, click_coords, image_path, session_id
 
 
 
 
 
 
 
248
 
249
  submit_btn.click(
250
  on_submit,
 
255
  clear_btn.click(
256
  lambda: (None, None, None, None, None),
257
  inputs=None,
258
+ outputs=[imagebox, textbox, output_img, output_coords, state_image_path, state_session_id],
259
  queue=False
260
  )
261
 
262
+ upvote_btn.click(
263
+ lambda session_id: update_vote("upvote", session_id),
264
+ inputs=state_session_id,
 
 
 
 
 
 
 
 
 
265
  outputs=[],
266
  queue=False
267
  )
268
 
269
  downvote_btn.click(
270
+ lambda session_id: update_vote("downvote", session_id),
271
+ inputs=state_session_id,
 
 
 
 
 
 
 
 
 
 
 
272
  outputs=[],
273
  queue=False
274
  )
275
 
276
  return demo
277
 
 
278
  if __name__ == "__main__":
279
  demo = build_demo(embed_mode=False)
280
  demo.queue(api_open=False).launch(
 
282
  server_port=7860,
283
  ssr_mode=False,
284
  debug=True,
285
+ )