pszemraj commited on
Commit
8049894
·
verified ·
1 Parent(s): e47cccf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -5,13 +5,13 @@ import gradio as gr
5
  import polars as pl
6
  from datasets import load_dataset
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
- from model2vec import StaticModel
9
 
10
  global df
11
 
12
- # Load a model from the HuggingFace hub (in this case the potion-base-8M model)
13
- model_name = "minishlab/potion-base-8M"
14
- model = StaticModel.from_pretrained(model_name)
15
 
16
 
17
  def get_iframe(hub_repo_id):
@@ -58,7 +58,7 @@ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
58
  gr.Info("Vectorizing dataset...")
59
  ds = load_dataset(hub_repo_id)
60
  df = ds[split].to_polars()
61
- embeddings = model.encode(df[column].cast(str), max_length=512)
62
  return embeddings
63
 
64
 
@@ -73,7 +73,7 @@ def run_query(hub_repo_id: str, query: str, split: str, column: str):
73
  query=f"""
74
  SELECT *
75
  FROM df
76
- ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256])
77
  LIMIT 5
78
  """
79
  ).to_df()
@@ -165,4 +165,4 @@ with gr.Blocks() as demo:
165
  outputs=results_output,
166
  )
167
 
168
- demo.launch()
 
5
  import polars as pl
6
  from datasets import load_dataset
7
  from gradio_huggingfacehub_search import HuggingfaceHubSearch
8
+ from sentence_transformers import SentenceTransformer
9
 
10
  global df
11
 
12
+ # Load the static embeddings model from HuggingFace hub
13
+ model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
14
+ model = SentenceTransformer(model_name, device="cpu")
15
 
16
 
17
  def get_iframe(hub_repo_id):
 
58
  gr.Info("Vectorizing dataset...")
59
  ds = load_dataset(hub_repo_id)
60
  df = ds[split].to_polars()
61
+ embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True)
62
  return embeddings
63
 
64
 
 
73
  query=f"""
74
  SELECT *
75
  FROM df
76
+ ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[1024])
77
  LIMIT 5
78
  """
79
  ).to_df()
 
165
  outputs=results_output,
166
  )
167
 
168
+ demo.launch()