|
""" PyTorch Feature Extraction Helpers |
|
|
|
A collection of classes, functions, modules to help extract features from models |
|
and provide a common interface for describing them. |
|
|
|
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter |
|
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
from collections import OrderedDict, defaultdict |
|
from copy import deepcopy |
|
from functools import partial |
|
from typing import Dict, List, Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from timm.layers import Format |
|
|
|
|
|
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] |
|
|
|
|
|
class FeatureInfo: |
|
|
|
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): |
|
prev_reduction = 1 |
|
for fi in feature_info: |
|
|
|
assert 'num_chs' in fi and fi['num_chs'] > 0 |
|
assert 'reduction' in fi and fi['reduction'] >= prev_reduction |
|
prev_reduction = fi['reduction'] |
|
assert 'module' in fi |
|
self.out_indices = out_indices |
|
self.info = feature_info |
|
|
|
def from_other(self, out_indices: Tuple[int]): |
|
return FeatureInfo(deepcopy(self.info), out_indices) |
|
|
|
def get(self, key, idx=None): |
|
""" Get value by key at specified index (indices) |
|
if idx == None, returns value for key at each output index |
|
if idx is an integer, return value for that feature module index (ignoring output indices) |
|
if idx is a list/tupple, return value for each module index (ignoring output indices) |
|
""" |
|
if idx is None: |
|
return [self.info[i][key] for i in self.out_indices] |
|
if isinstance(idx, (tuple, list)): |
|
return [self.info[i][key] for i in idx] |
|
else: |
|
return self.info[idx][key] |
|
|
|
def get_dicts(self, keys=None, idx=None): |
|
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) |
|
""" |
|
if idx is None: |
|
if keys is None: |
|
return [self.info[i] for i in self.out_indices] |
|
else: |
|
return [{k: self.info[i][k] for k in keys} for i in self.out_indices] |
|
if isinstance(idx, (tuple, list)): |
|
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] |
|
else: |
|
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} |
|
|
|
def channels(self, idx=None): |
|
""" feature channels accessor |
|
""" |
|
return self.get('num_chs', idx) |
|
|
|
def reduction(self, idx=None): |
|
""" feature reduction (output stride) accessor |
|
""" |
|
return self.get('reduction', idx) |
|
|
|
def module_name(self, idx=None): |
|
""" feature module name accessor |
|
""" |
|
return self.get('module', idx) |
|
|
|
def __getitem__(self, item): |
|
return self.info[item] |
|
|
|
def __len__(self): |
|
return len(self.info) |
|
|
|
|
|
class FeatureHooks: |
|
""" Feature Hook Helper |
|
|
|
This module helps with the setup and extraction of hooks for extracting features from |
|
internal nodes in a model by node name. |
|
|
|
FIXME This works well in eager Python but needs redesign for torchscript. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hooks: Sequence[str], |
|
named_modules: dict, |
|
out_map: Sequence[Union[int, str]] = None, |
|
default_hook_type: str = 'forward', |
|
): |
|
|
|
self._feature_outputs = defaultdict(OrderedDict) |
|
modules = {k: v for k, v in named_modules} |
|
for i, h in enumerate(hooks): |
|
hook_name = h['module'] |
|
m = modules[hook_name] |
|
hook_id = out_map[i] if out_map else hook_name |
|
hook_fn = partial(self._collect_output_hook, hook_id) |
|
hook_type = h.get('hook_type', default_hook_type) |
|
if hook_type == 'forward_pre': |
|
m.register_forward_pre_hook(hook_fn) |
|
elif hook_type == 'forward': |
|
m.register_forward_hook(hook_fn) |
|
else: |
|
assert False, "Unsupported hook type" |
|
|
|
def _collect_output_hook(self, hook_id, *args): |
|
x = args[-1] |
|
if isinstance(x, tuple): |
|
x = x[0] |
|
self._feature_outputs[x.device][hook_id] = x |
|
|
|
def get_output(self, device) -> Dict[str, torch.tensor]: |
|
output = self._feature_outputs[device] |
|
self._feature_outputs[device] = OrderedDict() |
|
return output |
|
|
|
|
|
def _module_list(module, flatten_sequential=False): |
|
|
|
ml = [] |
|
for name, module in module.named_children(): |
|
if flatten_sequential and isinstance(module, nn.Sequential): |
|
|
|
for child_name, child_module in module.named_children(): |
|
combined = [name, child_name] |
|
ml.append(('_'.join(combined), '.'.join(combined), child_module)) |
|
else: |
|
ml.append((name, name, module)) |
|
return ml |
|
|
|
|
|
def _get_feature_info(net, out_indices): |
|
feature_info = getattr(net, 'feature_info') |
|
if isinstance(feature_info, FeatureInfo): |
|
return feature_info.from_other(out_indices) |
|
elif isinstance(feature_info, (list, tuple)): |
|
return FeatureInfo(net.feature_info, out_indices) |
|
else: |
|
assert False, "Provided feature_info is not valid" |
|
|
|
|
|
def _get_return_layers(feature_info, out_map): |
|
module_names = feature_info.module_name() |
|
return_layers = {} |
|
for i, name in enumerate(module_names): |
|
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] |
|
return return_layers |
|
|
|
|
|
class FeatureDictNet(nn.ModuleDict): |
|
""" Feature extractor with OrderedDict return |
|
|
|
Wrap a model and extract features as specified by the out indices, the network is |
|
partially re-built from contained modules. |
|
|
|
There is a strong assumption that the modules have been registered into the model in the same |
|
order as they are used. There should be no reuse of the same nn.Module more than once, including |
|
trivial modules like `self.relu = nn.ReLU`. |
|
|
|
Only submodules that are directly assigned to the model class (`model.feature1`) or at most |
|
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. |
|
All Sequential containers that are directly assigned to the original model will have their |
|
modules assigned to this module with the name `model.features.1` being changed to `model.features_1` |
|
""" |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), |
|
out_map: Sequence[Union[int, str]] = None, |
|
output_fmt: str = 'NCHW', |
|
feature_concat: bool = False, |
|
flatten_sequential: bool = False, |
|
): |
|
""" |
|
Args: |
|
model: Model from which to extract features. |
|
out_indices: Output indices of the model features to extract. |
|
out_map: Return id mapping for each output index, otherwise str(index) is used. |
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting |
|
first element e.g. `x[0]` |
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) |
|
""" |
|
super(FeatureDictNet, self).__init__() |
|
self.feature_info = _get_feature_info(model, out_indices) |
|
self.output_fmt = Format(output_fmt) |
|
self.concat = feature_concat |
|
self.grad_checkpointing = False |
|
self.return_layers = {} |
|
|
|
return_layers = _get_return_layers(self.feature_info, out_map) |
|
modules = _module_list(model, flatten_sequential=flatten_sequential) |
|
remaining = set(return_layers.keys()) |
|
layers = OrderedDict() |
|
for new_name, old_name, module in modules: |
|
layers[new_name] = module |
|
if old_name in remaining: |
|
|
|
self.return_layers[new_name] = str(return_layers[old_name]) |
|
remaining.remove(old_name) |
|
if not remaining: |
|
break |
|
assert not remaining and len(self.return_layers) == len(return_layers), \ |
|
f'Return layers ({remaining}) are not present in model' |
|
self.update(layers) |
|
|
|
def set_grad_checkpointing(self, enable: bool = True): |
|
self.grad_checkpointing = enable |
|
|
|
def _collect(self, x) -> (Dict[str, torch.Tensor]): |
|
out = OrderedDict() |
|
for i, (name, module) in enumerate(self.items()): |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
|
|
|
|
|
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0) |
|
x = module(x) if first_or_last_module else checkpoint(module, x) |
|
else: |
|
x = module(x) |
|
|
|
if name in self.return_layers: |
|
out_id = self.return_layers[name] |
|
if isinstance(x, (tuple, list)): |
|
|
|
|
|
out[out_id] = torch.cat(x, 1) if self.concat else x[0] |
|
else: |
|
out[out_id] = x |
|
return out |
|
|
|
def forward(self, x) -> Dict[str, torch.Tensor]: |
|
return self._collect(x) |
|
|
|
|
|
class FeatureListNet(FeatureDictNet): |
|
""" Feature extractor with list return |
|
|
|
A specialization of FeatureDictNet that always returns features as a list (values() of dict). |
|
""" |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), |
|
output_fmt: str = 'NCHW', |
|
feature_concat: bool = False, |
|
flatten_sequential: bool = False, |
|
): |
|
""" |
|
Args: |
|
model: Model from which to extract features. |
|
out_indices: Output indices of the model features to extract. |
|
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting |
|
first element e.g. `x[0]` |
|
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) |
|
""" |
|
super().__init__( |
|
model, |
|
out_indices=out_indices, |
|
output_fmt=output_fmt, |
|
feature_concat=feature_concat, |
|
flatten_sequential=flatten_sequential, |
|
) |
|
|
|
def forward(self, x) -> (List[torch.Tensor]): |
|
return list(self._collect(x).values()) |
|
|
|
|
|
class FeatureHookNet(nn.ModuleDict): |
|
""" FeatureHookNet |
|
|
|
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. |
|
|
|
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying |
|
network in any way. |
|
|
|
If `no_rewrite` is False, the model will be re-written as in the |
|
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. |
|
|
|
FIXME this does not currently work with Torchscript, see FeatureHooks class |
|
""" |
|
def __init__( |
|
self, |
|
model: nn.Module, |
|
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), |
|
out_map: Sequence[Union[int, str]] = None, |
|
return_dict: bool = False, |
|
output_fmt: str = 'NCHW', |
|
no_rewrite: bool = False, |
|
flatten_sequential: bool = False, |
|
default_hook_type: str = 'forward', |
|
): |
|
""" |
|
|
|
Args: |
|
model: Model from which to extract features. |
|
out_indices: Output indices of the model features to extract. |
|
out_map: Return id mapping for each output index, otherwise str(index) is used. |
|
return_dict: Output features as a dict. |
|
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed. |
|
flatten_sequential arg must also be False if this is set True. |
|
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers. |
|
default_hook_type: The default hook type to use if not specified in model.feature_info. |
|
""" |
|
super().__init__() |
|
assert not torch.jit.is_scripting() |
|
self.feature_info = _get_feature_info(model, out_indices) |
|
self.return_dict = return_dict |
|
self.output_fmt = Format(output_fmt) |
|
self.grad_checkpointing = False |
|
|
|
layers = OrderedDict() |
|
hooks = [] |
|
if no_rewrite: |
|
assert not flatten_sequential |
|
if hasattr(model, 'reset_classifier'): |
|
model.reset_classifier(0) |
|
layers['body'] = model |
|
hooks.extend(self.feature_info.get_dicts()) |
|
else: |
|
modules = _module_list(model, flatten_sequential=flatten_sequential) |
|
remaining = { |
|
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type |
|
for f in self.feature_info.get_dicts() |
|
} |
|
for new_name, old_name, module in modules: |
|
layers[new_name] = module |
|
for fn, fm in module.named_modules(prefix=old_name): |
|
if fn in remaining: |
|
hooks.append(dict(module=fn, hook_type=remaining[fn])) |
|
del remaining[fn] |
|
if not remaining: |
|
break |
|
assert not remaining, f'Return layers ({remaining}) are not present in model' |
|
self.update(layers) |
|
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) |
|
|
|
def set_grad_checkpointing(self, enable: bool = True): |
|
self.grad_checkpointing = enable |
|
|
|
def forward(self, x): |
|
for i, (name, module) in enumerate(self.items()): |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
|
|
|
|
|
|
first_or_last_module = i == 0 or i == max(len(self) - 1, 0) |
|
x = module(x) if first_or_last_module else checkpoint(module, x) |
|
else: |
|
x = module(x) |
|
out = self.hooks.get_output(x.device) |
|
return out if self.return_dict else list(out.values()) |
|
|