eswardivi's picture
Update app.py
46d2261 verified
import gradio as gr
import torch
from transformers import (
AutoModel,
AutoTokenizer,
)
import os
from threading import Thread
# import spaces
import time
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
device = torch.device("cpu")
print("Using CPU")
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
input_mask_expanded.sum(1), min=1e-9
)
def cls_pooling(model_output):
return model_output[0][:, 0]
# @spaces.GPU
def get_embedding(text, use_mean_pooling, model_id):
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id, torch_dtype=torch.float16)
model = model.to(device)
inputs = tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=512
)
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
with torch.no_grad():
model_output = model(**inputs)
if use_mean_pooling:
return mean_pooling(model_output, inputs["attention_mask"])
return cls_pooling(model_output)
def get_similarity(text1, text2, pooling_method, model_id):
use_mean_pooling = pooling_method == "Use Mean Pooling"
embedding1 = get_embedding(text1, use_mean_pooling, model_id)
embedding2 = get_embedding(text2, use_mean_pooling, model_id)
return torch.nn.functional.cosine_similarity(embedding1, embedding2).item()
gr.Interface(
get_similarity,
[
gr.Textbox(lines=7, label="Text 1"),
gr.Textbox(lines=7, label="Text 2"),
gr.Dropdown(
choices=["Use Mean Pooling", "Use CLS"],
value="Use Mean Pooling",
label="Pooling Method",
info="Mean Pooling: Averages all token embeddings (better for semantic similarity)\nCLS Pooling: Uses only the [CLS] token embedding (faster, might miss context)",
),
gr.Dropdown(
choices=[
"nomic-ai/modernbert-embed-base",
"tasksource/ModernBERT-base-embed",
"tasksource/ModernBERT-base-nli",
"joe32140/ModernBERT-large-msmarco",
"answerdotai/ModernBERT-large",
"answerdotai/ModernBERT-base",
],
value="answerdotai/ModernBERT-large",
label="Model",
info="Choose between the variants of ModernBERT \nMight take a few seconds to load the model",
),
],
gr.Textbox(label="Similarity"),
title="ModernBERT Similarity Demo",
description="Compute the similarity between two texts using ModernBERT. Choose between different pooling strategies for embedding generation.",
examples=[
[
"The quick brown fox jumps over the lazy dog",
"A swift brown fox leaps above a sleeping canine",
"Use Mean Pooling",
"answerdotai/ModernBERT-large",
],
[
"I love programming in Python",
"I hate coding with Python",
"Use Mean Pooling",
"joe32140/ModernBERT-large-msmarco",
],
[
"The weather is beautiful today",
"Machine learning models are improving rapidly",
"Use Mean Pooling",
"tasksource/ModernBERT-base-embed",
],
[
"def calculate_sum(a, b):\n return a + b",
"def add_numbers(x, y):\n result = x + y\n return result",
"Use Mean Pooling",
"tasksource/ModernBERT-base-nli",
],
],
).launch()