# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
from argparse import Namespace | |
from typing import Dict, Any | |
import torch | |
from .radio_adaptor_generic import GenericAdaptor, AdaptorBase | |
dict_t = Dict[str, Any] | |
state_t = Dict[str, torch.Tensor] | |
class AdaptorRegistry: | |
def __init__(self): | |
self._registry = {} | |
def register_adaptor(self, name): | |
def decorator(factory_function): | |
if name in self._registry: | |
raise ValueError(f"Model '{name}' already registered") | |
self._registry[name] = factory_function | |
return factory_function | |
return decorator | |
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: | |
if name not in self._registry: | |
return GenericAdaptor(main_config, adaptor_config, state) | |
return self._registry[name](main_config, adaptor_config, state) | |
# Creating an instance of the registry | |
adaptor_registry = AdaptorRegistry() | |