h-siyuan commited on
Commit
3bd6fba
·
verified ·
1 Parent(s): f0783a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -32
app.py CHANGED
@@ -9,9 +9,10 @@ from qwen_vl_utils import process_vision_info
9
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
10
  import ast
11
  import os
12
- from datetime import datetime
13
  import numpy as np
14
  from huggingface_hub import hf_hub_download, list_repo_files
 
 
15
 
16
  # Define constants
17
  DESCRIPTION = "[ShowUI Demo](https://huggingface.co/showlab/ShowUI-2B)"
@@ -35,8 +36,7 @@ for file in files:
35
  print(f"Downloaded {file} to {file_path}")
36
 
37
  model = Qwen2VLForConditionalGeneration.from_pretrained(
38
- "./showui-2b",
39
- # "showlab/ShowUI-2B",
40
  torch_dtype=torch.bfloat16,
41
  device_map="cpu",
42
  )
@@ -57,20 +57,43 @@ def draw_point(image_input, point=None, radius=5):
57
  ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
58
  return image
59
 
60
- def array_to_image_path(image_array):
61
  """Save the uploaded image and return its path."""
62
  if image_array is None:
63
  raise ValueError("No image provided. Please upload an image before submitting.")
64
  img = Image.fromarray(np.uint8(image_array))
65
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
66
- filename = f"image_{timestamp}.png"
67
  img.save(filename)
68
  return os.path.abspath(filename)
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  @spaces.GPU
71
  def run_showui(image, query):
72
  """Main function for inference."""
73
- image_path = array_to_image_path(image)
 
 
 
 
74
 
75
  messages = [
76
  {
@@ -84,9 +107,7 @@ def run_showui(image, query):
84
  ]
85
 
86
  # Prepare inputs for the model
87
-
88
  global model
89
-
90
  model = model.to("cuda")
91
 
92
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
@@ -114,11 +135,11 @@ def run_showui(image, query):
114
 
115
  # Draw the point on the image
116
  result_image = draw_point(image_path, click_xy, radius=10)
117
- return result_image, str(click_xy)
118
 
119
- # Function to record votes
120
- def record_vote(vote_type, image_path, query, action_generated):
121
- """Record a vote in a JSON file."""
122
  vote_data = {
123
  "vote_type": vote_type,
124
  "image_path": image_path,
@@ -126,27 +147,35 @@ def record_vote(vote_type, image_path, query, action_generated):
126
  "action_generated": action_generated,
127
  "timestamp": datetime.now().isoformat()
128
  }
129
- with open("votes.json", "a") as f:
 
 
 
 
130
  f.write(json.dumps(vote_data) + "\n")
 
 
 
 
131
  return f"Your {vote_type} has been recorded. Thank you!"
132
 
133
- # Helper function to handle vote recording
134
- def handle_vote(vote_type, image_path, query, action_generated):
135
  """Handle vote recording by using the consistent image path."""
136
  if image_path is None:
137
  return "No image uploaded. Please upload an image before voting."
138
- return record_vote(vote_type, image_path, query, action_generated)
139
 
140
  # Load logo and encode to Base64
141
  with open("./assets/showui.png", "rb") as image_file:
142
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
143
 
144
-
145
  # Define layout and UI
146
  def build_demo(embed_mode, concurrency_count=1):
147
  with gr.Blocks(title="ShowUI Demo", theme=gr.themes.Default()) as demo:
148
  # State to store the consistent image path
149
  state_image_path = gr.State(value=None)
 
150
 
151
  if not embed_mode:
152
  gr.HTML(
@@ -230,19 +259,35 @@ def build_demo(embed_mode, concurrency_count=1):
230
  raise ValueError("No image provided. Please upload an image before submitting.")
231
 
232
  # Generate consistent image path and store it in the state
233
- image_path = array_to_image_path(image)
234
- return run_showui(image, query) + (image_path,)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  submit_btn.click(
237
  on_submit,
238
  [imagebox, textbox],
239
- [output_img, output_coords, state_image_path],
240
  )
241
 
242
  clear_btn.click(
243
  lambda: (None, None, None, None, None),
244
  inputs=None,
245
- outputs=[imagebox, textbox, output_img, output_coords, state_image_path], # Clear all outputs
246
  queue=False
247
  )
248
 
@@ -254,33 +299,34 @@ def build_demo(embed_mode, concurrency_count=1):
254
 
255
  # Record vote actions without feedback messages
256
  vote_btn.click(
257
- lambda image_path, query, action_generated: handle_vote(
258
- "upvote", image_path, query, action_generated
259
  ),
260
- inputs=[state_image_path, textbox, output_coords],
261
  outputs=[],
262
  queue=False
263
  )
264
 
265
  downvote_btn.click(
266
- lambda image_path, query, action_generated: handle_vote(
267
- "downvote", image_path, query, action_generated
268
  ),
269
- inputs=[state_image_path, textbox, output_coords],
270
  outputs=[],
271
  queue=False
272
  )
273
 
274
  flag_btn.click(
275
- lambda image_path, query, action_generated: handle_vote(
276
- "flag", image_path, query, action_generated
277
  ),
278
- inputs=[state_image_path, textbox, output_coords],
279
  outputs=[],
280
  queue=False
281
  )
282
 
283
  return demo
 
284
  # Launch the app
285
  if __name__ == "__main__":
286
  demo = build_demo(embed_mode=False)
@@ -289,4 +335,4 @@ if __name__ == "__main__":
289
  server_port=7860,
290
  ssr_mode=False,
291
  debug=True,
292
- )
 
9
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
10
  import ast
11
  import os
 
12
  import numpy as np
13
  from huggingface_hub import hf_hub_download, list_repo_files
14
+ import boto3
15
+ from botocore.exceptions import NoCredentialsError
16
 
17
  # Define constants
18
  DESCRIPTION = "[ShowUI Demo](https://huggingface.co/showlab/ShowUI-2B)"
 
36
  print(f"Downloaded {file} to {file_path}")
37
 
38
  model = Qwen2VLForConditionalGeneration.from_pretrained(
39
+ destination_folder,
 
40
  torch_dtype=torch.bfloat16,
41
  device_map="cpu",
42
  )
 
57
  ImageDraw.Draw(image).ellipse((x - radius, y - radius, x + radius, y + radius), fill='red')
58
  return image
59
 
60
+ def array_to_image_path(image_array, session_id):
61
  """Save the uploaded image and return its path."""
62
  if image_array is None:
63
  raise ValueError("No image provided. Please upload an image before submitting.")
64
  img = Image.fromarray(np.uint8(image_array))
65
+ filename = f"image_{session_id}.png"
 
66
  img.save(filename)
67
  return os.path.abspath(filename)
68
 
69
+ # Function to upload the file to S3
70
+ def upload_to_s3(file_name, bucket, object_name=None):
71
+ """Upload a file to an S3 bucket."""
72
+ if object_name is None:
73
+ object_name = file_name
74
+
75
+ # Create an S3 client
76
+ s3 = boto3.client('s3')
77
+
78
+ try:
79
+ s3.upload_file(file_name, bucket, object_name)
80
+ print(f"Uploaded {file_name} to {bucket}/{object_name}.")
81
+ return True
82
+ except FileNotFoundError:
83
+ print(f"The file {file_name} was not found.")
84
+ return False
85
+ except NoCredentialsError:
86
+ print("Credentials not available.")
87
+ return False
88
+
89
  @spaces.GPU
90
  def run_showui(image, query):
91
  """Main function for inference."""
92
+ session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
93
+ image_path = array_to_image_path(image, session_id)
94
+
95
+ # Upload the image to S3
96
+ upload_to_s3(image_path, 'altair.storage', object_name=f"ootb/images/{os.path.basename(image_path)}")
97
 
98
  messages = [
99
  {
 
107
  ]
108
 
109
  # Prepare inputs for the model
 
110
  global model
 
111
  model = model.to("cuda")
112
 
113
  text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
135
 
136
  # Draw the point on the image
137
  result_image = draw_point(image_path, click_xy, radius=10)
138
+ return result_image, str(click_xy), session_id
139
 
140
+ # Modify the record_vote function
141
+ def record_vote(vote_type, image_path, query, action_generated, session_id):
142
+ """Record a vote in a JSON file and upload to S3."""
143
  vote_data = {
144
  "vote_type": vote_type,
145
  "image_path": image_path,
 
147
  "action_generated": action_generated,
148
  "timestamp": datetime.now().isoformat()
149
  }
150
+
151
+ local_file_name = f"votes_{session_id}.json"
152
+
153
+ # Append vote data to the local JSON file
154
+ with open(local_file_name, "a") as f:
155
  f.write(json.dumps(vote_data) + "\n")
156
+
157
+ # Upload the updated JSON file to S3
158
+ upload_to_s3(local_file_name, 'altair.storage', object_name=f"ootb/votes/{local_file_name}")
159
+
160
  return f"Your {vote_type} has been recorded. Thank you!"
161
 
162
+ # Use session_id in the handle_vote function
163
+ def handle_vote(vote_type, image_path, query, action_generated, session_id):
164
  """Handle vote recording by using the consistent image path."""
165
  if image_path is None:
166
  return "No image uploaded. Please upload an image before voting."
167
+ return record_vote(vote_type, image_path, query, action_generated, session_id)
168
 
169
  # Load logo and encode to Base64
170
  with open("./assets/showui.png", "rb") as image_file:
171
  base64_image = base64.b64encode(image_file.read()).decode("utf-8")
172
 
 
173
  # Define layout and UI
174
  def build_demo(embed_mode, concurrency_count=1):
175
  with gr.Blocks(title="ShowUI Demo", theme=gr.themes.Default()) as demo:
176
  # State to store the consistent image path
177
  state_image_path = gr.State(value=None)
178
+ state_session_id = gr.State(value=None)
179
 
180
  if not embed_mode:
181
  gr.HTML(
 
259
  raise ValueError("No image provided. Please upload an image before submitting.")
260
 
261
  # Generate consistent image path and store it in the state
262
+ result_image, click_coords, session_id = run_showui(image, query)
263
+ return result_image, click_coords, session_id
264
+
265
+ def on_image_upload(image):
266
+ """Generate a new session ID when a new image is uploaded."""
267
+ if image is None:
268
+ raise ValueError("No image provided. Please upload an image.")
269
+
270
+ # Generate a new session ID
271
+ new_session_id = datetime.now().strftime("%Y%m%d_%H%M%S")
272
+ return new_session_id
273
+
274
+ imagebox.upload(
275
+ on_image_upload,
276
+ inputs=imagebox,
277
+ outputs=state_session_id,
278
+ queue=False
279
+ )
280
 
281
  submit_btn.click(
282
  on_submit,
283
  [imagebox, textbox],
284
+ [output_img, output_coords, state_session_id],
285
  )
286
 
287
  clear_btn.click(
288
  lambda: (None, None, None, None, None),
289
  inputs=None,
290
+ outputs=[imagebox, textbox, output_img, output_coords, state_session_id], # Clear all outputs
291
  queue=False
292
  )
293
 
 
299
 
300
  # Record vote actions without feedback messages
301
  vote_btn.click(
302
+ lambda image_path, query, action_generated, session_id: handle_vote(
303
+ "upvote", image_path, query, action_generated, session_id
304
  ),
305
+ inputs=[state_image_path, textbox, output_coords, state_session_id],
306
  outputs=[],
307
  queue=False
308
  )
309
 
310
  downvote_btn.click(
311
+ lambda image_path, query, action_generated, session_id: handle_vote(
312
+ "downvote", image_path, query, action_generated, session_id
313
  ),
314
+ inputs=[state_image_path, textbox, output_coords, state_session_id],
315
  outputs=[],
316
  queue=False
317
  )
318
 
319
  flag_btn.click(
320
+ lambda image_path, query, action_generated, session_id: handle_vote(
321
+ "flag", image_path, query, action_generated, session_id
322
  ),
323
+ inputs=[state_image_path, textbox, output_coords, state_session_id],
324
  outputs=[],
325
  queue=False
326
  )
327
 
328
  return demo
329
+
330
  # Launch the app
331
  if __name__ == "__main__":
332
  demo = build_demo(embed_mode=False)
 
335
  server_port=7860,
336
  ssr_mode=False,
337
  debug=True,
338
+ )