Vidushee commited on
Commit
d41e857
·
verified ·
1 Parent(s): 9b9d138

Upload 14 files

Browse files
Zocket_ImageBind.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from imagebind import data
2
+ import torch
3
+ from imagebind.models import imagebind_model
4
+ from imagebind.models.imagebind_model import ModalityType
5
+ import gradio as gr
6
+
7
+ # command = "pip install git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d timm==0.6.7 ftfy regex einops fvcore decord==0.6.0"
8
+ # process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE)
9
+ # process.wait()
10
+ # print(process.returncode) # should print 0 if installation was successful
11
+
12
+
13
+
14
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
15
+
16
+ # Instantiate model
17
+ model = imagebind_model.imagebind_huge(pretrained=True)
18
+ model.eval()
19
+ model.to(device)
20
+
21
+ text_list = ["An Advertisement(branding, text, promotions, lifestyle depiction, contextual cues, and visual composition)","Not an Advertisement"]
22
+ image_paths = []
23
+
24
+
25
+
26
+ with gr.Blocks() as demo:
27
+ image = gr.File()
28
+ image_paths.append(image)
29
+
30
+
31
+ gr.Markdown(
32
+ """
33
+ Zocket ImageBind made AdBind
34
+ """)
35
+
36
+
37
+ inputs = {
38
+ ModalityType.TEXT: data.load_and_transform_text(text_list, device),
39
+ ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
40
+ }
41
+
42
+ with torch.no_grad():
43
+ embeddings = model(inputs)
44
+
45
+ print(
46
+ "Vision x Text: ",
47
+ torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
48
+ )
49
+
50
+ out = f"""Output = {torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1)}"""
51
+ gr.Markdown(out)
52
+
53
+
54
+
55
+ demo.launch()
56
+
57
+
58
+ # Load data
__init__.py ADDED
File without changes
bpe/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
imagebind/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from imagebind import data
2
+ from imagebind.models import imagebind_model
3
+ from imagebind.models.imagebind_model import ModalityType
imagebind/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (341 Bytes). View file
 
imagebind/__pycache__/data.cpython-310.pyc ADDED
Binary file (9.23 kB). View file
 
imagebind/data.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torchaudio
14
+ from PIL import Image
15
+ from pytorchvideo import transforms as pv_transforms
16
+ from pytorchvideo.data.clip_sampling import ConstantClipsPerVideoSampler
17
+ from pytorchvideo.data.encoded_video import EncodedVideo
18
+ from torchvision import transforms
19
+ from torchvision.transforms._transforms_video import NormalizeVideo
20
+
21
+ from imagebind.models.multimodal_preprocessors import SimpleTokenizer
22
+
23
+ DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 # in milliseconds
24
+
25
+ BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz"
26
+
27
+
28
+ def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length):
29
+ # Based on https://github.com/YuanGongND/ast/blob/d7d8b4b8e06cdaeb6c843cdb38794c1c7692234c/src/dataloader.py#L102
30
+ waveform -= waveform.mean()
31
+ fbank = torchaudio.compliance.kaldi.fbank(
32
+ waveform,
33
+ htk_compat=True,
34
+ sample_frequency=sample_rate,
35
+ use_energy=False,
36
+ window_type="hanning",
37
+ num_mel_bins=num_mel_bins,
38
+ dither=0.0,
39
+ frame_length=25,
40
+ frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS,
41
+ )
42
+ # Convert to [mel_bins, num_frames] shape
43
+ fbank = fbank.transpose(0, 1)
44
+ # Pad to target_length
45
+ n_frames = fbank.size(1)
46
+ p = target_length - n_frames
47
+ # if p is too large (say >20%), flash a warning
48
+ if abs(p) / n_frames > 0.2:
49
+ logging.warning(
50
+ "Large gap between audio n_frames(%d) and "
51
+ "target_length (%d). Is the audio_target_length "
52
+ "setting correct?",
53
+ n_frames,
54
+ target_length,
55
+ )
56
+ # cut and pad
57
+ if p > 0:
58
+ fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0)
59
+ elif p < 0:
60
+ fbank = fbank[:, 0:target_length]
61
+ # Convert to [1, mel_bins, num_frames] shape, essentially like a 1
62
+ # channel image
63
+ fbank = fbank.unsqueeze(0)
64
+ return fbank
65
+
66
+
67
+ def get_clip_timepoints(clip_sampler, duration):
68
+ # Read out all clips in this video
69
+ all_clips_timepoints = []
70
+ is_last_clip = False
71
+ end = 0.0
72
+ while not is_last_clip:
73
+ start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None)
74
+ all_clips_timepoints.append((start, end))
75
+ return all_clips_timepoints
76
+
77
+
78
+ def load_and_transform_vision_data(image_paths, device):
79
+ if image_paths is None:
80
+ return None
81
+
82
+ image_outputs = []
83
+
84
+ data_transform = transforms.Compose(
85
+ [
86
+ transforms.Resize(
87
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
88
+ ),
89
+ transforms.CenterCrop(224),
90
+ transforms.ToTensor(),
91
+ transforms.Normalize(
92
+ mean=(0.48145466, 0.4578275, 0.40821073),
93
+ std=(0.26862954, 0.26130258, 0.27577711),
94
+ ),
95
+ ]
96
+ )
97
+
98
+ for image_path in image_paths:
99
+ with open(image_path, "rb") as fopen:
100
+ image = Image.open(fopen).convert("RGB")
101
+
102
+ image = data_transform(image).to(device)
103
+ image_outputs.append(image)
104
+ return torch.stack(image_outputs, dim=0)
105
+
106
+
107
+ def load_and_transform_text(text, device):
108
+ if text is None:
109
+ return None
110
+ tokenizer = SimpleTokenizer(bpe_path=BPE_PATH)
111
+ tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text]
112
+ tokens = torch.cat(tokens, dim=0)
113
+ return tokens
114
+
115
+
116
+ def load_and_transform_audio_data(
117
+ audio_paths,
118
+ device,
119
+ num_mel_bins=128,
120
+ target_length=204,
121
+ sample_rate=16000,
122
+ clip_duration=2,
123
+ clips_per_video=3,
124
+ mean=-4.268,
125
+ std=9.138,
126
+ ):
127
+ if audio_paths is None:
128
+ return None
129
+
130
+ audio_outputs = []
131
+ clip_sampler = ConstantClipsPerVideoSampler(
132
+ clip_duration=clip_duration, clips_per_video=clips_per_video
133
+ )
134
+
135
+ for audio_path in audio_paths:
136
+ waveform, sr = torchaudio.load(audio_path)
137
+ if sample_rate != sr:
138
+ waveform = torchaudio.functional.resample(
139
+ waveform, orig_freq=sr, new_freq=sample_rate
140
+ )
141
+ all_clips_timepoints = get_clip_timepoints(
142
+ clip_sampler, waveform.size(1) / sample_rate
143
+ )
144
+ all_clips = []
145
+ for clip_timepoints in all_clips_timepoints:
146
+ waveform_clip = waveform[
147
+ :,
148
+ int(clip_timepoints[0] * sample_rate) : int(
149
+ clip_timepoints[1] * sample_rate
150
+ ),
151
+ ]
152
+ waveform_melspec = waveform2melspec(
153
+ waveform_clip, sample_rate, num_mel_bins, target_length
154
+ )
155
+ all_clips.append(waveform_melspec)
156
+
157
+ normalize = transforms.Normalize(mean=mean, std=std)
158
+ all_clips = [normalize(ac).to(device) for ac in all_clips]
159
+
160
+ all_clips = torch.stack(all_clips, dim=0)
161
+ audio_outputs.append(all_clips)
162
+
163
+ return torch.stack(audio_outputs, dim=0)
164
+
165
+
166
+ def crop_boxes(boxes, x_offset, y_offset):
167
+ """
168
+ Perform crop on the bounding boxes given the offsets.
169
+ Args:
170
+ boxes (ndarray or None): bounding boxes to perform crop. The dimension
171
+ is `num boxes` x 4.
172
+ x_offset (int): cropping offset in the x axis.
173
+ y_offset (int): cropping offset in the y axis.
174
+ Returns:
175
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
176
+ `num boxes` x 4.
177
+ """
178
+ cropped_boxes = boxes.copy()
179
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
180
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
181
+
182
+ return cropped_boxes
183
+
184
+
185
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
186
+ """
187
+ Perform uniform spatial sampling on the images and corresponding boxes.
188
+ Args:
189
+ images (tensor): images to perform uniform crop. The dimension is
190
+ `num frames` x `channel` x `height` x `width`.
191
+ size (int): size of height and weight to crop the images.
192
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
193
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
194
+ crop if height is larger than width.
195
+ boxes (ndarray or None): optional. Corresponding boxes to images.
196
+ Dimension is `num boxes` x 4.
197
+ scale_size (int): optinal. If not None, resize the images to scale_size before
198
+ performing any crop.
199
+ Returns:
200
+ cropped (tensor): images with dimension of
201
+ `num frames` x `channel` x `size` x `size`.
202
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
203
+ `num boxes` x 4.
204
+ """
205
+ assert spatial_idx in [0, 1, 2]
206
+ ndim = len(images.shape)
207
+ if ndim == 3:
208
+ images = images.unsqueeze(0)
209
+ height = images.shape[2]
210
+ width = images.shape[3]
211
+
212
+ if scale_size is not None:
213
+ if width <= height:
214
+ width, height = scale_size, int(height / width * scale_size)
215
+ else:
216
+ width, height = int(width / height * scale_size), scale_size
217
+ images = torch.nn.functional.interpolate(
218
+ images,
219
+ size=(height, width),
220
+ mode="bilinear",
221
+ align_corners=False,
222
+ )
223
+
224
+ y_offset = int(math.ceil((height - size) / 2))
225
+ x_offset = int(math.ceil((width - size) / 2))
226
+
227
+ if height > width:
228
+ if spatial_idx == 0:
229
+ y_offset = 0
230
+ elif spatial_idx == 2:
231
+ y_offset = height - size
232
+ else:
233
+ if spatial_idx == 0:
234
+ x_offset = 0
235
+ elif spatial_idx == 2:
236
+ x_offset = width - size
237
+ cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
238
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
239
+ if ndim == 3:
240
+ cropped = cropped.squeeze(0)
241
+ return cropped, cropped_boxes
242
+
243
+
244
+ class SpatialCrop(nn.Module):
245
+ """
246
+ Convert the video into 3 smaller clips spatially. Must be used after the
247
+ temporal crops to get spatial crops, and should be used with
248
+ -2 in the spatial crop at the slowfast augmentation stage (so full
249
+ frames are passed in here). Will return a larger list with the
250
+ 3x spatial crops as well.
251
+ """
252
+
253
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
254
+ super().__init__()
255
+ self.crop_size = crop_size
256
+ if num_crops == 3:
257
+ self.crops_to_ext = [0, 1, 2]
258
+ self.flipped_crops_to_ext = []
259
+ elif num_crops == 1:
260
+ self.crops_to_ext = [1]
261
+ self.flipped_crops_to_ext = []
262
+ else:
263
+ raise NotImplementedError("Nothing else supported yet")
264
+
265
+ def forward(self, videos):
266
+ """
267
+ Args:
268
+ videos: A list of C, T, H, W videos.
269
+ Returns:
270
+ videos: A list with 3x the number of elements. Each video converted
271
+ to C, T, H', W' by spatial cropping.
272
+ """
273
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
274
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
275
+ res = []
276
+ for video in videos:
277
+ for spatial_idx in self.crops_to_ext:
278
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
279
+ if not self.flipped_crops_to_ext:
280
+ continue
281
+ flipped_video = transforms.functional.hflip(video)
282
+ for spatial_idx in self.flipped_crops_to_ext:
283
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
284
+ return res
285
+
286
+
287
+ def load_and_transform_video_data(
288
+ video_paths,
289
+ device,
290
+ clip_duration=2,
291
+ clips_per_video=5,
292
+ sample_rate=16000,
293
+ ):
294
+ if video_paths is None:
295
+ return None
296
+
297
+ video_outputs = []
298
+ video_transform = transforms.Compose(
299
+ [
300
+ pv_transforms.ShortSideScale(224),
301
+ NormalizeVideo(
302
+ mean=(0.48145466, 0.4578275, 0.40821073),
303
+ std=(0.26862954, 0.26130258, 0.27577711),
304
+ ),
305
+ ]
306
+ )
307
+
308
+ clip_sampler = ConstantClipsPerVideoSampler(
309
+ clip_duration=clip_duration, clips_per_video=clips_per_video
310
+ )
311
+ frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration)
312
+
313
+ for video_path in video_paths:
314
+ video = EncodedVideo.from_path(
315
+ video_path,
316
+ decoder="decord",
317
+ decode_audio=False,
318
+ **{"sample_rate": sample_rate},
319
+ )
320
+
321
+ all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration)
322
+
323
+ all_video = []
324
+ for clip_timepoints in all_clips_timepoints:
325
+ # Read the clip, get frames
326
+ clip = video.get_clip(clip_timepoints[0], clip_timepoints[1])
327
+ if clip is None:
328
+ raise ValueError("No clip found")
329
+ video_clip = frame_sampler(clip["video"])
330
+ video_clip = video_clip / 255.0 # since this is float, need 0-1
331
+
332
+ all_video.append(video_clip)
333
+
334
+ all_video = [video_transform(clip) for clip in all_video]
335
+ all_video = SpatialCrop(224, num_crops=3)(all_video)
336
+
337
+ all_video = torch.stack(all_video, dim=0)
338
+ video_outputs.append(all_video)
339
+
340
+ return torch.stack(video_outputs, dim=0).to(device)
imagebind/models/__init__.py ADDED
File without changes
imagebind/models/helpers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import einops
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class Normalize(nn.Module):
16
+ def __init__(self, dim: int) -> None:
17
+ super().__init__()
18
+ self.dim = dim
19
+
20
+ def forward(self, x):
21
+ return torch.nn.functional.normalize(x, dim=self.dim, p=2)
22
+
23
+
24
+ class LearnableLogitScaling(nn.Module):
25
+ def __init__(
26
+ self,
27
+ logit_scale_init: float = 1 / 0.07,
28
+ learnable: bool = True,
29
+ max_logit_scale: float = 100,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.max_logit_scale = max_logit_scale
33
+ self.logit_scale_init = logit_scale_init
34
+ self.learnable = learnable
35
+ log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
36
+ if learnable:
37
+ self.log_logit_scale = nn.Parameter(log_logit_scale)
38
+ else:
39
+ self.register_buffer("log_logit_scale", log_logit_scale)
40
+
41
+ def forward(self, x):
42
+ return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
43
+
44
+ def extra_repr(self):
45
+ st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}," \
46
+ f" max_logit_scale={self.max_logit_scale}"
47
+ return st
48
+
49
+
50
+ class EinOpsRearrange(nn.Module):
51
+ def __init__(self, rearrange_expr: str, **kwargs) -> None:
52
+ super().__init__()
53
+ self.rearrange_expr = rearrange_expr
54
+ self.kwargs = kwargs
55
+
56
+ def forward(self, x):
57
+ assert isinstance(x, torch.Tensor)
58
+ return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
59
+
60
+
61
+ class VerboseNNModule(nn.Module):
62
+ """
63
+ Wrapper around nn.Module that prints registered buffers and parameter names.
64
+ """
65
+
66
+ @staticmethod
67
+ def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
68
+ st = (
69
+ "("
70
+ + name
71
+ + "): "
72
+ + "tensor("
73
+ + str(tuple(tensor[1].shape))
74
+ + ", requires_grad="
75
+ + str(tensor[1].requires_grad)
76
+ + ")\n"
77
+ )
78
+ return st
79
+
80
+ def extra_repr(self) -> str:
81
+ named_modules = set()
82
+ for p in self.named_modules():
83
+ named_modules.update([p[0]])
84
+ named_modules = list(named_modules)
85
+
86
+ string_repr = ""
87
+ for p in self.named_parameters():
88
+ name = p[0].split(".")[0]
89
+ if name not in named_modules:
90
+ string_repr += self.get_readable_tensor_repr(name, p)
91
+
92
+ for p in self.named_buffers():
93
+ name = p[0].split(".")[0]
94
+ string_repr += self.get_readable_tensor_repr(name, p)
95
+
96
+ return string_repr
97
+
98
+
99
+ def cast_if_src_dtype(
100
+ tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
101
+ ):
102
+ updated = False
103
+ if tensor.dtype == src_dtype:
104
+ tensor = tensor.to(dtype=tgt_dtype)
105
+ updated = True
106
+ return tensor, updated
107
+
108
+
109
+ class QuickGELU(nn.Module):
110
+ # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
111
+ def forward(self, x: torch.Tensor):
112
+ return x * torch.sigmoid(1.702 * x)
113
+
114
+
115
+ class SelectElement(nn.Module):
116
+ def __init__(self, index) -> None:
117
+ super().__init__()
118
+ self.index = index
119
+
120
+ def forward(self, x):
121
+ assert x.ndim >= 3
122
+ return x[:, self.index, ...]
123
+
124
+
125
+ class SelectEOSAndProject(nn.Module):
126
+ """
127
+ Text Pooling used in OpenCLIP
128
+ """
129
+
130
+ def __init__(self, proj: nn.Module) -> None:
131
+ super().__init__()
132
+ self.proj = proj
133
+
134
+ def forward(self, x, seq_len):
135
+ assert x.ndim == 3
136
+ # x is of shape B x L x D
137
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
138
+ x = x[torch.arange(x.shape[0]), seq_len]
139
+ x = self.proj(x)
140
+ return x
imagebind/models/imagebind_model.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import os
10
+ from functools import partial
11
+ from types import SimpleNamespace
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ from imagebind.models.helpers import (EinOpsRearrange, LearnableLogitScaling, Normalize,
17
+ SelectElement, SelectEOSAndProject)
18
+ from imagebind.models.multimodal_preprocessors import (AudioPreprocessor,
19
+ IMUPreprocessor, PadIm2Video,
20
+ PatchEmbedGeneric,
21
+ RGBDTPreprocessor,
22
+ SpatioTemporalPosEmbeddingHelper,
23
+ TextPreprocessor,
24
+ ThermalPreprocessor)
25
+ from imagebind.models.transformer import MultiheadAttention, SimpleTransformer
26
+
27
+ ModalityType = SimpleNamespace(
28
+ VISION="vision",
29
+ TEXT="text",
30
+ AUDIO="audio",
31
+ THERMAL="thermal",
32
+ DEPTH="depth",
33
+ IMU="imu",
34
+ )
35
+
36
+
37
+ class ImageBindModel(nn.Module):
38
+ def __init__(
39
+ self,
40
+ video_frames=2,
41
+ kernel_size=(2, 14, 14),
42
+ audio_kernel_size=16,
43
+ audio_stride=10,
44
+ out_embed_dim=768,
45
+ vision_embed_dim=1024,
46
+ vision_num_blocks=24,
47
+ vision_num_heads=16,
48
+ audio_embed_dim=768,
49
+ audio_num_blocks=12,
50
+ audio_num_heads=12,
51
+ audio_num_mel_bins=128,
52
+ audio_target_len=204,
53
+ audio_drop_path=0.1,
54
+ text_embed_dim=768,
55
+ text_num_blocks=12,
56
+ text_num_heads=12,
57
+ depth_embed_dim=384,
58
+ depth_kernel_size=16,
59
+ depth_num_blocks=12,
60
+ depth_num_heads=8,
61
+ depth_drop_path=0.0,
62
+ thermal_embed_dim=768,
63
+ thermal_kernel_size=16,
64
+ thermal_num_blocks=12,
65
+ thermal_num_heads=12,
66
+ thermal_drop_path=0.0,
67
+ imu_embed_dim=512,
68
+ imu_kernel_size=8,
69
+ imu_num_blocks=6,
70
+ imu_num_heads=8,
71
+ imu_drop_path=0.7,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.modality_preprocessors = self._create_modality_preprocessors(
76
+ video_frames,
77
+ vision_embed_dim,
78
+ kernel_size,
79
+ text_embed_dim,
80
+ audio_embed_dim,
81
+ audio_kernel_size,
82
+ audio_stride,
83
+ audio_num_mel_bins,
84
+ audio_target_len,
85
+ depth_embed_dim,
86
+ depth_kernel_size,
87
+ thermal_embed_dim,
88
+ thermal_kernel_size,
89
+ imu_embed_dim,
90
+ )
91
+
92
+ self.modality_trunks = self._create_modality_trunks(
93
+ vision_embed_dim,
94
+ vision_num_blocks,
95
+ vision_num_heads,
96
+ text_embed_dim,
97
+ text_num_blocks,
98
+ text_num_heads,
99
+ audio_embed_dim,
100
+ audio_num_blocks,
101
+ audio_num_heads,
102
+ audio_drop_path,
103
+ depth_embed_dim,
104
+ depth_num_blocks,
105
+ depth_num_heads,
106
+ depth_drop_path,
107
+ thermal_embed_dim,
108
+ thermal_num_blocks,
109
+ thermal_num_heads,
110
+ thermal_drop_path,
111
+ imu_embed_dim,
112
+ imu_num_blocks,
113
+ imu_num_heads,
114
+ imu_drop_path,
115
+ )
116
+
117
+ self.modality_heads = self._create_modality_heads(
118
+ out_embed_dim,
119
+ vision_embed_dim,
120
+ text_embed_dim,
121
+ audio_embed_dim,
122
+ depth_embed_dim,
123
+ thermal_embed_dim,
124
+ imu_embed_dim,
125
+ )
126
+
127
+ self.modality_postprocessors = self._create_modality_postprocessors(
128
+ out_embed_dim
129
+ )
130
+
131
+ def _create_modality_preprocessors(
132
+ self,
133
+ video_frames=2,
134
+ vision_embed_dim=1024,
135
+ kernel_size=(2, 14, 14),
136
+ text_embed_dim=768,
137
+ audio_embed_dim=768,
138
+ audio_kernel_size=16,
139
+ audio_stride=10,
140
+ audio_num_mel_bins=128,
141
+ audio_target_len=204,
142
+ depth_embed_dim=768,
143
+ depth_kernel_size=16,
144
+ thermal_embed_dim=768,
145
+ thermal_kernel_size=16,
146
+ imu_embed_dim=512,
147
+ ):
148
+ rgbt_stem = PatchEmbedGeneric(
149
+ proj_stem=[
150
+ PadIm2Video(pad_type="repeat", ntimes=2),
151
+ nn.Conv3d(
152
+ in_channels=3,
153
+ kernel_size=kernel_size,
154
+ out_channels=vision_embed_dim,
155
+ stride=kernel_size,
156
+ bias=False,
157
+ ),
158
+ ]
159
+ )
160
+ rgbt_preprocessor = RGBDTPreprocessor(
161
+ img_size=[3, video_frames, 224, 224],
162
+ num_cls_tokens=1,
163
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
164
+ rgbt_stem=rgbt_stem,
165
+ depth_stem=None,
166
+ )
167
+
168
+ text_preprocessor = TextPreprocessor(
169
+ context_length=77,
170
+ vocab_size=49408,
171
+ embed_dim=text_embed_dim,
172
+ causal_masking=True,
173
+ )
174
+
175
+ audio_stem = PatchEmbedGeneric(
176
+ proj_stem=[
177
+ nn.Conv2d(
178
+ in_channels=1,
179
+ kernel_size=audio_kernel_size,
180
+ stride=audio_stride,
181
+ out_channels=audio_embed_dim,
182
+ bias=False,
183
+ ),
184
+ ],
185
+ norm_layer=nn.LayerNorm(normalized_shape=audio_embed_dim),
186
+ )
187
+ audio_preprocessor = AudioPreprocessor(
188
+ img_size=[1, audio_num_mel_bins, audio_target_len],
189
+ num_cls_tokens=1,
190
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
191
+ audio_stem=audio_stem,
192
+ )
193
+
194
+ depth_stem = PatchEmbedGeneric(
195
+ [
196
+ nn.Conv2d(
197
+ kernel_size=depth_kernel_size,
198
+ in_channels=1,
199
+ out_channels=depth_embed_dim,
200
+ stride=depth_kernel_size,
201
+ bias=False,
202
+ ),
203
+ ],
204
+ norm_layer=nn.LayerNorm(normalized_shape=depth_embed_dim),
205
+ )
206
+
207
+ depth_preprocessor = RGBDTPreprocessor(
208
+ img_size=[1, 224, 224],
209
+ num_cls_tokens=1,
210
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
211
+ rgbt_stem=None,
212
+ depth_stem=depth_stem,
213
+ )
214
+
215
+ thermal_stem = PatchEmbedGeneric(
216
+ [
217
+ nn.Conv2d(
218
+ kernel_size=thermal_kernel_size,
219
+ in_channels=1,
220
+ out_channels=thermal_embed_dim,
221
+ stride=thermal_kernel_size,
222
+ bias=False,
223
+ ),
224
+ ],
225
+ norm_layer=nn.LayerNorm(normalized_shape=thermal_embed_dim),
226
+ )
227
+ thermal_preprocessor = ThermalPreprocessor(
228
+ img_size=[1, 224, 224],
229
+ num_cls_tokens=1,
230
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
231
+ thermal_stem=thermal_stem,
232
+ )
233
+
234
+ imu_stem = PatchEmbedGeneric(
235
+ [
236
+ nn.Linear(
237
+ in_features=48,
238
+ out_features=imu_embed_dim,
239
+ bias=False,
240
+ ),
241
+ ],
242
+ norm_layer=nn.LayerNorm(normalized_shape=imu_embed_dim),
243
+ )
244
+
245
+ imu_preprocessor = IMUPreprocessor(
246
+ img_size=[6, 2000],
247
+ num_cls_tokens=1,
248
+ kernel_size=8,
249
+ embed_dim=imu_embed_dim,
250
+ pos_embed_fn=partial(SpatioTemporalPosEmbeddingHelper, learnable=True),
251
+ imu_stem=imu_stem,
252
+ )
253
+
254
+ modality_preprocessors = {
255
+ ModalityType.VISION: rgbt_preprocessor,
256
+ ModalityType.TEXT: text_preprocessor,
257
+ ModalityType.AUDIO: audio_preprocessor,
258
+ ModalityType.DEPTH: depth_preprocessor,
259
+ ModalityType.THERMAL: thermal_preprocessor,
260
+ ModalityType.IMU: imu_preprocessor,
261
+ }
262
+
263
+ return nn.ModuleDict(modality_preprocessors)
264
+
265
+ def _create_modality_trunks(
266
+ self,
267
+ vision_embed_dim=1024,
268
+ vision_num_blocks=24,
269
+ vision_num_heads=16,
270
+ text_embed_dim=768,
271
+ text_num_blocks=12,
272
+ text_num_heads=12,
273
+ audio_embed_dim=768,
274
+ audio_num_blocks=12,
275
+ audio_num_heads=12,
276
+ audio_drop_path=0.0,
277
+ depth_embed_dim=768,
278
+ depth_num_blocks=12,
279
+ depth_num_heads=12,
280
+ depth_drop_path=0.0,
281
+ thermal_embed_dim=768,
282
+ thermal_num_blocks=12,
283
+ thermal_num_heads=12,
284
+ thermal_drop_path=0.0,
285
+ imu_embed_dim=512,
286
+ imu_num_blocks=6,
287
+ imu_num_heads=8,
288
+ imu_drop_path=0.7,
289
+ ):
290
+ def instantiate_trunk(
291
+ embed_dim, num_blocks, num_heads, pre_transformer_ln, add_bias_kv, drop_path
292
+ ):
293
+ return SimpleTransformer(
294
+ embed_dim=embed_dim,
295
+ num_blocks=num_blocks,
296
+ ffn_dropout_rate=0.0,
297
+ drop_path_rate=drop_path,
298
+ attn_target=partial(
299
+ MultiheadAttention,
300
+ embed_dim=embed_dim,
301
+ num_heads=num_heads,
302
+ bias=True,
303
+ add_bias_kv=add_bias_kv,
304
+ ),
305
+ pre_transformer_layer=nn.Sequential(
306
+ nn.LayerNorm(embed_dim, eps=1e-6)
307
+ if pre_transformer_ln
308
+ else nn.Identity(),
309
+ EinOpsRearrange("b l d -> l b d"),
310
+ ),
311
+ post_transformer_layer=EinOpsRearrange("l b d -> b l d"),
312
+ )
313
+
314
+ modality_trunks = {}
315
+ modality_trunks[ModalityType.VISION] = instantiate_trunk(
316
+ vision_embed_dim,
317
+ vision_num_blocks,
318
+ vision_num_heads,
319
+ pre_transformer_ln=True,
320
+ add_bias_kv=False,
321
+ drop_path=0.0,
322
+ )
323
+ modality_trunks[ModalityType.TEXT] = instantiate_trunk(
324
+ text_embed_dim,
325
+ text_num_blocks,
326
+ text_num_heads,
327
+ pre_transformer_ln=False,
328
+ add_bias_kv=False,
329
+ drop_path=0.0,
330
+ )
331
+ modality_trunks[ModalityType.AUDIO] = instantiate_trunk(
332
+ audio_embed_dim,
333
+ audio_num_blocks,
334
+ audio_num_heads,
335
+ pre_transformer_ln=False,
336
+ add_bias_kv=True,
337
+ drop_path=audio_drop_path,
338
+ )
339
+ modality_trunks[ModalityType.DEPTH] = instantiate_trunk(
340
+ depth_embed_dim,
341
+ depth_num_blocks,
342
+ depth_num_heads,
343
+ pre_transformer_ln=False,
344
+ add_bias_kv=True,
345
+ drop_path=depth_drop_path,
346
+ )
347
+ modality_trunks[ModalityType.THERMAL] = instantiate_trunk(
348
+ thermal_embed_dim,
349
+ thermal_num_blocks,
350
+ thermal_num_heads,
351
+ pre_transformer_ln=False,
352
+ add_bias_kv=True,
353
+ drop_path=thermal_drop_path,
354
+ )
355
+ modality_trunks[ModalityType.IMU] = instantiate_trunk(
356
+ imu_embed_dim,
357
+ imu_num_blocks,
358
+ imu_num_heads,
359
+ pre_transformer_ln=False,
360
+ add_bias_kv=True,
361
+ drop_path=imu_drop_path,
362
+ )
363
+
364
+ return nn.ModuleDict(modality_trunks)
365
+
366
+ def _create_modality_heads(
367
+ self,
368
+ out_embed_dim,
369
+ vision_embed_dim,
370
+ text_embed_dim,
371
+ audio_embed_dim,
372
+ depth_embed_dim,
373
+ thermal_embed_dim,
374
+ imu_embed_dim,
375
+ ):
376
+ modality_heads = {}
377
+
378
+ modality_heads[ModalityType.VISION] = nn.Sequential(
379
+ nn.LayerNorm(normalized_shape=vision_embed_dim, eps=1e-6),
380
+ SelectElement(index=0),
381
+ nn.Linear(vision_embed_dim, out_embed_dim, bias=False),
382
+ )
383
+
384
+ modality_heads[ModalityType.TEXT] = SelectEOSAndProject(
385
+ proj=nn.Sequential(
386
+ nn.LayerNorm(normalized_shape=text_embed_dim, eps=1e-6),
387
+ nn.Linear(text_embed_dim, out_embed_dim, bias=False),
388
+ )
389
+ )
390
+
391
+ modality_heads[ModalityType.AUDIO] = nn.Sequential(
392
+ nn.LayerNorm(normalized_shape=audio_embed_dim, eps=1e-6),
393
+ SelectElement(index=0),
394
+ nn.Linear(audio_embed_dim, out_embed_dim, bias=False),
395
+ )
396
+
397
+ modality_heads[ModalityType.DEPTH] = nn.Sequential(
398
+ nn.LayerNorm(normalized_shape=depth_embed_dim, eps=1e-6),
399
+ SelectElement(index=0),
400
+ nn.Linear(depth_embed_dim, out_embed_dim, bias=False),
401
+ )
402
+
403
+ modality_heads[ModalityType.THERMAL] = nn.Sequential(
404
+ nn.LayerNorm(normalized_shape=thermal_embed_dim, eps=1e-6),
405
+ SelectElement(index=0),
406
+ nn.Linear(thermal_embed_dim, out_embed_dim, bias=False),
407
+ )
408
+
409
+ modality_heads[ModalityType.IMU] = nn.Sequential(
410
+ nn.LayerNorm(normalized_shape=imu_embed_dim, eps=1e-6),
411
+ SelectElement(index=0),
412
+ nn.Dropout(p=0.5),
413
+ nn.Linear(imu_embed_dim, out_embed_dim, bias=False),
414
+ )
415
+
416
+ return nn.ModuleDict(modality_heads)
417
+
418
+ def _create_modality_postprocessors(self, out_embed_dim):
419
+ modality_postprocessors = {}
420
+
421
+ modality_postprocessors[ModalityType.VISION] = Normalize(dim=-1)
422
+ modality_postprocessors[ModalityType.TEXT] = nn.Sequential(
423
+ Normalize(dim=-1), LearnableLogitScaling(learnable=True)
424
+ )
425
+ modality_postprocessors[ModalityType.AUDIO] = nn.Sequential(
426
+ Normalize(dim=-1),
427
+ LearnableLogitScaling(logit_scale_init=20.0, learnable=False),
428
+ )
429
+ modality_postprocessors[ModalityType.DEPTH] = nn.Sequential(
430
+ Normalize(dim=-1),
431
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
432
+ )
433
+ modality_postprocessors[ModalityType.THERMAL] = nn.Sequential(
434
+ Normalize(dim=-1),
435
+ LearnableLogitScaling(logit_scale_init=10.0, learnable=False),
436
+ )
437
+ modality_postprocessors[ModalityType.IMU] = nn.Sequential(
438
+ Normalize(dim=-1),
439
+ LearnableLogitScaling(logit_scale_init=5.0, learnable=False),
440
+ )
441
+
442
+ return nn.ModuleDict(modality_postprocessors)
443
+
444
+ def forward(self, inputs):
445
+ outputs = {}
446
+ for modality_key, modality_value in inputs.items():
447
+ reduce_list = (
448
+ modality_value.ndim >= 5
449
+ ) # Audio and Video inputs consist of multiple clips
450
+ if reduce_list:
451
+ B, S = modality_value.shape[:2]
452
+ modality_value = modality_value.reshape(
453
+ B * S, *modality_value.shape[2:]
454
+ )
455
+
456
+ if modality_value is not None:
457
+ modality_value = self.modality_preprocessors[modality_key](
458
+ **{modality_key: modality_value}
459
+ )
460
+ trunk_inputs = modality_value["trunk"]
461
+ head_inputs = modality_value["head"]
462
+ modality_value = self.modality_trunks[modality_key](**trunk_inputs)
463
+ modality_value = self.modality_heads[modality_key](
464
+ modality_value, **head_inputs
465
+ )
466
+ modality_value = self.modality_postprocessors[modality_key](
467
+ modality_value
468
+ )
469
+
470
+ if reduce_list:
471
+ modality_value = modality_value.reshape(B, S, -1)
472
+ modality_value = modality_value.mean(dim=1)
473
+
474
+ outputs[modality_key] = modality_value
475
+
476
+ return outputs
477
+
478
+
479
+ def imagebind_huge(pretrained=False):
480
+ model = ImageBindModel(
481
+ vision_embed_dim=1280,
482
+ vision_num_blocks=32,
483
+ vision_num_heads=16,
484
+ text_embed_dim=1024,
485
+ text_num_blocks=24,
486
+ text_num_heads=16,
487
+ out_embed_dim=1024,
488
+ audio_drop_path=0.1,
489
+ imu_drop_path=0.7,
490
+ )
491
+
492
+ if pretrained:
493
+ if not os.path.exists(".checkpoints/imagebind_huge.pth"):
494
+ print(
495
+ "Downloading imagebind weights to .checkpoints/imagebind_huge.pth ..."
496
+ )
497
+ os.makedirs(".checkpoints", exist_ok=True)
498
+ torch.hub.download_url_to_file(
499
+ "https://dl.fbaipublicfiles.com/imagebind/imagebind_huge.pth",
500
+ ".checkpoints/imagebind_huge.pth",
501
+ progress=True,
502
+ )
503
+
504
+ model.load_state_dict(torch.load(".checkpoints/imagebind_huge.pth"))
505
+
506
+ return model
imagebind/models/multimodal_preprocessors.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import gzip
9
+ import html
10
+ import io
11
+ import math
12
+ from functools import lru_cache
13
+ from typing import Callable, List, Optional, Tuple
14
+
15
+ import ftfy
16
+ import numpy as np
17
+ import regex as re
18
+ import torch
19
+ import torch.nn as nn
20
+ from iopath.common.file_io import g_pathmgr
21
+ from timm.models.layers import trunc_normal_
22
+
23
+ from imagebind.models.helpers import VerboseNNModule, cast_if_src_dtype
24
+
25
+
26
+ def get_sinusoid_encoding_table(n_position, d_hid):
27
+ """Sinusoid position encoding table"""
28
+
29
+ # TODO: make it with torch instead of numpy
30
+ def get_position_angle_vec(position):
31
+ return [
32
+ position / np.power(10000, 2 * (hid_j // 2) / d_hid)
33
+ for hid_j in range(d_hid)
34
+ ]
35
+
36
+ sinusoid_table = np.array(
37
+ [get_position_angle_vec(pos_i) for pos_i in range(n_position)]
38
+ )
39
+ sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
40
+ sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
41
+
42
+ return torch.FloatTensor(sinusoid_table).unsqueeze(0)
43
+
44
+
45
+ def interpolate_pos_encoding_2d(target_spatial_size, pos_embed):
46
+ N = pos_embed.shape[1]
47
+ if N == target_spatial_size:
48
+ return pos_embed
49
+ dim = pos_embed.shape[-1]
50
+ # nn.functional.interpolate doesn't work with bfloat16 so we cast to float32
51
+ pos_embed, updated = cast_if_src_dtype(pos_embed, torch.bfloat16, torch.float32)
52
+ pos_embed = nn.functional.interpolate(
53
+ pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(
54
+ 0, 3, 1, 2
55
+ ),
56
+ scale_factor=math.sqrt(target_spatial_size / N),
57
+ mode="bicubic",
58
+ )
59
+ if updated:
60
+ pos_embed, _ = cast_if_src_dtype(pos_embed, torch.float32, torch.bfloat16)
61
+ pos_embed = pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
62
+ return pos_embed
63
+
64
+
65
+ def interpolate_pos_encoding(
66
+ npatch_per_img,
67
+ pos_embed,
68
+ patches_layout,
69
+ input_shape=None,
70
+ first_patch_idx=1,
71
+ ):
72
+ assert first_patch_idx == 0 or first_patch_idx == 1, "there is 1 CLS token or none"
73
+ N = pos_embed.shape[1] - first_patch_idx # since it's 1 if cls_token exists
74
+ if npatch_per_img == N:
75
+ return pos_embed
76
+
77
+ assert (
78
+ patches_layout[-1] == patches_layout[-2]
79
+ ), "Interpolation of pos embed not supported for non-square layouts"
80
+
81
+ class_emb = pos_embed[:, :first_patch_idx]
82
+ pos_embed = pos_embed[:, first_patch_idx:]
83
+
84
+ if input_shape is None or patches_layout[0] == 1:
85
+ # simple 2D pos embedding, no temporal component
86
+ pos_embed = interpolate_pos_encoding_2d(npatch_per_img, pos_embed)
87
+ elif patches_layout[0] > 1:
88
+ # pos embed has a temporal component
89
+ assert len(input_shape) == 4, "temporal interpolation not supported"
90
+ # we only support 2D interpolation in this case
91
+ num_frames = patches_layout[0]
92
+ num_spatial_tokens = patches_layout[1] * patches_layout[2]
93
+ pos_embed = pos_embed.view(1, num_frames, num_spatial_tokens, -1)
94
+ # interpolate embedding for zeroth frame
95
+ pos_embed = interpolate_pos_encoding_2d(
96
+ npatch_per_img, pos_embed[0, 0, ...].unsqueeze(0)
97
+ )
98
+ else:
99
+ raise ValueError("This type of interpolation isn't implemented")
100
+
101
+ return torch.cat((class_emb, pos_embed), dim=1)
102
+
103
+
104
+ def _get_pos_embedding(
105
+ npatch_per_img,
106
+ pos_embed,
107
+ patches_layout,
108
+ input_shape,
109
+ first_patch_idx=1,
110
+ ):
111
+ pos_embed = interpolate_pos_encoding(
112
+ npatch_per_img,
113
+ pos_embed,
114
+ patches_layout,
115
+ input_shape=input_shape,
116
+ first_patch_idx=first_patch_idx,
117
+ )
118
+ return pos_embed
119
+
120
+
121
+ class PatchEmbedGeneric(nn.Module):
122
+ """
123
+ PatchEmbed from Hydra
124
+ """
125
+
126
+ def __init__(self, proj_stem, norm_layer: Optional[nn.Module] = None):
127
+ super().__init__()
128
+
129
+ if len(proj_stem) > 1:
130
+ self.proj = nn.Sequential(*proj_stem)
131
+ else:
132
+ # Special case to be able to load pre-trained models that were
133
+ # trained with a standard stem
134
+ self.proj = proj_stem[0]
135
+ self.norm_layer = norm_layer
136
+
137
+ def get_patch_layout(self, img_size):
138
+ with torch.no_grad():
139
+ dummy_img = torch.zeros(
140
+ [
141
+ 1,
142
+ ]
143
+ + img_size
144
+ )
145
+ dummy_out = self.proj(dummy_img)
146
+ embed_dim = dummy_out.shape[1]
147
+ patches_layout = tuple(dummy_out.shape[2:])
148
+ num_patches = np.prod(patches_layout)
149
+ return patches_layout, num_patches, embed_dim
150
+
151
+ def forward(self, x):
152
+ x = self.proj(x)
153
+ # B C (T) H W -> B (T)HW C
154
+ x = x.flatten(2).transpose(1, 2)
155
+ if self.norm_layer is not None:
156
+ x = self.norm_layer(x)
157
+ return x
158
+
159
+
160
+ class SpatioTemporalPosEmbeddingHelper(VerboseNNModule):
161
+ def __init__(
162
+ self,
163
+ patches_layout: List,
164
+ num_patches: int,
165
+ num_cls_tokens: int,
166
+ embed_dim: int,
167
+ learnable: bool,
168
+ ) -> None:
169
+ super().__init__()
170
+ self.num_cls_tokens = num_cls_tokens
171
+ self.patches_layout = patches_layout
172
+ self.num_patches = num_patches
173
+ self.num_tokens = num_cls_tokens + num_patches
174
+ self.learnable = learnable
175
+ if self.learnable:
176
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ else:
179
+ self.register_buffer(
180
+ "pos_embed", get_sinusoid_encoding_table(self.num_tokens, embed_dim)
181
+ )
182
+
183
+ def get_pos_embedding(self, vision_input, all_vision_tokens):
184
+ input_shape = vision_input.shape
185
+ pos_embed = _get_pos_embedding(
186
+ all_vision_tokens.size(1) - self.num_cls_tokens,
187
+ pos_embed=self.pos_embed,
188
+ patches_layout=self.patches_layout,
189
+ input_shape=input_shape,
190
+ first_patch_idx=self.num_cls_tokens,
191
+ )
192
+ return pos_embed
193
+
194
+
195
+ class RGBDTPreprocessor(VerboseNNModule):
196
+ def __init__(
197
+ self,
198
+ rgbt_stem: PatchEmbedGeneric,
199
+ depth_stem: Optional[PatchEmbedGeneric],
200
+ img_size: Tuple = (3, 224, 224),
201
+ num_cls_tokens: int = 1,
202
+ pos_embed_fn: Optional[Callable] = None,
203
+ use_type_embed: bool = False,
204
+ init_param_style: str = "openclip",
205
+ ) -> None:
206
+ super().__init__()
207
+ stem = rgbt_stem if rgbt_stem is not None else depth_stem
208
+ (
209
+ self.patches_layout,
210
+ self.num_patches,
211
+ self.embed_dim,
212
+ ) = stem.get_patch_layout(img_size)
213
+ self.rgbt_stem = rgbt_stem
214
+ self.depth_stem = depth_stem
215
+ self.use_pos_embed = pos_embed_fn is not None
216
+ self.use_type_embed = use_type_embed
217
+ self.num_cls_tokens = num_cls_tokens
218
+
219
+ if self.use_pos_embed:
220
+ self.pos_embedding_helper = pos_embed_fn(
221
+ patches_layout=self.patches_layout,
222
+ num_cls_tokens=num_cls_tokens,
223
+ num_patches=self.num_patches,
224
+ embed_dim=self.embed_dim,
225
+ )
226
+ if self.num_cls_tokens > 0:
227
+ self.cls_token = nn.Parameter(
228
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
229
+ )
230
+ if self.use_type_embed:
231
+ self.type_embed = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
232
+
233
+ self.init_parameters(init_param_style)
234
+
235
+ @torch.no_grad()
236
+ def init_parameters(self, init_param_style):
237
+ if init_param_style == "openclip":
238
+ # OpenCLIP style initialization
239
+ scale = self.embed_dim**-0.5
240
+ if self.use_pos_embed:
241
+ nn.init.normal_(self.pos_embedding_helper.pos_embed)
242
+ self.pos_embedding_helper.pos_embed *= scale
243
+
244
+ if self.num_cls_tokens > 0:
245
+ nn.init.normal_(self.cls_token)
246
+ self.cls_token *= scale
247
+ elif init_param_style == "vit":
248
+ self.cls_token.data.fill_(0)
249
+ else:
250
+ raise ValueError(f"Unknown init {init_param_style}")
251
+
252
+ if self.use_type_embed:
253
+ nn.init.normal_(self.type_embed)
254
+
255
+ def tokenize_input_and_cls_pos(self, input, stem, mask):
256
+ # tokens is of shape B x L x D
257
+ tokens = stem(input)
258
+ assert tokens.ndim == 3
259
+ assert tokens.shape[2] == self.embed_dim
260
+ B = tokens.shape[0]
261
+ if self.num_cls_tokens > 0:
262
+ class_tokens = self.cls_token.expand(
263
+ B, -1, -1
264
+ ) # stole class_tokens impl from Phil Wang, thanks
265
+ tokens = torch.cat((class_tokens, tokens), dim=1)
266
+ if self.use_pos_embed:
267
+ pos_embed = self.pos_embedding_helper.get_pos_embedding(input, tokens)
268
+ tokens = tokens + pos_embed
269
+ if self.use_type_embed:
270
+ tokens = tokens + self.type_embed.expand(B, -1, -1)
271
+ return tokens
272
+
273
+ def forward(self, vision=None, depth=None, patch_mask=None):
274
+ if patch_mask is not None:
275
+ raise NotImplementedError()
276
+
277
+ if vision is not None:
278
+ vision_tokens = self.tokenize_input_and_cls_pos(
279
+ vision, self.rgbt_stem, patch_mask
280
+ )
281
+
282
+ if depth is not None:
283
+ depth_tokens = self.tokenize_input_and_cls_pos(
284
+ depth, self.depth_stem, patch_mask
285
+ )
286
+
287
+ # aggregate tokens
288
+ if vision is not None and depth is not None:
289
+ final_tokens = vision_tokens + depth_tokens
290
+ else:
291
+ final_tokens = vision_tokens if vision is not None else depth_tokens
292
+ return_dict = {
293
+ "trunk": {
294
+ "tokens": final_tokens,
295
+ },
296
+ "head": {},
297
+ }
298
+ return return_dict
299
+
300
+
301
+ class AudioPreprocessor(RGBDTPreprocessor):
302
+ def __init__(self, audio_stem: PatchEmbedGeneric, **kwargs) -> None:
303
+ super().__init__(rgbt_stem=audio_stem, depth_stem=None, **kwargs)
304
+
305
+ def forward(self, audio=None):
306
+ return super().forward(vision=audio)
307
+
308
+
309
+ class ThermalPreprocessor(RGBDTPreprocessor):
310
+ def __init__(self, thermal_stem: PatchEmbedGeneric, **kwargs) -> None:
311
+ super().__init__(rgbt_stem=thermal_stem, depth_stem=None, **kwargs)
312
+
313
+ def forward(self, thermal=None):
314
+ return super().forward(vision=thermal)
315
+
316
+
317
+ def build_causal_attention_mask(context_length):
318
+ # lazily create causal attention mask, with full attention between the vision tokens
319
+ # pytorch uses additive attention mask; fill with -inf
320
+ mask = torch.empty(context_length, context_length, requires_grad=False)
321
+ mask.fill_(float("-inf"))
322
+ mask.triu_(1) # zero out the lower diagonal
323
+ return mask
324
+
325
+
326
+ class TextPreprocessor(VerboseNNModule):
327
+ def __init__(
328
+ self,
329
+ vocab_size: int,
330
+ context_length: int,
331
+ embed_dim: int,
332
+ causal_masking: bool,
333
+ supply_seq_len_to_head: bool = True,
334
+ num_cls_tokens: int = 0,
335
+ init_param_style: str = "openclip",
336
+ ) -> None:
337
+ super().__init__()
338
+ self.vocab_size = vocab_size
339
+ self.context_length = context_length
340
+ self.token_embedding = nn.Embedding(vocab_size, embed_dim)
341
+ self.pos_embed = nn.Parameter(
342
+ torch.empty(1, self.context_length + num_cls_tokens, embed_dim)
343
+ )
344
+ self.causal_masking = causal_masking
345
+ if self.causal_masking:
346
+ mask = build_causal_attention_mask(self.context_length)
347
+ # register the mask as a buffer so it can be moved to the right device
348
+ self.register_buffer("mask", mask)
349
+
350
+ self.supply_seq_len_to_head = supply_seq_len_to_head
351
+ self.num_cls_tokens = num_cls_tokens
352
+ self.embed_dim = embed_dim
353
+ if num_cls_tokens > 0:
354
+ assert self.causal_masking is False, "Masking + CLS token isn't implemented"
355
+ self.cls_token = nn.Parameter(
356
+ torch.zeros(1, self.num_cls_tokens, embed_dim)
357
+ )
358
+
359
+ self.init_parameters(init_param_style)
360
+
361
+ @torch.no_grad()
362
+ def init_parameters(self, init_param_style="openclip"):
363
+ # OpenCLIP style initialization
364
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
365
+ nn.init.normal_(self.pos_embed, std=0.01)
366
+
367
+ if init_param_style == "openclip":
368
+ # OpenCLIP style initialization
369
+ scale = self.embed_dim**-0.5
370
+ if self.num_cls_tokens > 0:
371
+ nn.init.normal_(self.cls_token)
372
+ self.cls_token *= scale
373
+ elif init_param_style == "vit":
374
+ self.cls_token.data.fill_(0)
375
+ else:
376
+ raise ValueError(f"Unknown init {init_param_style}")
377
+
378
+ def forward(self, text):
379
+ # text tokens are of shape B x L x D
380
+ text_tokens = self.token_embedding(text)
381
+ # concat CLS tokens if any
382
+ if self.num_cls_tokens > 0:
383
+ B = text_tokens.shape[0]
384
+ class_tokens = self.cls_token.expand(
385
+ B, -1, -1
386
+ ) # stole class_tokens impl from Phil Wang, thanks
387
+ text_tokens = torch.cat((class_tokens, text_tokens), dim=1)
388
+ text_tokens = text_tokens + self.pos_embed
389
+ return_dict = {
390
+ "trunk": {
391
+ "tokens": text_tokens,
392
+ },
393
+ "head": {},
394
+ }
395
+ # Compute sequence length after adding CLS tokens
396
+ if self.supply_seq_len_to_head:
397
+ text_lengths = text.argmax(dim=-1)
398
+ return_dict["head"] = {
399
+ "seq_len": text_lengths,
400
+ }
401
+ if self.causal_masking:
402
+ return_dict["trunk"].update({"attn_mask": self.mask})
403
+ return return_dict
404
+
405
+
406
+ class Im2Video(nn.Module):
407
+ """Convert an image into a trivial video."""
408
+
409
+ def __init__(self, time_dim=2):
410
+ super().__init__()
411
+ self.time_dim = time_dim
412
+
413
+ def forward(self, x):
414
+ if x.ndim == 4:
415
+ # B, C, H, W -> B, C, T, H, W
416
+ return x.unsqueeze(self.time_dim)
417
+ elif x.ndim == 5:
418
+ return x
419
+ else:
420
+ raise ValueError(f"Dimension incorrect {x.shape}")
421
+
422
+
423
+ class PadIm2Video(Im2Video):
424
+ def __init__(self, ntimes, pad_type, time_dim=2):
425
+ super().__init__(time_dim=time_dim)
426
+ assert ntimes > 0
427
+ assert pad_type in ["zero", "repeat"]
428
+ self.ntimes = ntimes
429
+ self.pad_type = pad_type
430
+
431
+ def forward(self, x):
432
+ x = super().forward(x)
433
+ if x.shape[self.time_dim] == 1:
434
+ if self.pad_type == "repeat":
435
+ new_shape = [1] * len(x.shape)
436
+ new_shape[self.time_dim] = self.ntimes
437
+ x = x.repeat(new_shape)
438
+ elif self.pad_type == "zero":
439
+ padarg = [0, 0] * len(x.shape)
440
+ padarg[2 * self.time_dim + 1] = self.ntimes - x.shape[self.time_dim]
441
+ x = nn.functional.pad(x, padarg)
442
+ return x
443
+
444
+
445
+ # Modified from github.com/openai/CLIP
446
+ @lru_cache()
447
+ def bytes_to_unicode():
448
+ """
449
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
450
+ The reversible bpe codes work on unicode strings.
451
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
452
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
453
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
454
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
455
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
456
+ """
457
+ bs = (
458
+ list(range(ord("!"), ord("~") + 1))
459
+ + list(range(ord("¡"), ord("¬") + 1))
460
+ + list(range(ord("®"), ord("ÿ") + 1))
461
+ )
462
+ cs = bs[:]
463
+ n = 0
464
+ for b in range(2**8):
465
+ if b not in bs:
466
+ bs.append(b)
467
+ cs.append(2**8 + n)
468
+ n += 1
469
+ cs = [chr(n) for n in cs]
470
+ return dict(zip(bs, cs))
471
+
472
+
473
+ def get_pairs(word):
474
+ """Return set of symbol pairs in a word.
475
+ Word is represented as tuple of symbols (symbols being variable-length strings).
476
+ """
477
+ pairs = set()
478
+ prev_char = word[0]
479
+ for char in word[1:]:
480
+ pairs.add((prev_char, char))
481
+ prev_char = char
482
+ return pairs
483
+
484
+
485
+ def basic_clean(text):
486
+ text = ftfy.fix_text(text)
487
+ text = html.unescape(html.unescape(text))
488
+ return text.strip()
489
+
490
+
491
+ def whitespace_clean(text):
492
+ text = re.sub(r"\s+", " ", text)
493
+ text = text.strip()
494
+ return text
495
+
496
+
497
+ class SimpleTokenizer(object):
498
+ def __init__(self, bpe_path: str, context_length=77):
499
+ self.byte_encoder = bytes_to_unicode()
500
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
501
+
502
+ with g_pathmgr.open(bpe_path, "rb") as fh:
503
+ bpe_bytes = io.BytesIO(fh.read())
504
+ merges: List[str] = gzip.open(bpe_bytes).read().decode("utf-8").split("\n")
505
+ merges = merges[1 : 49152 - 256 - 2 + 1]
506
+ merges: List[Tuple[str, ...]] = [tuple(merge.split()) for merge in merges]
507
+ vocab = list(bytes_to_unicode().values())
508
+ vocab = vocab + [v + "</w>" for v in vocab]
509
+ for merge in merges:
510
+ vocab.append("".join(merge))
511
+ vocab.extend(["<|startoftext|>", "<|endoftext|>"])
512
+ self.encoder = dict(zip(vocab, range(len(vocab))))
513
+ self.decoder = {v: k for k, v in self.encoder.items()}
514
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
515
+ self.cache = {
516
+ "<|startoftext|>": "<|startoftext|>",
517
+ "<|endoftext|>": "<|endoftext|>",
518
+ }
519
+ self.pat = re.compile(
520
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
521
+ re.IGNORECASE,
522
+ )
523
+ self.context_length = context_length
524
+
525
+ def bpe(self, token):
526
+ if token in self.cache:
527
+ return self.cache[token]
528
+ word = tuple(token[:-1]) + (token[-1] + "</w>",)
529
+ pairs = get_pairs(word)
530
+
531
+ if not pairs:
532
+ return token + "</w>"
533
+
534
+ while True:
535
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
536
+ if bigram not in self.bpe_ranks:
537
+ break
538
+ first, second = bigram
539
+ new_word = []
540
+ i = 0
541
+ while i < len(word):
542
+ try:
543
+ j = word.index(first, i)
544
+ new_word.extend(word[i:j])
545
+ i = j
546
+ except:
547
+ new_word.extend(word[i:])
548
+ break
549
+
550
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
551
+ new_word.append(first + second)
552
+ i += 2
553
+ else:
554
+ new_word.append(word[i])
555
+ i += 1
556
+ new_word = tuple(new_word)
557
+ word = new_word
558
+ if len(word) == 1:
559
+ break
560
+ else:
561
+ pairs = get_pairs(word)
562
+ word = " ".join(word)
563
+ self.cache[token] = word
564
+ return word
565
+
566
+ def encode(self, text):
567
+ bpe_tokens = []
568
+ text = whitespace_clean(basic_clean(text)).lower()
569
+ for token in re.findall(self.pat, text):
570
+ token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
571
+ bpe_tokens.extend(
572
+ self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
573
+ )
574
+ return bpe_tokens
575
+
576
+ def decode(self, tokens):
577
+ text = "".join([self.decoder[token] for token in tokens])
578
+ text = (
579
+ bytearray([self.byte_decoder[c] for c in text])
580
+ .decode("utf-8", errors="replace")
581
+ .replace("</w>", " ")
582
+ )
583
+ return text
584
+
585
+ def __call__(self, texts, context_length=None):
586
+ if not context_length:
587
+ context_length = self.context_length
588
+
589
+ if isinstance(texts, str):
590
+ texts = [texts]
591
+
592
+ sot_token = self.encoder["<|startoftext|>"]
593
+ eot_token = self.encoder["<|endoftext|>"]
594
+ all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts]
595
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
596
+
597
+ for i, tokens in enumerate(all_tokens):
598
+ tokens = tokens[:context_length]
599
+ result[i, : len(tokens)] = torch.tensor(tokens)
600
+
601
+ if len(result) == 1:
602
+ return result[0]
603
+ return result
604
+
605
+
606
+ class IMUPreprocessor(VerboseNNModule):
607
+ def __init__(
608
+ self,
609
+ kernel_size: int,
610
+ imu_stem: PatchEmbedGeneric,
611
+ embed_dim: int,
612
+ img_size: Tuple = (6, 2000),
613
+ num_cls_tokens: int = 1,
614
+ pos_embed_fn: Optional[Callable] = None,
615
+ init_param_style: str = "openclip",
616
+ ) -> None:
617
+ super().__init__()
618
+ self.imu_stem = imu_stem
619
+ self.embed_dim = embed_dim
620
+ self.use_pos_embed = pos_embed_fn is not None
621
+ self.num_cls_tokens = num_cls_tokens
622
+ self.kernel_size = kernel_size
623
+ self.pos_embed = nn.Parameter(
624
+ torch.empty(1, (img_size[1] // kernel_size) + num_cls_tokens, embed_dim)
625
+ )
626
+
627
+ if self.num_cls_tokens > 0:
628
+ self.cls_token = nn.Parameter(
629
+ torch.zeros(1, self.num_cls_tokens, self.embed_dim)
630
+ )
631
+
632
+ self.init_parameters(init_param_style)
633
+
634
+ @torch.no_grad()
635
+ def init_parameters(self, init_param_style):
636
+ nn.init.normal_(self.pos_embed, std=0.01)
637
+
638
+ if init_param_style == "openclip":
639
+ # OpenCLIP style initialization
640
+ scale = self.embed_dim**-0.5
641
+
642
+ if self.num_cls_tokens > 0:
643
+ nn.init.normal_(self.cls_token)
644
+ self.cls_token *= scale
645
+ elif init_param_style == "vit":
646
+ self.cls_token.data.fill_(0)
647
+ else:
648
+ raise ValueError(f"Unknown init {init_param_style}")
649
+
650
+ def tokenize_input_and_cls_pos(self, input, stem):
651
+ # tokens is of shape B x L x D
652
+ tokens = stem.norm_layer(stem.proj(input))
653
+ assert tokens.ndim == 3
654
+ assert tokens.shape[2] == self.embed_dim
655
+ B = tokens.shape[0]
656
+ if self.num_cls_tokens > 0:
657
+ class_tokens = self.cls_token.expand(
658
+ B, -1, -1
659
+ ) # stole class_tokens impl from Phil Wang, thanks
660
+ tokens = torch.cat((class_tokens, tokens), dim=1)
661
+ if self.use_pos_embed:
662
+ tokens = tokens + self.pos_embed
663
+ return tokens
664
+
665
+ def forward(self, imu):
666
+ # Patchify
667
+ imu = imu.unfold(
668
+ -1,
669
+ self.kernel_size,
670
+ self.kernel_size,
671
+ ).permute(0, 2, 1, 3)
672
+ imu = imu.reshape(imu.size(0), imu.size(1), -1)
673
+
674
+ imu_tokens = self.tokenize_input_and_cls_pos(
675
+ imu,
676
+ self.imu_stem,
677
+ )
678
+
679
+ return_dict = {
680
+ "trunk": {
681
+ "tokens": imu_tokens,
682
+ },
683
+ "head": {},
684
+ }
685
+ return return_dict
imagebind/models/transformer.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # Code modified from
9
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py ;
10
+ # https://github.com/facebookresearch/deit/blob/main/models.py
11
+ # and https://github.com/facebookresearch/vissl/blob/main/vissl/models/trunks/vision_transformer.py
12
+
13
+
14
+ from functools import partial
15
+ from typing import Callable, List, Optional
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint as checkpoint
20
+ from timm.models.layers import DropPath, trunc_normal_
21
+
22
+
23
+ class Attention(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim,
27
+ num_heads=8,
28
+ qkv_bias=False,
29
+ qk_scale=None,
30
+ attn_drop=0.0,
31
+ proj_drop=0.0,
32
+ ):
33
+ super().__init__()
34
+ self.num_heads = num_heads
35
+ head_dim = dim // num_heads
36
+ # NOTE scale factor was wrong in my original version,
37
+ # can set manually to be compat with prev weights
38
+ self.scale = qk_scale or head_dim**-0.5
39
+
40
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
41
+ self.attn_drop = nn.Dropout(attn_drop)
42
+ self.proj = nn.Linear(dim, dim)
43
+ self.proj_drop = nn.Dropout(proj_drop)
44
+
45
+ def forward(self, x):
46
+ B, N, C = x.shape
47
+ qkv = (
48
+ self.qkv(x)
49
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
50
+ .permute(2, 0, 3, 1, 4)
51
+ )
52
+ q, k, v = (
53
+ qkv[0],
54
+ qkv[1],
55
+ qkv[2],
56
+ ) # make torchscript happy (cannot use tensor as tuple)
57
+
58
+ attn = (q @ k.transpose(-2, -1)) * self.scale
59
+ attn = attn.softmax(dim=-1)
60
+ attn = self.attn_drop(attn)
61
+
62
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
63
+ x = self.proj(x)
64
+ x = self.proj_drop(x)
65
+ return x
66
+
67
+
68
+ class Mlp(nn.Module):
69
+ def __init__(
70
+ self,
71
+ in_features,
72
+ hidden_features=None,
73
+ out_features=None,
74
+ act_layer=nn.GELU,
75
+ drop=0.0,
76
+ ):
77
+ super().__init__()
78
+ out_features = out_features or in_features
79
+ hidden_features = hidden_features or in_features
80
+ self.fc1 = nn.Linear(in_features, hidden_features)
81
+ self.act = act_layer()
82
+ self.fc2 = nn.Linear(hidden_features, out_features)
83
+ self.drop = nn.Dropout(drop)
84
+
85
+ def forward(self, x):
86
+ x = self.fc1(x)
87
+ x = self.act(x)
88
+ x = self.drop(x)
89
+ x = self.fc2(x)
90
+ x = self.drop(x)
91
+ return x
92
+
93
+
94
+ class MultiheadAttention(nn.MultiheadAttention):
95
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
96
+ return super().forward(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
97
+
98
+
99
+ class ViTAttention(Attention):
100
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
101
+ assert attn_mask is None
102
+ return super().forward(x)
103
+
104
+
105
+ class BlockWithMasking(nn.Module):
106
+ def __init__(
107
+ self,
108
+ dim: int,
109
+ attn_target: Callable,
110
+ mlp_ratio: int = 4,
111
+ act_layer: Callable = nn.GELU,
112
+ norm_layer: Callable = nn.LayerNorm,
113
+ ffn_dropout_rate: float = 0.0,
114
+ drop_path: float = 0.0,
115
+ layer_scale_type: Optional[str] = None,
116
+ layer_scale_init_value: float = 1e-4,
117
+ ):
118
+ super().__init__()
119
+
120
+ assert not isinstance(
121
+ attn_target, nn.Module
122
+ ), "attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
123
+ self.attn = attn_target()
124
+ if drop_path > 0.0:
125
+ self.drop_path = DropPath(drop_path)
126
+ else:
127
+ self.drop_path = nn.Identity()
128
+ self.norm_1 = norm_layer(dim)
129
+ mlp_hidden_dim = int(mlp_ratio * dim)
130
+ self.mlp = Mlp(
131
+ in_features=dim,
132
+ hidden_features=mlp_hidden_dim,
133
+ act_layer=act_layer,
134
+ drop=ffn_dropout_rate,
135
+ )
136
+ self.norm_2 = norm_layer(dim)
137
+ self.layer_scale_type = layer_scale_type
138
+ if self.layer_scale_type is not None:
139
+ assert self.layer_scale_type in [
140
+ "per_channel",
141
+ "scalar",
142
+ ], f"Found Layer scale type {self.layer_scale_type}"
143
+ if self.layer_scale_type == "per_channel":
144
+ # one gamma value per channel
145
+ gamma_shape = [1, 1, dim]
146
+ elif self.layer_scale_type == "scalar":
147
+ # single gamma value for all channels
148
+ gamma_shape = [1, 1, 1]
149
+ # two gammas: for each part of the fwd in the encoder
150
+ self.layer_scale_gamma1 = nn.Parameter(
151
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
152
+ requires_grad=True,
153
+ )
154
+ self.layer_scale_gamma2 = nn.Parameter(
155
+ torch.ones(size=gamma_shape) * layer_scale_init_value,
156
+ requires_grad=True,
157
+ )
158
+
159
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
160
+ if self.layer_scale_type is None:
161
+ x = x + self.drop_path(self.attn(self.norm_1(x), attn_mask))
162
+ x = x + self.drop_path(self.mlp(self.norm_2(x)))
163
+ else:
164
+ x = (
165
+ x
166
+ + self.drop_path(self.attn(self.norm_1(x), attn_mask))
167
+ * self.layer_scale_gamma1
168
+ )
169
+ x = x + self.drop_path(self.mlp(self.norm_2(x))) * self.layer_scale_gamma2
170
+ return x
171
+
172
+
173
+ _LAYER_NORM = partial(nn.LayerNorm, eps=1e-6)
174
+
175
+
176
+ class SimpleTransformer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ attn_target: Callable,
180
+ embed_dim: int,
181
+ num_blocks: int,
182
+ block: Callable = BlockWithMasking,
183
+ pre_transformer_layer: Optional[Callable] = None,
184
+ post_transformer_layer: Optional[Callable] = None,
185
+ drop_path_rate: float = 0.0,
186
+ drop_path_type: str = "progressive",
187
+ norm_layer: Callable = _LAYER_NORM,
188
+ mlp_ratio: int = 4,
189
+ ffn_dropout_rate: float = 0.0,
190
+ layer_scale_type: Optional[str] = None, # from cait; possible values are None, "per_channel", "scalar"
191
+ layer_scale_init_value: float = 1e-4, # from cait; float
192
+ weight_init_style: str = "jax", # possible values jax or pytorch
193
+ ):
194
+ """
195
+ Simple Transformer with the following features
196
+ 1. Supports masked attention
197
+ 2. Supports DropPath
198
+ 3. Supports LayerScale
199
+ 4. Supports Dropout in Attention and FFN
200
+ 5. Makes few assumptions about the input except that it is a Tensor
201
+ """
202
+ super().__init__()
203
+ self.pre_transformer_layer = pre_transformer_layer
204
+ if drop_path_type == "progressive":
205
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_blocks)]
206
+ elif drop_path_type == "uniform":
207
+ dpr = [drop_path_rate for i in range(num_blocks)]
208
+ else:
209
+ raise ValueError(f"Unknown drop_path_type: {drop_path_type}")
210
+
211
+ self.blocks = nn.Sequential(
212
+ *[
213
+ block(
214
+ dim=embed_dim,
215
+ attn_target=attn_target,
216
+ mlp_ratio=mlp_ratio,
217
+ ffn_dropout_rate=ffn_dropout_rate,
218
+ drop_path=dpr[i],
219
+ norm_layer=norm_layer,
220
+ layer_scale_type=layer_scale_type,
221
+ layer_scale_init_value=layer_scale_init_value,
222
+ )
223
+ for i in range(num_blocks)
224
+ ]
225
+ )
226
+ self.post_transformer_layer = post_transformer_layer
227
+ self.weight_init_style = weight_init_style
228
+ self.apply(self._init_weights)
229
+
230
+ def _init_weights(self, m):
231
+ if isinstance(m, nn.Linear):
232
+ if self.weight_init_style == "jax":
233
+ # Based on MAE and official Jax ViT implementation
234
+ torch.nn.init.xavier_uniform_(m.weight)
235
+ elif self.weight_init_style == "pytorch":
236
+ # PyTorch ViT uses trunc_normal_
237
+ trunc_normal_(m.weight, std=0.02)
238
+
239
+ if m.bias is not None:
240
+ nn.init.constant_(m.bias, 0)
241
+ elif isinstance(m, (nn.LayerNorm)):
242
+ nn.init.constant_(m.bias, 0)
243
+ nn.init.constant_(m.weight, 1.0)
244
+
245
+ def forward(
246
+ self,
247
+ tokens: torch.Tensor,
248
+ attn_mask: torch.Tensor = None,
249
+ use_checkpoint: bool = False,
250
+ checkpoint_every_n: int = 1,
251
+ checkpoint_blk_ids: Optional[List[int]] = None,
252
+ ):
253
+ """
254
+ Inputs
255
+ - tokens: data of shape N x L x D (or L x N x D depending on the attention implementation)
256
+ - attn: mask of shape L x L
257
+
258
+ Output
259
+ - x: data of shape N x L x D (or L x N x D depending on the attention implementation)
260
+ """
261
+ if self.pre_transformer_layer:
262
+ tokens = self.pre_transformer_layer(tokens)
263
+ if use_checkpoint and checkpoint_blk_ids is None:
264
+ checkpoint_blk_ids = [
265
+ blk_id
266
+ for blk_id in range(len(self.blocks))
267
+ if blk_id % checkpoint_every_n == 0
268
+ ]
269
+ if checkpoint_blk_ids:
270
+ checkpoint_blk_ids = set(checkpoint_blk_ids)
271
+ for blk_id, blk in enumerate(self.blocks):
272
+ if use_checkpoint and blk_id in checkpoint_blk_ids:
273
+ tokens = checkpoint.checkpoint(
274
+ blk, tokens, attn_mask, use_reentrant=False
275
+ )
276
+ else:
277
+ tokens = blk(tokens, attn_mask=attn_mask)
278
+ if self.post_transformer_layer:
279
+ tokens = self.post_transformer_layer(tokens)
280
+ return tokens
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.0
2
+ torchvision==0.14.0
3
+ torchaudio==0.13.0
4
+ pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
5
+ timm==0.6.7
6
+ ftfy
7
+ regex
8
+ einops
9
+ fvcore
10
+ eva-decord==0.6.1
11
+ iopath
12
+ numpy>=1.19
13
+ matplotlib
14
+ types-regex
15
+ mayavi
16
+ cartopy
setup.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ with open('requirements.txt') as f:
4
+ required = f.read().splitlines()
5
+
6
+ setup(
7
+ name='imagebind',
8
+ version='0.1.0',
9
+ packages=find_packages(),
10
+ description='A brief description of the package',
11
+ long_description=open('README.md').read(),
12
+ long_description_content_type="text/markdown",
13
+ url='https://github.com/facebookresearch/ImageBind',
14
+ classifiers=[
15
+ 'Programming Language :: Python :: 3',
16
+ 'License :: Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International',
17
+ ],
18
+ install_requires=required,
19
+ dependency_links=['https://download.pytorch.org/whl/cu113'],
20
+ )