Spaces:
Running
Running
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import open_clip | |
from datasets import Dataset | |
import os | |
# Set environment variable to work around OpenMP runtime issue | |
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
# Load the model and processor | |
model, processor = open_clip.create_model_from_pretrained('hf-hub:imageomics/bioclip') | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
# Load the dataset | |
embedding_path = "./data/embeddings_bioclip_False" | |
ds = Dataset.load_from_disk(embedding_path) | |
# Load FAISS indexes | |
cosine_faiss_path = os.path.join(embedding_path, "embeddings_cosine.faiss") | |
l2_faiss_path = os.path.join(embedding_path, "embeddings_l2.faiss") | |
ds.load_faiss_index("embeddings_cosine", cosine_faiss_path) | |
ds.load_faiss_index("embeddings_l2", l2_faiss_path) | |
def majority_vote(classes, scores=None): | |
if scores is None: | |
scores = np.ones_like(classes) | |
unique_classes, class_counts = np.unique(classes, return_counts=True) | |
class_weights = {cls: 0 for cls in unique_classes} | |
for cls, weight in zip(classes, scores): | |
class_weights[cls] += weight | |
majority_class = max(class_weights, key=class_weights.get) | |
return majority_class | |
def classify_example(example, index="embeddings_l2", k=10, vote_scores=True): | |
features = np.array(example["embeddings"], dtype=np.float32) | |
scores, nearest = ds.get_nearest_examples(index, features, k) | |
class_labels = [ds.features["label"].names[c] for c in nearest["label"]] | |
if vote_scores: | |
prediction = majority_vote(class_labels, scores) | |
else: | |
prediction = majority_vote(class_labels) | |
return prediction, class_labels, nearest["file"] | |
def embed_image(image: Image.Image): | |
processed_images = processor(image).unsqueeze(0) | |
with torch.no_grad(): | |
embeddings = model.encode_image(processed_images.to(device)) | |
return {"embeddings": embeddings.cpu()} | |
def predict(image): | |
embedding = embed_image(image) | |
prediction, class_labels, file_paths = classify_example(embedding) | |
return prediction, ", ".join(class_labels[:3]), ", ".join(file_paths[:3]) | |
iface = gr.Interface( | |
fn=predict, | |
inputs=gr.Image(type="pil"), | |
outputs=[ | |
gr.Textbox(label="Prediction"), | |
gr.Textbox(label="Top 3 Classes"), | |
gr.Textbox(label="Top 3 File Paths") | |
], | |
title="BioClip Image Classification", | |
description="Upload an image to get a prediction using the BioClip model." | |
) | |
iface.launch() |