MatchAnything / imcui /hloc /pairs_from_retrieval.py
XingyiHe's picture
init commit
3040ac4
import argparse
import collections.abc as collections
from pathlib import Path
from typing import Optional
import h5py
import numpy as np
import torch
from . import logger
from .utils.io import list_h5_names
from .utils.parsers import parse_image_lists
from .utils.read_write_model import read_images_binary
def parse_names(prefix, names, names_all):
if prefix is not None:
if not isinstance(prefix, str):
prefix = tuple(prefix)
names = [n for n in names_all if n.startswith(prefix)]
if len(names) == 0:
raise ValueError(f"Could not find any image with the prefix `{prefix}`.")
elif names is not None:
if isinstance(names, (str, Path)):
names = parse_image_lists(names)
elif isinstance(names, collections.Iterable):
names = list(names)
else:
raise ValueError(
f"Unknown type of image list: {names}."
"Provide either a list or a path to a list file."
)
else:
names = names_all
return names
def get_descriptors(names, path, name2idx=None, key="global_descriptor"):
if name2idx is None:
with h5py.File(str(path), "r", libver="latest") as fd:
desc = [fd[n][key].__array__() for n in names]
else:
desc = []
for n in names:
with h5py.File(str(path[name2idx[n]]), "r", libver="latest") as fd:
desc.append(fd[n][key].__array__())
return torch.from_numpy(np.stack(desc, 0)).float()
def pairs_from_score_matrix(
scores: torch.Tensor,
invalid: np.array,
num_select: int,
min_score: Optional[float] = None,
):
assert scores.shape == invalid.shape
if isinstance(scores, np.ndarray):
scores = torch.from_numpy(scores)
invalid = torch.from_numpy(invalid).to(scores.device)
if min_score is not None:
invalid |= scores < min_score
scores.masked_fill_(invalid, float("-inf"))
topk = torch.topk(scores, num_select, dim=1)
indices = topk.indices.cpu().numpy()
valid = topk.values.isfinite().cpu().numpy()
pairs = []
for i, j in zip(*np.where(valid)):
pairs.append((i, indices[i, j]))
return pairs
def main(
descriptors,
output,
num_matched,
query_prefix=None,
query_list=None,
db_prefix=None,
db_list=None,
db_model=None,
db_descriptors=None,
):
logger.info("Extracting image pairs from a retrieval database.")
# We handle multiple reference feature files.
# We only assume that names are unique among them and map names to files.
if db_descriptors is None:
db_descriptors = descriptors
if isinstance(db_descriptors, (Path, str)):
db_descriptors = [db_descriptors]
name2db = {n: i for i, p in enumerate(db_descriptors) for n in list_h5_names(p)}
db_names_h5 = list(name2db.keys())
query_names_h5 = list_h5_names(descriptors)
if db_model:
images = read_images_binary(db_model / "images.bin")
db_names = [i.name for i in images.values()]
else:
db_names = parse_names(db_prefix, db_list, db_names_h5)
if len(db_names) == 0:
raise ValueError("Could not find any database image.")
query_names = parse_names(query_prefix, query_list, query_names_h5)
device = "cuda" if torch.cuda.is_available() else "cpu"
db_desc = get_descriptors(db_names, db_descriptors, name2db)
query_desc = get_descriptors(query_names, descriptors)
sim = torch.einsum("id,jd->ij", query_desc.to(device), db_desc.to(device))
# Avoid self-matching
self = np.array(query_names)[:, None] == np.array(db_names)[None]
pairs = pairs_from_score_matrix(sim, self, num_matched, min_score=0)
pairs = [(query_names[i], db_names[j]) for i, j in pairs]
logger.info(f"Found {len(pairs)} pairs.")
with open(output, "w") as f:
f.write("\n".join(" ".join([i, j]) for i, j in pairs))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--descriptors", type=Path, required=True)
parser.add_argument("--output", type=Path, required=True)
parser.add_argument("--num_matched", type=int, required=True)
parser.add_argument("--query_prefix", type=str, nargs="+")
parser.add_argument("--query_list", type=Path)
parser.add_argument("--db_prefix", type=str, nargs="+")
parser.add_argument("--db_list", type=Path)
parser.add_argument("--db_model", type=Path)
parser.add_argument("--db_descriptors", type=Path)
args = parser.parse_args()
main(**args.__dict__)