# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
import torch | |
from torch import nn | |
class ClsToken(nn.Module): | |
def __init__(self, ndim: int, | |
num_tokens: int = 1, | |
enabled: bool = True, | |
register_multiple: int = 0, | |
): | |
super().__init__() | |
self.ndim = ndim | |
self.enabled = enabled | |
self.num_registers = 0 | |
self.num_tokens = num_tokens | |
if enabled: | |
if register_multiple > 0: | |
self.num_registers = register_multiple - (num_tokens % register_multiple) | |
scale = ndim ** -0.5 | |
self.token = nn.Parameter(torch.randn(num_tokens + self.num_registers, ndim) * scale) | |
else: | |
self.token = None | |
self.num_patches = self.num_tokens + self.num_registers | |
def disable(self): | |
self.token = None | |
self.enabled = False | |
def forward(self, x: torch.Tensor): | |
if self.token is None: | |
return x | |
token = self.token.unsqueeze(0).expand(x.shape[0], -1, -1) | |
x = torch.cat([ | |
token, | |
x, | |
], dim=1) | |
return x | |
def no_weight_decay(self): | |
return [ | |
'token', | |
] | |