pszemraj commited on
Commit
c28c93a
·
verified ·
1 Parent(s): 9b8a77d

use num_proc for loading

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  from functools import lru_cache
2
 
3
  import duckdb
@@ -14,7 +15,7 @@ model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
14
  model = SentenceTransformer(
15
  model_name,
16
  device="cpu",
17
- tokenizer_kwargs={"model_max_length": 512},
18
  )
19
 
20
 
@@ -35,7 +36,7 @@ def get_iframe(hub_repo_id):
35
 
36
  def load_dataset_from_hub(hub_repo_id: str):
37
  gr.Info(message="Loading dataset...")
38
- ds = load_dataset(hub_repo_id)
39
 
40
 
41
  def get_columns(hub_repo_id: str, split: str):
@@ -50,7 +51,7 @@ def get_columns(hub_repo_id: str, split: str):
50
 
51
 
52
  def get_splits(hub_repo_id: str):
53
- ds = load_dataset(hub_repo_id)
54
  splits = list(ds.keys())
55
  return gr.Dropdown(
56
  choices=splits, value=splits[0], label="Select a split", visible=True
@@ -60,7 +61,7 @@ def get_splits(hub_repo_id: str):
60
  @lru_cache
61
  def vectorize_dataset(hub_repo_id: str, split: str, column: str):
62
  gr.Info("Vectorizing dataset...")
63
- ds = load_dataset(hub_repo_id)
64
  df = ds[split].to_polars()
65
  embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
66
  return embeddings
@@ -68,7 +69,7 @@ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
68
 
69
  def run_query(hub_repo_id: str, query: str, split: str, column: str):
70
  embeddings = vectorize_dataset(hub_repo_id, split, column)
71
- ds = load_dataset(hub_repo_id)
72
  df = ds[split].to_polars()
73
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
74
  try:
 
1
+ import os
2
  from functools import lru_cache
3
 
4
  import duckdb
 
15
  model = SentenceTransformer(
16
  model_name,
17
  device="cpu",
18
+ tokenizer_kwargs={"model_max_length": 512}, # arbitrary for this model, here to keep things fast
19
  )
20
 
21
 
 
36
 
37
  def load_dataset_from_hub(hub_repo_id: str):
38
  gr.Info(message="Loading dataset...")
39
+ ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
40
 
41
 
42
  def get_columns(hub_repo_id: str, split: str):
 
51
 
52
 
53
  def get_splits(hub_repo_id: str):
54
+ ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
55
  splits = list(ds.keys())
56
  return gr.Dropdown(
57
  choices=splits, value=splits[0], label="Select a split", visible=True
 
61
  @lru_cache
62
  def vectorize_dataset(hub_repo_id: str, split: str, column: str):
63
  gr.Info("Vectorizing dataset...")
64
+ ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
65
  df = ds[split].to_polars()
66
  embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
67
  return embeddings
 
69
 
70
  def run_query(hub_repo_id: str, query: str, split: str, column: str):
71
  embeddings = vectorize_dataset(hub_repo_id, split, column)
72
+ ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
73
  df = ds[split].to_polars()
74
  df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
75
  try: