First image captioning model for russian language vit-rugpt2-image-captioning
This is an image captioning model trained on translated version (en-ru) of dataset COCO2014.
Model Details
Model was initialized google/vit-base-patch16-224-in21k
for encoder and sberbank-ai/rugpt3large_based_on_gpt2
for decoder.
Metrics on test data
- Bleu: 8.672
- Bleu precision 1: 30.567
- Bleu precision 2: 7.895
- Bleu precision 3: 3.261
Sample running code
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
import torch
from PIL import Image
model = VisionEncoderDecoderModel.from_pretrained("vit-rugpt2-image-captioning")
feature_extractor = ViTFeatureExtractor.from_pretrained("vit-rugpt2-image-captioning")
tokenizer = AutoTokenizer.from_pretrained("vit-rugpt2-image-captioning")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
max_length = 16
num_beams = 4
gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
def predict_caption(image_paths):
images = []
for image_path in image_paths:
i_image = Image.open(image_path)
if i_image.mode != "RGB":
i_image = i_image.convert(mode="RGB")
images.append(i_image)
pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
output_ids = model.generate(pixel_values, **gen_kwargs)
preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
preds = [pred.strip() for pred in preds]
return preds
predict_caption(['train2014/COCO_train2014_000000295442.jpg']) # ['Самолет на взлетно-посадочной полосе аэропорта.']
Sample running code using transformers pipeline
from transformers import pipeline
image_to_text = pipeline("image-to-text", model="vit-rugpt2-image-captioning")
image_to_text("train2014/COCO_train2014_000000296754.jpg") # [{'generated_text': 'Человек идет по улице с зонтом.'}]
Contact for any help
- Downloads last month
- 317
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.