adamlu1 commited on
Commit
0375f07
·
1 Parent(s): 2c16cb7
Files changed (1) hide show
  1. app.py +14 -4
app.py CHANGED
@@ -13,8 +13,18 @@ from utils import check_ocr_box, get_yolo_model, get_caption_model_processor, ge
13
  import torch
14
  from PIL import Image
15
 
16
- yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
17
- caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
 
 
 
 
 
 
 
 
 
 
18
  platform = 'pc'
19
  if platform == 'pc':
20
  draw_bbox_config = {
@@ -51,10 +61,10 @@ MARKDOWN = """
51
  OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
52
  """
53
 
54
- DEVICE = torch.device('cuda')
55
 
56
  # @spaces.GPU
57
- # @torch.inference_mode()
58
  # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
59
  @spaces.GPU(duration=65)
60
  def process(
 
13
  import torch
14
  from PIL import Image
15
 
16
+ # yolo_model = get_yolo_model(model_path='weights/icon_detect/best.pt')
17
+ # caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence")
18
+
19
+ from ultralytics import YOLO
20
+ yolo_model = YOLO('weights/icon_detect/best.pt').to('cuda')
21
+ from transformers import AutoProcessor, AutoModelForCausalLM
22
+ processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
23
+ model = AutoModelForCausalLM.from_pretrained("weights/icon_caption_florence", torch_dtype=torch.float16, trust_remote_code=True).to('cuda')
24
+ caption_model_processor = {'processor': processor, 'model': model}
25
+ print('finish loading model!!!')
26
+
27
+
28
  platform = 'pc'
29
  if platform == 'pc':
30
  draw_bbox_config = {
 
61
  OmniParser is a screen parsing tool to convert general GUI screen to structured elements.
62
  """
63
 
64
+ # DEVICE = torch.device('cuda')
65
 
66
  # @spaces.GPU
67
+ @torch.inference_mode()
68
  # @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
69
  @spaces.GPU(duration=65)
70
  def process(