sunana commited on
Commit
4fbf139
·
1 Parent(s): cb0ba4a

Update MT.py

Browse files
Files changed (1) hide show
  1. MT.py +5 -4
MT.py CHANGED
@@ -285,10 +285,11 @@ class FeatureTransformer(nn.Module):
285
  for i in range(self.num_layers):
286
  value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
287
  value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
288
-
289
- attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
290
- attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
291
- feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
 
292
  return feature_list, attn_list, attn_viz_list
293
 
294
  def forward_save_mem(self, feature0, add_position_embedding=True):
 
285
  for i in range(self.num_layers):
286
  value, att, attn_viz = self.layers(att=att, value=value, shape=[h, w], iteration=i)
287
  value_decode = self.normalize(torch.square(self.re_proj(value))) # map to motion energy, Do use normalization here
288
+ if i % 2 == 0:
289
+ attn_viz_list.append(attn_viz.reshape(b, h, w, h, w))
290
+ attn_list.append(att.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
291
+ feature_list.append(value_decode.view(b, h, w, c).permute(0, 3, 1, 2).contiguous())
292
+
293
  return feature_list, attn_list, attn_viz_list
294
 
295
  def forward_save_mem(self, feature0, add_position_embedding=True):