Spaces:
Running
Running
import os | |
from typing import ClassVar | |
# import dotenv | |
import gradio as gr | |
import lancedb | |
import srsly | |
from huggingface_hub import snapshot_download | |
from lancedb.embeddings.base import TextEmbeddingFunction | |
from lancedb.embeddings.registry import register | |
from lancedb.pydantic import LanceModel, Vector | |
from lancedb.rerankers import CohereReranker, ColbertReranker | |
from lancedb.util import attempt_import_or_raise | |
# dotenv.load_dotenv() | |
class CohereEmbeddingFunction_2(TextEmbeddingFunction): | |
name: str = "embed-english-v3.0" | |
client: ClassVar = None | |
def ndims(self): | |
return 768 | |
def generate_embeddings(self, texts): | |
""" | |
Get the embeddings for the given texts | |
Parameters | |
---------- | |
texts: list[str] or np.ndarray (of str) | |
The texts to embed | |
""" | |
# TODO retry, rate limit, token limit | |
self._init_client() | |
rs = CohereEmbeddingFunction_2.client.embed( | |
texts=texts, model=self.name, input_type="search_document" | |
) | |
return [emb for emb in rs.embeddings] | |
def _init_client(self): | |
cohere = attempt_import_or_raise("cohere") | |
if CohereEmbeddingFunction_2.client is None: | |
CohereEmbeddingFunction_2.client = cohere.Client( | |
os.environ["COHERE_API_KEY"] | |
) | |
COHERE_EMBEDDER = CohereEmbeddingFunction_2.create() | |
class ArxivModel(LanceModel): | |
text: str = COHERE_EMBEDDER.SourceField() | |
vector: Vector(1024) = COHERE_EMBEDDER.VectorField() | |
title: str | |
paper_title: str | |
content_type: str | |
arxiv_id: str | |
def download_data(): | |
snapshot_download( | |
repo_id="rbiswasfc/zotero_db", | |
repo_type="dataset", | |
local_dir="./data", | |
token=os.environ["HF_TOKEN"], | |
) | |
print("Data downloaded!") | |
download_data() | |
VERSION = "0.0.0a" | |
DB = lancedb.connect("./data/.lancedb_zotero_v0") | |
ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json") | |
RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()} | |
TBL = DB.open_table("arxiv_zotero_v0") | |
def _format_results(arxiv_refs): | |
results = [] | |
for arx_id, paper_title in arxiv_refs.items(): | |
abstract = ID_TO_ABSTRACT.get(arx_id, "") | |
# these are all ugly hacks because the data preprocessing is poor. to be fixed v soon. | |
if "Abstract\n\n" in abstract: | |
abstract = abstract.split("Abstract\n\n")[-1] | |
if paper_title in abstract: | |
abstract = abstract.split(paper_title)[-1] | |
if abstract.startswith("\n"): | |
abstract = abstract[1:] | |
if "\n\n" in abstract[:20]: | |
abstract = "\n\n".join(abstract.split("\n\n")[1:]) | |
result = { | |
"title": paper_title, | |
"url": f"https://arxiv.org/abs/{arx_id}", | |
"abstract": abstract, | |
} | |
results.append(result) | |
return results | |
def query_db(query: str, k: int = 10, reranker: str = "cohere"): | |
raw_results = TBL.search(query, query_type="hybrid").limit(k) | |
if reranker is not None: | |
ranked_results = raw_results.rerank(reranker=RERANKERS[reranker]) | |
else: | |
ranked_results = raw_results | |
ranked_results = ranked_results.to_pandas() | |
top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"}) | |
top_results = top_results.sort_values(by="_relevance_score", ascending=False).head( | |
3 | |
) | |
top_results_dict = { | |
row["arxiv_id"]: row["paper_title"] | |
for index, row in ranked_results.iterrows() | |
if row["arxiv_id"] in top_results.index | |
} | |
final_results = _format_results(top_results_dict) | |
return final_results | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
query = gr.Textbox(label="Query", placeholder="Enter your query...") | |
submit_btn = gr.Button("Submit") | |
output = gr.JSON(label="Search Results") | |
# # callback --- | |
submit_btn.click( | |
fn=query_db, | |
inputs=query, | |
outputs=output, | |
) | |
demo.launch() | |