|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Union, NamedTuple |
|
import torch |
|
from transformers import PretrainedConfig |
|
|
|
from .radio_common import RESOURCE_MAP, DEFAULT_VERSION |
|
|
|
from .radio_model import Resolution |
|
|
|
class RADIOConfig(PretrainedConfig): |
|
"""Pretrained Hugging Face configuration for RADIO models.""" |
|
|
|
def __init__( |
|
self, |
|
args: Optional[dict] = None, |
|
version: Optional[str] = DEFAULT_VERSION, |
|
patch_size: Optional[int] = None, |
|
max_resolution: Optional[int] = None, |
|
preferred_resolution: Optional[Resolution] = None, |
|
adaptor_names: Union[str, List[str]] = None, |
|
vitdet_window_size: Optional[int] = None, |
|
**kwargs, |
|
): |
|
self.args = args |
|
for field in ["dtype", "amp_dtype"]: |
|
if self.args is not None and field in self.args: |
|
|
|
|
|
|
|
self.args[field] = str(args[field]).split(".")[-1] |
|
self.version = version |
|
resource = RESOURCE_MAP[version] |
|
self.patch_size = patch_size or resource.patch_size |
|
self.max_resolution = max_resolution or resource.max_resolution |
|
self.preferred_resolution = ( |
|
preferred_resolution or resource.preferred_resolution |
|
) |
|
self.adaptor_names = adaptor_names |
|
self.vitdet_window_size = vitdet_window_size |
|
super().__init__(**kwargs) |