DDingcheol commited on
Commit
b28441d
ยท
1 Parent(s): 7063ff1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py CHANGED
@@ -1,4 +1,81 @@
1
  #ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  import gradio as gr
4
  import torch
 
1
  #ํ—ˆ๊น…ํŽ˜์ด์Šค์—์„œ ๋Œ์•„๊ฐˆ ์ˆ˜ ์žˆ๋„๋ก ๋ฐ”๊พธ์–ด ๋ณด์•˜์Œ
2
+ import torch
3
+ from transformers import BertTokenizerFast, BertForQuestionAnswering, Trainer, TrainingArguments
4
+ from datasets import load_dataset
5
+ from collections import defaultdict
6
+
7
+ # ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
8
+ dataset_load = load_dataset('Multimodal-Fatima/OK-VQA_train')
9
+ dataset = dataset_load['train'].select(range(300))
10
+
11
+ # ๋ถˆํ•„์š”ํ•œ ํŠน์„ฑ ์„ ํƒ
12
+ selected_features = ['image', 'answers', 'question']
13
+ selected_dataset = dataset.map(lambda ex: {feature: ex[feature] for feature in selected_features})
14
+
15
+ # ์†Œํ”„ํŠธ ์ธ์ฝ”๋”ฉ
16
+ answers_to_id = defaultdict(lambda: len(answers_to_id))
17
+ selected_dataset = selected_dataset.map(lambda ex: {
18
+ 'answers': [answers_to_id[ans] for ans in ex['answers']],
19
+ 'question': ex['question'],
20
+ 'image': ex['image']
21
+ })
22
+
23
+ id_to_answers = {v: k for k, v in answers_to_id.items()}
24
+ id_to_labels = {k: ex['answers'] for k, ex in enumerate(selected_dataset)}
25
+
26
+ selected_dataset = selected_dataset.map(lambda ex: {'answers': id_to_labels.get(ex['answers'][0]),
27
+ 'question': ex['question'],
28
+ 'image': ex['image']})
29
+
30
+ flattened_features = []
31
+
32
+ for ex in selected_dataset:
33
+ flattened_example = {
34
+ 'answers': ex['answers'],
35
+ 'question': ex['question'],
36
+ 'image': ex['image'],
37
+ }
38
+ flattened_features.append(flattened_example)
39
+
40
+ # ๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ
41
+ from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
42
+
43
+ model_name = 'microsoft/git-base-vqav2'
44
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
45
+
46
+ # Trainer๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ํ•™์Šต
47
+ tokenizer = BertTokenizerFast.from_pretrained('bert-base-multilingual-cased')
48
+
49
+ def preprocess_function(examples):
50
+ tokenized_inputs = tokenizer(examples['question'], truncation=True, padding=True)
51
+ return {
52
+ 'input_ids': tokenized_inputs['input_ids'],
53
+ 'attention_mask': tokenized_inputs['attention_mask'],
54
+ 'pixel_values': [(4, 3, 244, 244)] * len(tokenized_inputs['input_ids']),
55
+ 'pixel_mask': [1] * len(tokenized_inputs['input_ids']),
56
+ 'labels': [[label] for label in examples['answers']]
57
+ }
58
+
59
+ dataset = load_dataset("Multimodal-Fatima/OK-VQA_train")['train'].select(range(300))
60
+ ok_vqa_dataset = dataset.map(preprocess_function, batched=True)
61
+ ok_vqa_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'pixel_values', 'pixel_mask', 'labels'])
62
+
63
+ training_args = TrainingArguments(
64
+ output_dir='./results',
65
+ num_train_epochs=20,
66
+ per_device_train_batch_size=4,
67
+ logging_steps=500,
68
+ )
69
+
70
+ trainer = Trainer(
71
+ model=model,
72
+ args=training_args,
73
+ train_dataset=ok_vqa_dataset
74
+ )
75
+
76
+ # ๋ชจ๋ธ ํ•™์Šต
77
+ trainer.train()
78
+
79
 
80
  import gradio as gr
81
  import torch