sayakpaul HF staff commited on
Commit
ccecbb2
·
1 Parent(s): 008e1c0

fix: output grid and caching.

Browse files
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import timm
@@ -6,7 +8,6 @@ from timm import create_model
6
  from timm.models.layers import PatchEmbed
7
  from torchvision.models.feature_extraction import create_feature_extractor
8
  from torchvision.transforms import functional as F
9
- import glob
10
 
11
  CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
12
  TRANSFORM = timm.data.create_transform(
@@ -73,6 +74,7 @@ def generate_plot(processed_map):
73
  fig.tight_layout()
74
  return fig
75
 
 
76
  def serialize_images(processed_map):
77
  """Serializes attention maps."""
78
  print(f"Number of maps: {processed_map.shape[0]}")
@@ -94,7 +96,7 @@ def generate_class_attn_map(image, block_id=0):
94
 
95
  block_key = f"blocks_token_only.{block_id}.attn.softmax"
96
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
97
-
98
  serialize_images(processed_cls_attn_map)
99
  all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
100
  print(f"Number of images: {len(all_attn_img_paths)}")
@@ -107,10 +109,10 @@ article = "Class attention maps as investigated in [Going deeper with Image Tran
107
  iface = gr.Interface(
108
  generate_class_attn_map,
109
  inputs=[
110
- gr.inputs.Image(type="pil", label="Input Image"),
111
  gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
112
  ],
113
- outputs=gr.Gallery().style(grid=[2], height="auto"),
114
  title=title,
115
  article=article,
116
  allow_flagging="never",
 
1
+ import glob
2
+
3
  import gradio as gr
4
  import matplotlib.pyplot as plt
5
  import timm
 
8
  from timm.models.layers import PatchEmbed
9
  from torchvision.models.feature_extraction import create_feature_extractor
10
  from torchvision.transforms import functional as F
 
11
 
12
  CAIT_MODEL = create_model("cait_xxs24_224.fb_dist_in1k", pretrained=True).eval()
13
  TRANSFORM = timm.data.create_transform(
 
74
  fig.tight_layout()
75
  return fig
76
 
77
+
78
  def serialize_images(processed_map):
79
  """Serializes attention maps."""
80
  print(f"Number of maps: {processed_map.shape[0]}")
 
96
 
97
  block_key = f"blocks_token_only.{block_id}.attn.softmax"
98
  processed_cls_attn_map = get_cls_attention_map(image_tensor, out, block_key)
99
+
100
  serialize_images(processed_cls_attn_map)
101
  all_attn_img_paths = sorted(glob.glob("attention_map_*.png"))
102
  print(f"Number of images: {len(all_attn_img_paths)}")
 
109
  iface = gr.Interface(
110
  generate_class_attn_map,
111
  inputs=[
112
+ gr.Image(type="pil", label="Input Image"),
113
  gr.Slider(0, 1, value=0, step=1, label="Block ID", info="Transformer Block ID"),
114
  ],
115
+ outputs=gr.Gallery().style(columns=2, height="auto", object_fit="scale-down"),
116
  title=title,
117
  article=article,
118
  allow_flagging="never",
gradio_cached_examples/14/log.csv DELETED
@@ -1,2 +0,0 @@
1
- output,flag,username,timestamp
2
- /Users/sayakpaul/Downloads/class-attention-map/gradio_cached_examples/14/output/24ed4fad-3279-4814-ba76-b4c411c673a0,,,2023-06-11 11:55:03.515035
 
 
 
gradio_cached_examples/14/output/24ed4fad-3279-4814-ba76-b4c411c673a0/76269f58a7c390191fe41c6e016b4904749cd456/attention_map_i.png DELETED
Binary file (29.3 kB)
 
gradio_cached_examples/14/output/24ed4fad-3279-4814-ba76-b4c411c673a0/captions.json DELETED
@@ -1 +0,0 @@
1
- {"/Users/sayakpaul/Downloads/class-attention-map/gradio_cached_examples/14/output/24ed4fad-3279-4814-ba76-b4c411c673a0/76269f58a7c390191fe41c6e016b4904749cd456/attention_map_i.png": null}