|
import os |
|
while True: |
|
try: |
|
import cv2 |
|
except ImportError: |
|
print("Package cv2 not found. Attepting installation.") |
|
os.system("pip install -U opencv-python &> /dev/null") |
|
continue |
|
break |
|
|
|
import os, cv2, time, math |
|
print("=> Loading libraries...") |
|
start = time.time() |
|
|
|
import requests, torch, argparse |
|
import gradio as gr |
|
from torchvision import transforms |
|
from datasets import load_dataset |
|
from timm.data import create_transform |
|
from timm.models import create_model, load_checkpoint |
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--local", action='store_true') |
|
args = parser.parse_args() |
|
|
|
if not args.local: |
|
print("=> Logging into huggingface...") |
|
from huggingface_hub import login |
|
login(token=os.environ["HF_TOKEN"]) |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).") |
|
print("=> Loading model...") |
|
start = time.time() |
|
|
|
size = "b" |
|
img_size = 224 |
|
crop_pct = 0.9 |
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
model = create_model(f"tpmlp_{size}").to(device) |
|
try: |
|
load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True) |
|
except FileNotFoundError: |
|
load_checkpoint(model, f"tpmlp_{size}.pth.tar", True) |
|
model.eval() |
|
|
|
response = requests.get("https://git.io/JJkYN") |
|
labels = response.text.split("\n") |
|
|
|
augs = create_transform( |
|
input_size=(3, 224, 224), |
|
is_training=False, |
|
use_prefetcher=False, |
|
crop_pct=0.9, |
|
) |
|
|
|
|
|
scale_size = math.floor(img_size / crop_pct) |
|
resize = transforms.Compose([ |
|
transforms.Resize(scale_size), |
|
transforms.CenterCrop(img_size), |
|
transforms.ToTensor() |
|
]) |
|
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD)) |
|
|
|
def transform(img): |
|
img = resize(img.convert("RGB")) |
|
tensor = normalize(img) |
|
return img, tensor |
|
|
|
def predict(inp): |
|
img, inp = transform(inp) |
|
inp = inp.unsqueeze(0) |
|
with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam: |
|
grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True) |
|
|
|
|
|
grayscale_cam = grayscale_cam[0, :] |
|
probs = probs[0, :] |
|
|
|
cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED) |
|
confidences = {labels[i]: float(probs[i]) for i in range(1000)} |
|
return confidences, cam_image |
|
|
|
print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).") |
|
|
|
base = "../example-imgs" if args.local else "." |
|
|
|
print("=> Loading examples.") |
|
indices = [ |
|
0, |
|
2, |
|
7, |
|
9, |
|
10, |
|
11, |
|
12, |
|
14, |
|
] |
|
ds = load_dataset("imagenet-1k", split="validation", streaming=True) |
|
examples = []; idx = 0 |
|
start = time.time() |
|
for data in ds: |
|
if idx == indices: |
|
data['image'].save(f"{base}/{idx}.png") |
|
idx += 1 |
|
if idx == max(indices): |
|
break |
|
del ds |
|
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo: |
|
gr.HTML(""" |
|
<h1 align="center">Interactive Demo</h1> |
|
<h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2> |
|
<br><br> |
|
""") |
|
with gr.Row(): |
|
input_image = gr.Image(type="pil", min_width=300, label="Input Image") |
|
softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions") |
|
grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM") |
|
with gr.Row(): |
|
gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam]) |
|
gr.ClearButton(input_image) |
|
with gr.Row(): |
|
gr.Examples([f"{base}/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True) |
|
|
|
if args.local: |
|
demo.launch( |
|
share=False, debug=False, allowed_paths=[f"{base}"], server_name="0.0.0.0", |
|
server_port=8000, |
|
) |
|
else: |
|
demo.launch(allowed_paths=[f"{base}"]) |