imgSegmentation / app.py
Margaritamawyin's picture
Create app.py
43cf32d
import gradio as gr
import torch
import numpy as np
from transformers import MaskFormerFeatureExtractor, MaskFormerForInstanceSegmentation
from PIL import Image
# Cargar el modelo y el preprocesador
device = torch.device("cpu")
model = MaskFormerForInstanceSegmentation.from_pretrained("facebook/maskformer-swin-tiny-ade").to(device)
model.eval()
preprocessor = MaskFormerFeatureExtractor.from_pretrained("facebook/maskformer-swin-tiny-ade")
# Funci贸n de consulta para Gradio
def query_image(img):
# Procesar la imagen con el preprocesador
inputs = preprocessor(images=img, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Obtener la m谩scara de segmentaci贸n (aseg煤rate de que esta l贸gica coincida con tu configuraci贸n)
mask = torch.argmax(outputs.logits[0], dim=0).cpu().detach().numpy()
# Crear una m谩scara binaria solo para la clase de "regla" (de acuerdo a tu c贸digo original)
rule_class_id = 1 # ID de la clase "regla"
rule_mask = (mask == rule_class_id).astype(np.uint8)
# Crear una imagen RGB para visualizar la m谩scara
mask_image = np.stack([rule_mask] * 3, axis=-1)
return Image.fromarray((mask_image * 255).astype(np.uint8))
# Crear la interfaz Gradio
demo = gr.Interface(
query_image,
inputs=[gr.Image()],
outputs="image",
title="Rule Segmentation Demo",
description="Please upload an image to see rule segmentation",
)
# Lanzar la interfaz Gradio
demo.launch()