File size: 4,473 Bytes
da716ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
""" ConvMixer
"""
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import SelectAdaptivePool2d
from ._registry import register_model, generate_default_cfgs
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq
__all__ = ['ConvMixer']
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class ConvMixer(nn.Module):
def __init__(
self,
dim,
depth,
kernel_size=9,
patch_size=7,
in_chans=3,
num_classes=1000,
global_pool='avg',
drop_rate=0.,
act_layer=nn.GELU,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.grad_checkpointing = False
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dim, kernel_size=patch_size, stride=patch_size),
act_layer(),
nn.BatchNorm2d(dim)
)
self.blocks = nn.Sequential(
*[nn.Sequential(
Residual(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
act_layer(),
nn.BatchNorm2d(dim)
)),
nn.Conv2d(dim, dim, kernel_size=1),
act_layer(),
nn.BatchNorm2d(dim)
) for i in range(depth)]
)
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head_drop = nn.Dropout(drop_rate)
self.head = nn.Linear(dim, num_classes) if num_classes > 0 else nn.Identity()
@torch.jit.ignore
def group_matcher(self, coarse=False):
matcher = dict(stem=r'^stem', blocks=r'^blocks\.(\d+)')
return matcher
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.pooling(x)
x = self.head_drop(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_convmixer(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ConvMixer, variant, pretrained, **kwargs)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .96, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head',
'first_conv': 'stem.0',
**kwargs
}
default_cfgs = generate_default_cfgs({
'convmixer_1536_20.in1k': _cfg(hf_hub_id='timm/'),
'convmixer_768_32.in1k': _cfg(hf_hub_id='timm/'),
'convmixer_1024_20_ks9_p14.in1k': _cfg(hf_hub_id='timm/')
})
@register_model
def convmixer_1536_20(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=1536, depth=20, kernel_size=9, patch_size=7, **kwargs)
return _create_convmixer('convmixer_1536_20', pretrained, **model_args)
@register_model
def convmixer_768_32(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=768, depth=32, kernel_size=7, patch_size=7, act_layer=nn.ReLU, **kwargs)
return _create_convmixer('convmixer_768_32', pretrained, **model_args)
@register_model
def convmixer_1024_20_ks9_p14(pretrained=False, **kwargs) -> ConvMixer:
model_args = dict(dim=1024, depth=20, kernel_size=9, patch_size=14, **kwargs)
return _create_convmixer('convmixer_1024_20_ks9_p14', pretrained, **model_args) |