# api.py import warnings from pathlib import Path from typing import Any, Dict, Optional import cv2 import matplotlib.pyplot as plt import numpy as np import torch from ..hloc import extract_features, logger, match_dense, match_features from ..hloc.utils.viz import add_text, plot_keypoints from ..ui.utils import filter_matches, get_feature_model, get_model from ..ui.viz import display_matches, fig2im, plot_images warnings.simplefilter("ignore") class ImageMatchingAPI(torch.nn.Module): default_conf = { "ransac": { "enable": True, "estimator": "poselib", "geometry": "homography", "method": "RANSAC", "reproj_threshold": 3, "confidence": 0.9999, "max_iter": 10000, }, } def __init__( self, conf: dict = {}, device: str = "cpu", detect_threshold: float = 0.015, max_keypoints: int = 1024, match_threshold: float = 0.2, ) -> None: """ Initializes an instance of the ImageMatchingAPI class. Args: conf (dict): A dictionary containing the configuration parameters. device (str, optional): The device to use for computation. Defaults to "cpu". detect_threshold (float, optional): The threshold for detecting keypoints. Defaults to 0.015. max_keypoints (int, optional): The maximum number of keypoints to extract. Defaults to 1024. match_threshold (float, optional): The threshold for matching keypoints. Defaults to 0.2. Returns: None """ super().__init__() self.device = device self.conf = {**self.default_conf, **conf} self._updata_config(detect_threshold, max_keypoints, match_threshold) self._init_models() if device == "cuda": memory_allocated = torch.cuda.memory_allocated(device) memory_reserved = torch.cuda.memory_reserved(device) logger.info(f"GPU memory allocated: {memory_allocated / 1024**2:.3f} MB") logger.info(f"GPU memory reserved: {memory_reserved / 1024**2:.3f} MB") self.pred = None def parse_match_config(self, conf): if conf["dense"]: return { **conf, "matcher": match_dense.confs.get(conf["matcher"]["model"]["name"]), "dense": True, } else: return { **conf, "feature": extract_features.confs.get(conf["feature"]["model"]["name"]), "matcher": match_features.confs.get(conf["matcher"]["model"]["name"]), "dense": False, } def _updata_config( self, detect_threshold: float = 0.015, max_keypoints: int = 1024, match_threshold: float = 0.2, ): self.dense = self.conf["dense"] if self.conf["dense"]: try: self.conf["matcher"]["model"]["match_threshold"] = match_threshold except TypeError as e: logger.error(e) else: self.conf["feature"]["model"]["max_keypoints"] = max_keypoints self.conf["feature"]["model"]["keypoint_threshold"] = detect_threshold self.extract_conf = self.conf["feature"] self.match_conf = self.conf["matcher"] def _init_models(self): # initialize matcher self.matcher = get_model(self.match_conf) # initialize extractor if self.dense: self.extractor = None else: self.extractor = get_feature_model(self.conf["feature"]) def _forward(self, img0, img1): if self.dense: pred = match_dense.match_images( self.matcher, img0, img1, self.match_conf["preprocessing"], device=self.device, ) last_fixed = "{}".format( # noqa: F841 self.match_conf["model"]["name"] ) else: pred0 = extract_features.extract( self.extractor, img0, self.extract_conf["preprocessing"] ) pred1 = extract_features.extract( self.extractor, img1, self.extract_conf["preprocessing"] ) pred = match_features.match_images(self.matcher, pred0, pred1) return pred def _convert_pred(self, pred): ret = { k: v.cpu().detach()[0].numpy() if isinstance(v, torch.Tensor) else v for k, v in pred.items() } ret = { k: v[0].cpu().detach().numpy() if isinstance(v, list) else v for k, v in ret.items() } return ret @torch.inference_mode() def extract(self, img0: np.ndarray, **kwargs) -> Dict[str, np.ndarray]: """Extract features from a single image. Args: img0 (np.ndarray): image Returns: Dict[str, np.ndarray]: feature dict """ # setting prams self.extractor.conf["max_keypoints"] = kwargs.get("max_keypoints", 512) self.extractor.conf["keypoint_threshold"] = kwargs.get( "keypoint_threshold", 0.0 ) pred = extract_features.extract( self.extractor, img0, self.extract_conf["preprocessing"] ) pred = self._convert_pred(pred) # back to origin scale s0 = pred["original_size"] / pred["size"] pred["keypoints_orig"] = ( match_features.scale_keypoints(pred["keypoints"] + 0.5, s0) - 0.5 ) # TODO: rotate back binarize = kwargs.get("binarize", False) if binarize: assert "descriptors" in pred pred["descriptors"] = (pred["descriptors"] > 0).astype(np.uint8) pred["descriptors"] = pred["descriptors"].T # N x DIM return pred @torch.inference_mode() def forward( self, img0: np.ndarray, img1: np.ndarray, ) -> Dict[str, np.ndarray]: """ Forward pass of the image matching API. Args: img0: A 3D NumPy array of shape (H, W, C) representing the first image. Values are in the range [0, 1] and are in RGB mode. img1: A 3D NumPy array of shape (H, W, C) representing the second image. Values are in the range [0, 1] and are in RGB mode. Returns: A dictionary containing the following keys: - image0_orig: The original image 0. - image1_orig: The original image 1. - keypoints0_orig: The keypoints detected in image 0. - keypoints1_orig: The keypoints detected in image 1. - mkeypoints0_orig: The raw matches between image 0 and image 1. - mkeypoints1_orig: The raw matches between image 1 and image 0. - mmkeypoints0_orig: The RANSAC inliers in image 0. - mmkeypoints1_orig: The RANSAC inliers in image 1. - mconf: The confidence scores for the raw matches. - mmconf: The confidence scores for the RANSAC inliers. """ # Take as input a pair of images (not a batch) assert isinstance(img0, np.ndarray) assert isinstance(img1, np.ndarray) self.pred = self._forward(img0, img1) if self.conf["ransac"]["enable"]: self.pred = self._geometry_check(self.pred) return self.pred def _geometry_check( self, pred: Dict[str, Any], ) -> Dict[str, Any]: """ Filter matches using RANSAC. If keypoints are available, filter by keypoints. If lines are available, filter by lines. If both keypoints and lines are available, filter by keypoints. Args: pred (Dict[str, Any]): dict of matches, including original keypoints. See :func:`filter_matches` for the expected keys. Returns: Dict[str, Any]: filtered matches """ pred = filter_matches( pred, ransac_method=self.conf["ransac"]["method"], ransac_reproj_threshold=self.conf["ransac"]["reproj_threshold"], ransac_confidence=self.conf["ransac"]["confidence"], ransac_max_iter=self.conf["ransac"]["max_iter"], ) return pred def visualize( self, log_path: Optional[Path] = None, ) -> None: """ Visualize the matches. Args: log_path (Path, optional): The directory to save the images. Defaults to None. Returns: None """ if self.conf["dense"]: postfix = str(self.conf["matcher"]["model"]["name"]) else: postfix = "{}_{}".format( str(self.conf["feature"]["model"]["name"]), str(self.conf["matcher"]["model"]["name"]), ) titles = [ "Image 0 - Keypoints", "Image 1 - Keypoints", ] pred: Dict[str, Any] = self.pred image0: np.ndarray = pred["image0_orig"] image1: np.ndarray = pred["image1_orig"] output_keypoints: np.ndarray = plot_images( [image0, image1], titles=titles, dpi=300 ) if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys(): plot_keypoints([pred["keypoints0_orig"], pred["keypoints1_orig"]]) text: str = ( f"# keypoints0: {len(pred['keypoints0_orig'])} \n" + f"# keypoints1: {len(pred['keypoints1_orig'])}" ) add_text(0, text, fs=15) output_keypoints = fig2im(output_keypoints) # plot images with raw matches titles = [ "Image 0 - Raw matched keypoints", "Image 1 - Raw matched keypoints", ] output_matches_raw, num_matches_raw = display_matches( pred, titles=titles, tag="KPTS_RAW" ) # plot images with ransac matches titles = [ "Image 0 - Ransac matched keypoints", "Image 1 - Ransac matched keypoints", ] output_matches_ransac, num_matches_ransac = display_matches( pred, titles=titles, tag="KPTS_RANSAC" ) if log_path is not None: img_keypoints_path: Path = log_path / f"img_keypoints_{postfix}.png" img_matches_raw_path: Path = log_path / f"img_matches_raw_{postfix}.png" img_matches_ransac_path: Path = ( log_path / f"img_matches_ransac_{postfix}.png" ) cv2.imwrite( str(img_keypoints_path), output_keypoints[:, :, ::-1].copy(), # RGB -> BGR ) cv2.imwrite( str(img_matches_raw_path), output_matches_raw[:, :, ::-1].copy(), # RGB -> BGR ) cv2.imwrite( str(img_matches_ransac_path), output_matches_ransac[:, :, ::-1].copy(), # RGB -> BGR ) plt.close("all")