bhlewis's picture
Update app.py
ba4060a verified
import gradio as gr
import json
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer
from datasets import Dataset
import shutil
import os
# Load tokenized data
def load_data(tokenized_file):
with open(tokenized_file.name, 'r') as f:
tokenized_data = json.load(f)
return tokenized_data
# Fine-tune the model
def fine_tune_model(tokenized_file, progress=gr.Progress()):
tokenized_data = load_data(tokenized_file)
# Convert tokenized data to Dataset
dataset = Dataset.from_dict(tokenized_data)
# Split the dataset into train and validation sets
tokenized_datasets = dataset.train_test_split(test_size=0.2)
model = AutoModelForSequenceClassification.from_pretrained('anferico/bert-for-patents', num_labels=2)
tokenizer = AutoTokenizer.from_pretrained('anferico/bert-for-patents')
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=8, # Reduce batch size
per_device_eval_batch_size=16,
gradient_accumulation_steps=2, # Use gradient accumulation
fp16=True, # Enable mixed precision training
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['test'],
)
progress(0.5, "Fine-tuning the model...")
trainer.train()
progress(1.0, "Fine-tuning complete.")
model.save_pretrained('./fine_tuned_patentbert')
tokenizer.save_pretrained('./fine_tuned_patentbert')
# Create a zip file of the fine-tuned model
shutil.make_archive('fine_tuned_patentbert', 'zip', './fine_tuned_patentbert')
return "Model fine-tuned and saved successfully. Download the model using the link below.", "fine_tuned_patentbert.zip"
# Create Gradio interface
iface = gr.Interface(
fn=fine_tune_model,
inputs=[
gr.File(label="Upload Tokenized Data JSON")
],
outputs=[
gr.Textbox(label="Processing Information"),
gr.File(label="Download Fine-Tuned Model")
],
title="Fine-Tune Patent BERT Model",
description="Upload tokenized JSON file to fine-tune the BERT model.",
live=True # Enable live updates for progress
)
# Launch the interface
iface.launch()