File size: 843 Bytes
be12cc9
63a1db6
 
 
 
be12cc9
63a1db6
be12cc9
63a1db6
 
be12cc9
63a1db6
be12cc9
63a1db6
be12cc9
63a1db6
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import lancedb
from lancedb.pydantic import LanceModel
import pydantic
# import time

from config import lancedb_location

db = lancedb.connect(lancedb_location)
table = db.open_table("kanji")

class SearchResult(LanceModel):
    kanji: str
    distance: float = pydantic.Field(validation_alias=pydantic.AliasChoices('distance', '_distance'))

def search_vector(query_vector: torch.Tensor, limit: int=20) -> list[SearchResult]:
    # start = time.perf_counter()
    results = (
        table
        .search(query_vector.numpy(), vector_column_name="vector", query_type="vector")
        .limit(limit)
        # .to_pydantic(SearchResult)   # type: ignore
        .to_list()
    )
    # end = time.perf_counter()
    # print(f"Searched in {end - start:.3f}")
    return [SearchResult.model_validate(result) for result in results]