TimmWrapper
Overview
Helper class to enable loading timm models to be used with the transformers library and its autoclasses.
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)
TimmWrapperConfig
class transformers.TimmWrapperConfig
< source >( initializer_range: float = 0.02 do_pooling: bool = True **kwargs )
This is the configuration class to store the configuration for a timm backbone TimmWrapper
.
It is used to instantiate a timm model according to the specified arguments, defining the model.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Config loads imagenet label descriptions and stores them in id2label
attribute, label2id
attribute for default
imagenet models is set to None
due to occlusions in the label descriptions.
TimmWrapperImageProcessor
class transformers.TimmWrapperImageProcessor
< source >( pretrained_cfg: typing.Dict[str, typing.Any] architecture: typing.Optional[str] = None **kwargs )
Wrapper class for timm models to be used within transformers.
preprocess
< source >( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), typing.List[ForwardRef('PIL.Image.Image')], typing.List[numpy.ndarray], typing.List[ForwardRef('torch.Tensor')]] return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = 'pt' )
Preprocess an image or batch of images.
TimmWrapperModel
Wrapper class for timm models to be used in transformers.
forward
< source >( pixel_values: FloatTensor output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Union[bool, typing.List[int], NoneType] = None return_dict: typing.Optional[bool] = None do_pooling: typing.Optional[bool] = None **kwargs ) β transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput
or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See TimmWrapperImageProcessor.preprocess() for details. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - **kwargs —
Additional keyword arguments passed along to the
timm
model forward. - do_pooling (
bool
, optional) — Whether to do pooling for the last_hidden_state inTimmWrapperModel
or not. IfNone
is passed, thedo_pooling
value from the config is used.
Returns
transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput
or tuple(torch.FloatTensor)
A transformers.models.timm_wrapper.modeling_timm_wrapper.TimmWrapperModelOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (<class 'transformers.models.timm_wrapper.configuration_timm_wrapper.TimmWrapperConfig'>
) and inputs.
- last_hidden_state (
torch.FloatTensor
) β The last hidden state of the model, output before applying the classification head. - pooler_output (
torch.FloatTensor
, optional) β The pooled output derived from the last hidden state, if applicable. - hidden_states (
tuple(torch.FloatTensor)
, optional) β A tuple containing the intermediate hidden states of the model at the output of each layer or specified layers. Returned ifoutput_hidden_states=True
is set or ifconfig.output_hidden_states=True
. - attentions (
tuple(torch.FloatTensor)
, optional) β A tuple containing the intermediate attention weights of the model at the output of each layer. Returned ifoutput_attentions=True
is set or ifconfig.output_attentions=True
. Note: Currently, Timm models do not support attentions output.
The TimmWrapperModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Examples:
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModel, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModel.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... outputs = model(**inputs)
>>> # Get pooled output
>>> pooled_output = outputs.pooler_output
>>> # Get last hidden state
>>> last_hidden_state = outputs.last_hidden_state
TimmWrapperForImageClassification
Wrapper class for timm models to be used in transformers for image classification.
forward
< source >( pixel_values: FloatTensor labels: typing.Optional[torch.LongTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Union[bool, typing.List[int], NoneType] = None return_dict: typing.Optional[bool] = None **kwargs ) β transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See TimmWrapperImageProcessor.preprocess() for details. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Not compatible with timm wrapped models. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Not compatible with timm wrapped models. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - **kwargs —
Additional keyword arguments passed along to the
timm
model forward. - labels (
torch.LongTensor
of shape(batch_size,)
, optional) — Labels for computing the image classification/regression loss. Indices should be in[0, ..., config.num_labels - 1]
. Ifconfig.num_labels == 1
a regression loss is computed (Mean-Square loss), Ifconfig.num_labels > 1
a classification loss is computed (Cross-Entropy).
Returns
transformers.modeling_outputs.ImageClassifierOutput or tuple(torch.FloatTensor)
A transformers.modeling_outputs.ImageClassifierOutput or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (<class 'transformers.models.timm_wrapper.configuration_timm_wrapper.TimmWrapperConfig'>
) and inputs.
-
loss (
torch.FloatTensor
of shape(1,)
, optional, returned whenlabels
is provided) β Classification (or regression if config.num_labels==1) loss. -
logits (
torch.FloatTensor
of shape(batch_size, config.num_labels)
) β Classification (or regression if config.num_labels==1) scores (before SoftMax). -
hidden_states (
tuple(torch.FloatTensor)
, optional, returned whenoutput_hidden_states=True
is passed or whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings, if the model has an embedding layer, + one for the output of each stage) of shape(batch_size, sequence_length, hidden_size)
. Hidden-states (also called feature maps) of the model at the output of each stage. -
attentions (
tuple(torch.FloatTensor)
, optional, returned whenoutput_attentions=True
is passed or whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, patch_size, sequence_length)
.Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
The TimmWrapperForImageClassification forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Examples:
>>> import torch
>>> from PIL import Image
>>> from urllib.request import urlopen
>>> from transformers import AutoModelForImageClassification, AutoImageProcessor
>>> # Load image
>>> image = Image.open(urlopen(
... 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
... ))
>>> # Load model and image processor
>>> checkpoint = "timm/resnet50.a1_in1k"
>>> image_processor = AutoImageProcessor.from_pretrained(checkpoint)
>>> model = AutoModelForImageClassification.from_pretrained(checkpoint).eval()
>>> # Preprocess image
>>> inputs = image_processor(image)
>>> # Forward pass
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # Get top 5 predictions
>>> top5_probabilities, top5_class_indices = torch.topk(logits.softmax(dim=1) * 100, k=5)