MatchAnything / imcui /hloc /localize_sfm.py
XingyiHe's picture
init commit
3040ac4
import argparse
import pickle
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Union
import numpy as np
import pycolmap
from tqdm import tqdm
from . import logger
from .utils.io import get_keypoints, get_matches
from .utils.parsers import parse_image_lists, parse_retrieval
def do_covisibility_clustering(
frame_ids: List[int], reconstruction: pycolmap.Reconstruction
):
clusters = []
visited = set()
for frame_id in frame_ids:
# Check if already labeled
if frame_id in visited:
continue
# New component
clusters.append([])
queue = {frame_id}
while len(queue):
exploration_frame = queue.pop()
# Already part of the component
if exploration_frame in visited:
continue
visited.add(exploration_frame)
clusters[-1].append(exploration_frame)
observed = reconstruction.images[exploration_frame].points2D
connected_frames = {
obs.image_id
for p2D in observed
if p2D.has_point3D()
for obs in reconstruction.points3D[p2D.point3D_id].track.elements
}
connected_frames &= set(frame_ids)
connected_frames -= visited
queue |= connected_frames
clusters = sorted(clusters, key=len, reverse=True)
return clusters
class QueryLocalizer:
def __init__(self, reconstruction, config=None):
self.reconstruction = reconstruction
self.config = config or {}
def localize(self, points2D_all, points2D_idxs, points3D_id, query_camera):
points2D = points2D_all[points2D_idxs]
points3D = [self.reconstruction.points3D[j].xyz for j in points3D_id]
ret = pycolmap.absolute_pose_estimation(
points2D,
points3D,
query_camera,
estimation_options=self.config.get("estimation", {}),
refinement_options=self.config.get("refinement", {}),
)
return ret
def pose_from_cluster(
localizer: QueryLocalizer,
qname: str,
query_camera: pycolmap.Camera,
db_ids: List[int],
features_path: Path,
matches_path: Path,
**kwargs,
):
kpq = get_keypoints(features_path, qname)
kpq += 0.5 # COLMAP coordinates
kp_idx_to_3D = defaultdict(list)
kp_idx_to_3D_to_db = defaultdict(lambda: defaultdict(list))
num_matches = 0
for i, db_id in enumerate(db_ids):
image = localizer.reconstruction.images[db_id]
if image.num_points3D == 0:
logger.debug(f"No 3D points found for {image.name}.")
continue
points3D_ids = np.array(
[p.point3D_id if p.has_point3D() else -1 for p in image.points2D]
)
matches, _ = get_matches(matches_path, qname, image.name)
matches = matches[points3D_ids[matches[:, 1]] != -1]
num_matches += len(matches)
for idx, m in matches:
id_3D = points3D_ids[m]
kp_idx_to_3D_to_db[idx][id_3D].append(i)
# avoid duplicate observations
if id_3D not in kp_idx_to_3D[idx]:
kp_idx_to_3D[idx].append(id_3D)
idxs = list(kp_idx_to_3D.keys())
mkp_idxs = [i for i in idxs for _ in kp_idx_to_3D[i]]
mp3d_ids = [j for i in idxs for j in kp_idx_to_3D[i]]
ret = localizer.localize(kpq, mkp_idxs, mp3d_ids, query_camera, **kwargs)
if ret is not None:
ret["camera"] = query_camera
# mostly for logging and post-processing
mkp_to_3D_to_db = [
(j, kp_idx_to_3D_to_db[i][j]) for i in idxs for j in kp_idx_to_3D[i]
]
log = {
"db": db_ids,
"PnP_ret": ret,
"keypoints_query": kpq[mkp_idxs],
"points3D_ids": mp3d_ids,
"points3D_xyz": None, # we don't log xyz anymore because of file size
"num_matches": num_matches,
"keypoint_index_to_db": (mkp_idxs, mkp_to_3D_to_db),
}
return ret, log
def main(
reference_sfm: Union[Path, pycolmap.Reconstruction],
queries: Path,
retrieval: Path,
features: Path,
matches: Path,
results: Path,
ransac_thresh: int = 12,
covisibility_clustering: bool = False,
prepend_camera_name: bool = False,
config: Dict = None,
):
assert retrieval.exists(), retrieval
assert features.exists(), features
assert matches.exists(), matches
queries = parse_image_lists(queries, with_intrinsics=True)
retrieval_dict = parse_retrieval(retrieval)
logger.info("Reading the 3D model...")
if not isinstance(reference_sfm, pycolmap.Reconstruction):
reference_sfm = pycolmap.Reconstruction(reference_sfm)
db_name_to_id = {img.name: i for i, img in reference_sfm.images.items()}
config = {
"estimation": {"ransac": {"max_error": ransac_thresh}},
**(config or {}),
}
localizer = QueryLocalizer(reference_sfm, config)
cam_from_world = {}
logs = {
"features": features,
"matches": matches,
"retrieval": retrieval,
"loc": {},
}
logger.info("Starting localization...")
for qname, qcam in tqdm(queries):
if qname not in retrieval_dict:
logger.warning(f"No images retrieved for query image {qname}. Skipping...")
continue
db_names = retrieval_dict[qname]
db_ids = []
for n in db_names:
if n not in db_name_to_id:
logger.warning(f"Image {n} was retrieved but not in database")
continue
db_ids.append(db_name_to_id[n])
if covisibility_clustering:
clusters = do_covisibility_clustering(db_ids, reference_sfm)
best_inliers = 0
best_cluster = None
logs_clusters = []
for i, cluster_ids in enumerate(clusters):
ret, log = pose_from_cluster(
localizer, qname, qcam, cluster_ids, features, matches
)
if ret is not None and ret["num_inliers"] > best_inliers:
best_cluster = i
best_inliers = ret["num_inliers"]
logs_clusters.append(log)
if best_cluster is not None:
ret = logs_clusters[best_cluster]["PnP_ret"]
cam_from_world[qname] = ret["cam_from_world"]
logs["loc"][qname] = {
"db": db_ids,
"best_cluster": best_cluster,
"log_clusters": logs_clusters,
"covisibility_clustering": covisibility_clustering,
}
else:
ret, log = pose_from_cluster(
localizer, qname, qcam, db_ids, features, matches
)
if ret is not None:
cam_from_world[qname] = ret["cam_from_world"]
else:
closest = reference_sfm.images[db_ids[0]]
cam_from_world[qname] = closest.cam_from_world
log["covisibility_clustering"] = covisibility_clustering
logs["loc"][qname] = log
logger.info(f"Localized {len(cam_from_world)} / {len(queries)} images.")
logger.info(f"Writing poses to {results}...")
with open(results, "w") as f:
for query, t in cam_from_world.items():
qvec = " ".join(map(str, t.rotation.quat[[3, 0, 1, 2]]))
tvec = " ".join(map(str, t.translation))
name = query.split("/")[-1]
if prepend_camera_name:
name = query.split("/")[-2] + "/" + name
f.write(f"{name} {qvec} {tvec}\n")
logs_path = f"{results}_logs.pkl"
logger.info(f"Writing logs to {logs_path}...")
# TODO: Resolve pickling issue with pycolmap objects.
with open(logs_path, "wb") as f:
pickle.dump(logs, f)
logger.info("Done!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--reference_sfm", type=Path, required=True)
parser.add_argument("--queries", type=Path, required=True)
parser.add_argument("--features", type=Path, required=True)
parser.add_argument("--matches", type=Path, required=True)
parser.add_argument("--retrieval", type=Path, required=True)
parser.add_argument("--results", type=Path, required=True)
parser.add_argument("--ransac_thresh", type=float, default=12.0)
parser.add_argument("--covisibility_clustering", action="store_true")
parser.add_argument("--prepend_camera_name", action="store_true")
args = parser.parse_args()
main(**args.__dict__)