selamw commited on
Commit
33ab70a
β€’
1 Parent(s): d083d25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -141
app.py CHANGED
@@ -5,12 +5,9 @@ import spaces
5
  import torch
6
  import os
7
 
8
- from transformers import AutoProcessor, AutoModelForCausalLM
9
-
10
-
11
 
12
  access_token = os.getenv('HF_token')
13
- model_id = "selamw/BirdWatcher2"
14
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
15
 
16
 
@@ -44,31 +41,13 @@ def convert_to_markdown(input_text):
44
 
45
  @spaces.GPU
46
  def infer_fin_pali(image, question):
47
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
-
49
- # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
50
- # processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
51
-
52
-
53
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
54
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
55
 
56
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, trust_remote_code=True, quantization_config=bnb_config,token=access_token).to(device)
57
- processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True, token=access_token)
58
- ###
59
 
60
- # model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
61
- # processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
62
-
63
- # prompt = "<OD>"
64
 
65
- # url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true"
66
- # image = Image.open(requests.get(url, stream=True).raw)
67
-
68
- inputs = processor(text=question, images=image, return_tensors="pt").to(device, torch_dtype)
69
-
70
- ######
71
- # inputs = processor(images=image, text=question, return_tensors="pt").to(device)
72
 
73
  predictions = model.generate(**inputs, max_new_tokens=512)
74
  decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
@@ -131,118 +110,4 @@ with gr.Blocks(css=css) as demo:
131
  label='Examples πŸ‘‡'
132
  )
133
 
134
- demo.launch(debug=True, share=True)
135
-
136
- # import gradio as gr
137
- # from PIL import Image
138
- # from transformers import BitsAndBytesConfig, PaliGemmaForConditionalGeneration, PaliGemmaProcessor
139
- # import spaces
140
- # import torch
141
- # import os
142
-
143
-
144
- # access_token = os.getenv('HF_token')
145
- # model_id = "selamw/BirdWatcher"
146
- # bnb_config = BitsAndBytesConfig(load_in_8bit=True)
147
-
148
-
149
- # def convert_to_markdown(input_text):
150
- # """Converts bird information text to Markdown format,
151
- # making specific keywords bold and adding headings.
152
- # Args:
153
- # input_text (str): The input text containing bird information.
154
- # Returns:
155
- # str: The formatted Markdown text.
156
- # """
157
-
158
- # bold_words = ['Look:', 'Cool Fact!:', 'Habitat:', 'Food:', 'Birdie Behaviors:']
159
-
160
- # # Split into title and content based on the first ":", handling extra whitespace
161
- # if ":" in input_text:
162
- # title, content = map(str.strip, input_text.split(":", 1))
163
- # else:
164
- # title = input_text
165
- # content = ""
166
-
167
- # # Bold the keywords
168
- # for word in bold_words:
169
- # content = content.replace(word, f'\n\n**{word}')
170
-
171
- # # Construct the Markdown output with headings
172
- # formatted_output = f"**{title}**{content}"
173
-
174
- # return formatted_output.strip()
175
-
176
-
177
- # @spaces.GPU
178
- # def infer_fin_pali(image, question):
179
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180
-
181
- # model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
182
- # processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
183
-
184
-
185
- # inputs = processor(images=image, text=question, return_tensors="pt").to(device)
186
-
187
- # predictions = model.generate(**inputs, max_new_tokens=512)
188
- # decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
189
-
190
- # # Ensure proper Markdown formatting
191
- # formatted_output = convert_to_markdown(decoded_output)
192
-
193
- # return formatted_output
194
-
195
-
196
- # css = """
197
- # #mkd {
198
- # height: 500px;
199
- # overflow: auto;
200
- # border: 1px solid #ccc;
201
- # }
202
- # h1 {
203
- # text-align: center;
204
- # }
205
- # h3 {
206
- # text-align: center;
207
- # }
208
- # h2 {
209
- # text-align: center;
210
- # }
211
- # span.gray-text {
212
- # color: gray;
213
- # }
214
- # """
215
-
216
- # with gr.Blocks(css=css) as demo:
217
- # gr.HTML("<h1>🦩 BirdWatcher 🦜</h1>")
218
- # gr.HTML("<h3>[Powered by Fine-tuned PaliGemma]</h3>")
219
- # gr.HTML("<h3>Upload an image of a bird, and the model will generate a detailed description of its species.</h3>")
220
- # gr.HTML("<p style='text-align: center;'>(There are over 11,000 bird species in the world, and this model was fine-tuned with over 500)</p>")
221
-
222
- # with gr.Tab(label="Bird Identification"):
223
- # with gr.Row():
224
- # input_img = gr.Image(label="Input Bird Image")
225
- # with gr.Column():
226
- # with gr.Row():
227
- # question = gr.Text(label="Default Prompt", value="Describe this bird species", elem_id="default-prompt", interactive=True)
228
- # with gr.Row():
229
- # submit_btn = gr.Button(value="Run")
230
- # with gr.Row():
231
- # output = gr.Markdown(label="Response") # Use Markdown component to display output
232
-
233
- # submit_btn.click(infer_fin_pali, [input_img, question], [output])
234
-
235
- # gr.Examples(
236
- # [["01.jpg", "Describe this bird species"],
237
- # ["02.jpg", "Describe this bird species"],
238
- # ["03.jpg", "Describe this bird species"],
239
- # ["04.jpg", "Describe this bird species"],
240
- # ["05.jpg", "Describe this bird species"],
241
- # ["06.jpg", "Describe this bird species"]],
242
- # inputs=[input_img, question],
243
- # outputs=[output],
244
- # fn=infer_fin_pali,
245
- # label='Examples πŸ‘‡'
246
- # )
247
-
248
- # demo.launch(debug=True, share=True)
 
5
  import torch
6
  import os
7
 
 
 
 
8
 
9
  access_token = os.getenv('HF_token')
10
+ model_id = "selamw/BirdWatcher"
11
  bnb_config = BitsAndBytesConfig(load_in_8bit=True)
12
 
13
 
 
41
 
42
  @spaces.GPU
43
  def infer_fin_pali(image, question):
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
45
 
46
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, token=access_token)
47
+ processor = PaliGemmaProcessor.from_pretrained(model_id, token=access_token)
 
48
 
 
 
 
 
49
 
50
+ inputs = processor(images=image, text=question, return_tensors="pt").to(device)
 
 
 
 
 
 
51
 
52
  predictions = model.generate(**inputs, max_new_tokens=512)
53
  decoded_output = processor.decode(predictions[0], skip_special_tokens=True)[len(question):].lstrip("\n")
 
110
  label='Examples πŸ‘‡'
111
  )
112
 
113
+ demo.launch(debug=True, share=True)