DDingcheol commited on
Commit
7063ff1
ยท
1 Parent(s): 7172545

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -73
app.py CHANGED
@@ -1,78 +1,43 @@
1
  #ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
2
 
 
3
  import torch
4
- from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
5
- from datasets import load_dataset
6
- from collections import defaultdict
7
-
8
- # ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
9
- dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
10
- dataset = dataset_load['train'].select(range(300))
11
-
12
- # ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์„ ํƒ
13
- selected_features = ['image', 'answers', 'question']
14
- selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})
15
-
16
- # ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ
17
- answers_to_id = defaultdict(lambda: len(answers_to_id))
18
- selected_dataset = selected_dataset.map(lambda ex: {
19
- 'answers': [answers_to_id[ans] for ans in ex['answers']],
20
- 'question': ex['question'],
21
- 'image': ex['image']
22
- })
23
-
24
- id_to_answers = {v: k for k, v in answers_to_id.items()}
25
- id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
26
-
27
- selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
28
- 'question': ex['question'],
29
- 'image': ex['image']})
30
-
31
- flattened_features = []
32
-
33
- for ex in selected_dataset:
34
- flattened_example = {
35
- 'answers': ex['answers'],
36
- 'question': ex['question'],
37
- 'image': ex['image'],
38
- }
39
- flattened_features.append(flattened_example)
40
-
41
- # ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
42
- from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
43
-
44
- model_name = 'microsoft/git-base-vqav2'
45
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
46
-
47
- # Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ํ•™์Šต
48
- tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
49
-
50
- def preprocess_function(examples):
51
- tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
52
- return {
53
- 'input_ids': tokenized_inputs['input_ids'],
54
- 'attention_mask': tokenized_inputs['attention_mask'],
55
- 'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
56
- 'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
57
- 'labels': [[label] for label in examples['answers']]
58
- }
59
-
60
- dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
61
- ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
62
- ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
63
-
64
- training_args = TrainingArguments(
65
- output_dir='./results',
66
- num_train_epochs=20,
67
- per_device_train_batch_size=4,
68
- logging_steps=500,
69
- )
70
-
71
- trainer = Trainer(
72
- model=model,
73
- args=training_args,
74
- train_dataset=ok_vqa_dataset
75
  )
76
 
77
- # ๋ชจ๋ธ ํ•™์Šต
78
- trainer.train()
 
1
  #ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
2
 
3
+ import gradio as gr
4
  import torch
5
+ from transformers import BertTokenizer, BertForSequenceClassification
6
+
7
+ # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ๊ฐ€์ค‘์น˜ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
8
+ model_name = 'microsoft/git-base-vqav2' # ์‚ฌ์šฉํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„
9
+ model = BertForSequenceClassification.from_pretrained(model_name)
10
+ tokenizer = BertTokenizer.from_pretrained(model_name)
11
+
12
+ # ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
13
+ def predict_answer(image, question):
14
+ inputs = tokenizer(question, return_tensors='pt')
15
+ input_ids = inputs['input_ids']
16
+ attention_mask = inputs['attention_mask']
17
+
18
+ # ์ด๋ฏธ์ง€์™€ ๊ด€๋ จ๋œ ์ฒ˜๋ฆฌ ์ˆ˜ํ–‰
19
+ # ์ด๋ฏธ์ง€ ์ฒ˜๋ฆฌ ์ฝ”๋“œ๋ฅผ ์—ฌ๊ธฐ์— ์ถ”๊ฐ€ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค (์ž…๋ ฅ๋œ ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ์ „์ฒ˜๋ฆฌ ๋“ฑ)
20
+
21
+ # ๋ชจ๋ธ์— ์ž…๋ ฅ ๋ฐ์ดํ„ฐ๋ฅผ ์ „๋‹ฌํ•˜์—ฌ ์˜ˆ์ธก ์ˆ˜ํ–‰
22
+ with torch.no_grad():
23
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
24
+
25
+ # ์˜ˆ์ธก ๊ฒฐ๊ณผ์—์„œ ๊ฐ€์žฅ ๋†’์€ ํ™•๋ฅ ์„ ๊ฐ€์ง„ ๋ ˆ์ด๋ธ” ID ๊ฐ€์ ธ์˜ค๊ธฐ
26
+ predicted_label_id = torch.argmax(outputs.logits).item()
27
+ predicted_label = id_to_label_fn(predicted_label_id)
28
+
29
+ return predicted_label
30
+
31
+ iface = gr.Interface(
32
+ fn=predict_answer,
33
+ inputs=["image", "text"],
34
+ outputs="text",
35
+ title="Visual Question Answering",
36
+ description="Input an image and a question to get the model's answer.",
37
+ example=[
38
+ "https://your_image_url.jpg",
39
+ "What is shown in the image?"
40
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
+ iface.launch()