|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import Namespace |
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
|
|
from .radio_adaptor_registry import adaptor_registry, dict_t, state_t |
|
|
|
from .radio_adaptor_generic import GenericAdaptor |
|
|
|
|
|
class OpenCLIP_RADIO(GenericAdaptor): |
|
def __init__(self, main_config: Namespace, adaptor_config: dict_t, state: state_t): |
|
super().__init__(main_config, adaptor_config, state) |
|
|
|
import open_clip |
|
|
|
self.oc_model = open_clip.create_model_from_pretrained( |
|
model_name=adaptor_config['model'], |
|
pretrained=adaptor_config['pretrained'], |
|
return_transform=False, |
|
) |
|
|
|
self.oc_model.visual = None |
|
|
|
self.tokenizer = open_clip.get_tokenizer(model_name=adaptor_config['model']) |
|
|
|
def encode_text(self, text, normalize: bool = False): |
|
return self.oc_model.encode_text(text, normalize=normalize) |
|
|
|
|
|
@adaptor_registry.register_adaptor("open_clip") |
|
def create_open_clip_adaptor(main_config: Namespace, adaptor_config: dict_t, state: state_t): |
|
return OpenCLIP_RADIO(main_config, adaptor_config, state) |
|
|