Spaces:
Running
on
Zero
Running
on
Zero
import subprocess | |
import sys | |
from pathlib import Path | |
import torch | |
import torchvision.transforms as tvf | |
from .. import logger | |
from ..utils.base_model import BaseModel | |
fire_path = Path(__file__).parent / "../../third_party/fire" | |
sys.path.append(str(fire_path)) | |
import fire_network | |
EPS = 1e-6 | |
class FIRe(BaseModel): | |
default_conf = { | |
"global": True, | |
"asmk": False, | |
"model_name": "fire_SfM_120k.pth", | |
"scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], # default params | |
"features_num": 1000, | |
"asmk_name": "asmk_codebook.bin", | |
"config_name": "eval_fire.yml", | |
} | |
required_inputs = ["image"] | |
# Models exported using | |
fire_models = { | |
"fire_SfM_120k.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/official/fire.pth", | |
"fire_imagenet.pth": "http://download.europe.naverlabs.com/ComputerVision/FIRe/pretraining/fire_imagenet.pth", | |
} | |
def _init(self, conf): | |
assert conf["model_name"] in self.fire_models.keys() | |
# Config paths | |
model_path = fire_path / "model" / conf["model_name"] | |
config_path = fire_path / conf["config_name"] # noqa: F841 | |
asmk_bin_path = fire_path / "model" / conf["asmk_name"] # noqa: F841 | |
# Download the model. | |
if not model_path.exists(): | |
model_path.parent.mkdir(exist_ok=True) | |
link = self.fire_models[conf["model_name"]] | |
cmd = ["wget", "--quiet", link, "-O", str(model_path)] | |
logger.info(f"Downloading the FIRe model with `{cmd}`.") | |
subprocess.run(cmd, check=True) | |
logger.info("Loading fire model...") | |
# Load net | |
state = torch.load(model_path) | |
state["net_params"]["pretrained"] = None | |
net = fire_network.init_network(**state["net_params"]) | |
net.load_state_dict(state["state_dict"]) | |
self.net = net | |
self.norm_rgb = tvf.Normalize( | |
**dict(zip(["mean", "std"], net.runtime["mean_std"])) | |
) | |
# params | |
self.scales = conf["scales"] | |
self.features_num = conf["features_num"] | |
def _forward(self, data): | |
image = self.norm_rgb(data["image"]) | |
local_desc = self.net.forward_local( | |
image, features_num=self.features_num, scales=self.scales | |
) | |
logger.info(f"output[0].shape = {local_desc[0].shape}\n") | |
return { | |
# 'global_descriptor': desc | |
"local_descriptor": local_desc | |
} | |