import os while True: try: import cv2 except ImportError: print("Package cv2 not found. Attepting installation.") os.system("pip install -U opencv-python &> /dev/null") continue break import os, cv2, time, math print("=> Loading libraries...") start = time.time() import requests, torch, argparse import gradio as gr from torchvision import transforms from datasets import load_dataset from timm.data import create_transform from timm.models import create_model, load_checkpoint from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image parser = argparse.ArgumentParser() parser.add_argument("--local", action='store_true') args = parser.parse_args() if not args.local: print("=> Logging into huggingface...") from huggingface_hub import login login(token=os.environ["HF_TOKEN"]) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).") print("=> Loading model...") start = time.time() size = "b" img_size = 224 crop_pct = 0.9 IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) model = create_model(f"tpmlp_{size}").to(device) try: load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True) except FileNotFoundError: load_checkpoint(model, f"tpmlp_{size}.pth.tar", True) model.eval() response = requests.get("https://git.io/JJkYN") labels = response.text.split("\n") augs = create_transform( input_size=(3, 224, 224), is_training=False, use_prefetcher=False, crop_pct=0.9, ) scale_size = math.floor(img_size / crop_pct) resize = transforms.Compose([ transforms.Resize(scale_size), transforms.CenterCrop(img_size), transforms.ToTensor() ]) normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD)) def transform(img): img = resize(img.convert("RGB")) tensor = normalize(img) return img, tensor def predict(inp): img, inp = transform(inp) inp = inp.unsqueeze(0) with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam: grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True) # Here grayscale_cam has only one image in the batch grayscale_cam = grayscale_cam[0, :] probs = probs[0, :] cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) confidences = {labels[i]: float(probs[i]) for i in range(1000)} return confidences, cam_image print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).") base = "../example-imgs" if args.local else "." print("=> Loading examples.") indices = [ 0, # Coucal 2, # Volcano 7, # Sombrero 9, # Balance beam 10, # Sulphur-crested cockatoo 11, # Shower cap 12, # Petri dish INCORRECTLY CLASSIFIED as lens 14, # Angora rabbit ] ds = load_dataset("imagenet-1k", split="validation", streaming=True) examples = []; idx = 0 start = time.time() for data in ds: if idx == indices: data['image'].save(f"{base}/{idx}.png") idx += 1 if idx == max(indices): break del ds print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).") # demo = gr.Interface( # fn=predict, # inputs=gr.inputs.Image(type="pil"), # outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")], # examples=[f"../example-imgs/{idx}.png" for idx in indices], # ) with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo: gr.HTML("""

Interactive Demo

CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing



""") with gr.Row(): input_image = gr.Image(type="pil", min_width=300, label="Input Image") softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions") grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM") with gr.Row(): gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam]) gr.ClearButton(input_image) with gr.Row(): gr.Examples([f"{base}/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True) if args.local: demo.launch( share=False, debug=False, allowed_paths=[f"{base}"], server_name="0.0.0.0", # ssl_verify=False, server_port=8000, # ssl_certfile="/workspace/openssl/cert.pem", ssl_keyfile="/workspace/openssl/key.pem" ) else: demo.launch(allowed_paths=[f"{base}"])