""" ConvMixer """ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import SelectAdaptivePool2d from ._registry import register_model, generate_default_cfgs from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq __all__ = ['ConvMixer'] class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x class ConvMixer(nn.Module): def __init__( self, dim, depth, kernel_size=9, patch_size=7, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., act_layer=nn.GELU, **kwargs, ): super().__init__() self.num_classes = num_classes self.num_features = dim self.grad_checkpointing = False self.stem = nn.Sequential( nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size), act_layer(), nn.BatchNorm2d(dim) ) self.blocks = nn.Sequential( *[nn.Sequential( Residual(nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), act_layer(), nn.BatchNorm2d(dim) )), nn.Conv2d(dim, dim, kernel_size=1), act_layer(), nn.BatchNorm2d(dim) ) for i in range(depth)] ) self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head_drop = nn.Dropout(drop_rate) self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity() @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)') return matcher @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): return self.head def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes if global_pool is not None: self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) return x def forward_head(self, x, pre_logits: bool = False): x = self.pooling(x) x = self.head_drop(x) return x if pre_logits else self.head(x) def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) return x def _create_convmixer(variant, pretrained=False, **kwargs): return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs) def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .96, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head', 'first_conv': 'stem.0', **kwargs } default_cfgs = generate_default_cfgs({ 'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'), 'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'), 'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/') }) @register_model def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer: model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs) return _create_convmixer('convmixer_1536_20', pretrained, **model_args) @register_model def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer: model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs) return _create_convmixer('convmixer_768_32', pretrained, **model_args) @register_model def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer: model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs) return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args)