import torch import lancedb from lancedb.pydantic import LanceModel import pydantic # import time from config import lancedb_location db = lancedb.connect(lancedb_location) table = db.open_table("kanji") class SearchResult(LanceModel): kanji: str distance: float = pydantic.Field(validation_alias=pydantic.AliasChoices('distance', '_distance')) def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[SearchResult]: # start = time.perf_counter() results = ( table .search(query_vector.numpy(), vector_column_name="vector", query_type="vector") .limit(limit) # .to_pydantic(SearchResult) # type: ignore .to_list() ) # end = time.perf_counter() # print(f"Searched in {end - start:.3f}") return [SearchResult.model_validate(result) for result in results]