|
import gradio as gr |
|
import json |
|
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer |
|
from datasets import Dataset |
|
import shutil |
|
import os |
|
|
|
|
|
def load_data(tokenized_file): |
|
with open(tokenized_file.name, 'r') as f: |
|
tokenized_data = json.load(f) |
|
return tokenized_data |
|
|
|
|
|
def fine_tune_model(tokenized_file, progress=gr.Progress()): |
|
tokenized_data = load_data(tokenized_file) |
|
|
|
|
|
dataset = Dataset.from_dict(tokenized_data) |
|
|
|
|
|
tokenized_datasets = dataset.train_test_split(test_size=0.2) |
|
|
|
model = AutoModelForSequenceClassification.from_pretrained('anferico/bert-for-patents', num_labels=2) |
|
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents') |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=3, |
|
per_device_train_batch_size=8, |
|
per_device_eval_batch_size=16, |
|
gradient_accumulation_steps=2, |
|
fp16=True, |
|
warmup_steps=500, |
|
weight_decay=0.01, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
evaluation_strategy="epoch", |
|
save_strategy="epoch", |
|
load_best_model_at_end=True, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_datasets['train'], |
|
eval_dataset=tokenized_datasets['test'], |
|
) |
|
|
|
progress(0.5, "Fine-tuning the model...") |
|
trainer.train() |
|
progress(1.0, "Fine-tuning complete.") |
|
|
|
model.save_pretrained('./fine_tuned_patentbert') |
|
tokenizer.save_pretrained('./fine_tuned_patentbert') |
|
|
|
|
|
shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert') |
|
|
|
return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip" |
|
|
|
|
|
iface = gr.Interface( |
|
fn=fine_tune_model, |
|
inputs=[ |
|
gr.File(label="Upload Tokenized Data JSON") |
|
], |
|
outputs=[ |
|
gr.Textbox(label="Processing Information"), |
|
gr.File(label="Download Fine-Tuned Model") |
|
], |
|
title="Fine-Tune Patent BERT Model", |
|
description="Upload tokenized JSON file to fine-tune the BERT model.", |
|
live=True |
|
) |
|
|
|
|
|
iface.launch() |
|
|