# Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from typing import Union, Tuple from types import MethodType import torch from torch import nn from timm.models import VisionTransformer, checkpoint_seq from .radio_vit_patch_generator import ViTPatchGenerator def _forward_cpe(self: VisionTransformer, x: torch.Tensor) -> torch.Tensor: x = self.patch_generator(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) x = self.norm(x) return x def enable_cpe(model: nn.Module, max_img_size: Union[int, Tuple[int, int]] = 1024, num_cls_tokens: int = 1, pos_dropout: float = 0.1, register_multiple: int = 0, ): if not isinstance(model, VisionTransformer): raise ValueError("CPE only support for VisionTransformer models!") patch_size = model.patch_embed.patch_size[0] embed_dim = model.embed_dim input_dims = model.patch_embed.img_size normalize_patches = not isinstance(model.patch_embed.norm, nn.Identity) cls_token = model.cls_token is not None max_img_size = int(round(max_img_size / patch_size) * patch_size) patch_generator = ViTPatchGenerator( patch_size=patch_size, embed_dim=embed_dim, input_dims=input_dims, normalize_patches=normalize_patches, cls_token=cls_token, max_input_dims=max_img_size, pos_dropout=pos_dropout, num_cls_tokens=num_cls_tokens, register_multiple=register_multiple, ) model.patch_generator = patch_generator model.patch_embed = None model.cls_token = None model.pos_embed = None model.pos_drop = None model.num_cls_tokens = num_cls_tokens model.num_registers = patch_generator.num_registers model.forward_features = MethodType(_forward_cpe, model)