import warnings import cv2 import numpy as np import torch from kornia.color import rgb_to_grayscale from omegaconf import OmegaConf from packaging import version try: import pycolmap except ImportError: pycolmap = None from .. import logger from ..utils.base_model import BaseModel def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None): h, w = image_shape ij = np.round(points - 0.5).astype(int).T[::-1] # Remove duplicate points (identical coordinates). # Pick highest scale or score s = scales if scores is None else scores buffer = np.zeros((h, w)) np.maximum.at(buffer, tuple(ij), s) keep = np.where(buffer[tuple(ij)] == s)[0] # Pick lowest angle (arbitrary). ij = ij[:, keep] buffer[:] = np.inf o_abs = np.abs(angles[keep]) np.minimum.at(buffer, tuple(ij), o_abs) mask = buffer[tuple(ij)] == o_abs ij = ij[:, mask] keep = keep[mask] if nms_radius > 0: # Apply NMS on the remaining points buffer[:] = 0 buffer[tuple(ij)] = s[keep] # scores or scale local_max = torch.nn.functional.max_pool2d( torch.from_numpy(buffer).unsqueeze(0), kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius, ).squeeze(0) is_local_max = buffer == local_max.numpy() keep = keep[is_local_max[tuple(ij)]] return keep def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor: x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps) x.clip_(min=eps).sqrt_() return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps) def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray: """ Detect keypoints using OpenCV Detector. Optionally, perform description. Args: features: OpenCV based keypoints detector and descriptor image: Grayscale image of uint8 data type Returns: keypoints: 1D array of detected cv2.KeyPoint scores: 1D array of responses descriptors: 1D array of descriptors """ detections, descriptors = features.detectAndCompute(image, None) points = np.array([k.pt for k in detections], dtype=np.float32) scores = np.array([k.response for k in detections], dtype=np.float32) scales = np.array([k.size for k in detections], dtype=np.float32) angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32)) return points, scores, scales, angles, descriptors class SIFT(BaseModel): default_conf = { "rootsift": True, "nms_radius": 0, # None to disable filtering entirely. "max_keypoints": 4096, "backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda} "detection_threshold": 0.0066667, # from COLMAP "edge_threshold": 10, "first_octave": -1, # only used by pycolmap, the default of COLMAP "num_octaves": 4, } required_data_keys = ["image"] def _init(self, conf): self.conf = OmegaConf.create(self.conf) backend = self.conf.backend if backend.startswith("pycolmap"): if pycolmap is None: raise ImportError( "Cannot find module pycolmap: install it with pip" "or use backend=opencv." ) options = { "peak_threshold": self.conf.detection_threshold, "edge_threshold": self.conf.edge_threshold, "first_octave": self.conf.first_octave, "num_octaves": self.conf.num_octaves, "normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy. } device = ( "auto" if backend == "pycolmap" else backend.replace("pycolmap_", "") ) if ( backend == "pycolmap_cpu" or not pycolmap.has_cuda ) and pycolmap.__version__ < "0.5.0": warnings.warn( "The pycolmap CPU SIFT is buggy in version < 0.5.0, " "consider upgrading pycolmap or use the CUDA version.", stacklevel=1, ) else: options["max_num_features"] = self.conf.max_keypoints self.sift = pycolmap.Sift(options=options, device=device) elif backend == "opencv": self.sift = cv2.SIFT_create( contrastThreshold=self.conf.detection_threshold, nfeatures=self.conf.max_keypoints, edgeThreshold=self.conf.edge_threshold, nOctaveLayers=self.conf.num_octaves, ) else: backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"} raise ValueError( f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}." ) logger.info("Load SIFT model done.") def extract_single_image(self, image: torch.Tensor): image_np = image.cpu().numpy().squeeze(0) if self.conf.backend.startswith("pycolmap"): if version.parse(pycolmap.__version__) >= version.parse("0.5.0"): detections, descriptors = self.sift.extract(image_np) scores = None # Scores are not exposed by COLMAP anymore. else: detections, scores, descriptors = self.sift.extract(image_np) keypoints = detections[:, :2] # Keep only (x, y). scales, angles = detections[:, -2:].T if scores is not None and ( self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda ): # Set the scores as a combination of abs. response and scale. scores = np.abs(scores) * scales elif self.conf.backend == "opencv": # TODO: Check if opencv keypoints are already in corner convention keypoints, scores, scales, angles, descriptors = run_opencv_sift( self.sift, (image_np * 255.0).astype(np.uint8) ) pred = { "keypoints": keypoints, "scales": scales, "oris": angles, "descriptors": descriptors, } if scores is not None: pred["scores"] = scores # sometimes pycolmap returns points outside the image. We remove them if self.conf.backend.startswith("pycolmap"): is_inside = ( pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]]) ).all(-1) pred = {k: v[is_inside] for k, v in pred.items()} if self.conf.nms_radius is not None: keep = filter_dog_point( pred["keypoints"], pred["scales"], pred["oris"], image_np.shape, self.conf.nms_radius, scores=pred.get("scores"), ) pred = {k: v[keep] for k, v in pred.items()} pred = {k: torch.from_numpy(v) for k, v in pred.items()} if scores is not None: # Keep the k keypoints with highest score num_points = self.conf.max_keypoints if num_points is not None and len(pred["keypoints"]) > num_points: indices = torch.topk(pred["scores"], num_points).indices pred = {k: v[indices] for k, v in pred.items()} return pred def _forward(self, data: dict) -> dict: image = data["image"] if image.shape[1] == 3: image = rgb_to_grayscale(image) device = image.device image = image.cpu() pred = [] for k in range(len(image)): img = image[k] if "image_size" in data.keys(): # avoid extracting points in padded areas w, h = data["image_size"][k] img = img[:, :h, :w] p = self.extract_single_image(img) pred.append(p) pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]} if self.conf.rootsift: pred["descriptors"] = sift_to_rootsift(pred["descriptors"]) pred["descriptors"] = pred["descriptors"].permute(0, 2, 1) pred["keypoint_scores"] = pred["scores"].clone() return pred