Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation | |
from PIL import Image | |
# Cargar el modelo y el preprocesador | |
device = torch.device("cpu") | |
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device) | |
model.eval() | |
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade") | |
# Funci贸n de consulta para Gradio | |
def query_image(img): | |
# Procesar la imagen con el preprocesador | |
inputs = preprocessor(images=img, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Obtener la m谩scara de segmentaci贸n (aseg煤rate de que esta l贸gica coincida con tu configuraci贸n) | |
mask = torch.argmax(outputs.logits[0], dim=0).cpu().detach().numpy() | |
# Crear una m谩scara binaria solo para la clase de "regla" (de acuerdo a tu c贸digo original) | |
rule_class_id = 1 # ID de la clase "regla" | |
rule_mask = (mask == rule_class_id).astype(np.uint8) | |
# Crear una imagen RGB para visualizar la m谩scara | |
mask_image = np.stack([rule_mask] * 3, axis=-1) | |
return Image.fromarray((mask_image * 255).astype(np.uint8)) | |
# Crear la interfaz Gradio | |
demo = gr.Interface( | |
query_image, | |
inputs=[gr.Image()], | |
outputs="image", | |
title="Rule Segmentation Demo", | |
description="Please upload an image to see rule segmentation", | |
) | |
# Lanzar la interfaz Gradio | |
demo.launch() | |