use num_proc for loading
Browse files
app.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from functools import lru_cache
|
2 |
|
3 |
import duckdb
|
@@ -14,7 +15,7 @@ model_name = "sentence-transformers/static-retrieval-mrl-en-v1"
|
|
14 |
model = SentenceTransformer(
|
15 |
model_name,
|
16 |
device="cpu",
|
17 |
-
tokenizer_kwargs={"model_max_length": 512},
|
18 |
)
|
19 |
|
20 |
|
@@ -35,7 +36,7 @@ def get_iframe(hub_repo_id):
|
|
35 |
|
36 |
def load_dataset_from_hub(hub_repo_id: str):
|
37 |
gr.Info(message="Loading dataset...")
|
38 |
-
ds = load_dataset(hub_repo_id)
|
39 |
|
40 |
|
41 |
def get_columns(hub_repo_id: str, split: str):
|
@@ -50,7 +51,7 @@ def get_columns(hub_repo_id: str, split: str):
|
|
50 |
|
51 |
|
52 |
def get_splits(hub_repo_id: str):
|
53 |
-
ds = load_dataset(hub_repo_id)
|
54 |
splits = list(ds.keys())
|
55 |
return gr.Dropdown(
|
56 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
@@ -60,7 +61,7 @@ def get_splits(hub_repo_id: str):
|
|
60 |
@lru_cache
|
61 |
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
62 |
gr.Info("Vectorizing dataset...")
|
63 |
-
ds = load_dataset(hub_repo_id)
|
64 |
df = ds[split].to_polars()
|
65 |
embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
|
66 |
return embeddings
|
@@ -68,7 +69,7 @@ def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
|
68 |
|
69 |
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
70 |
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
71 |
-
ds = load_dataset(hub_repo_id)
|
72 |
df = ds[split].to_polars()
|
73 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
74 |
try:
|
|
|
1 |
+
import os
|
2 |
from functools import lru_cache
|
3 |
|
4 |
import duckdb
|
|
|
15 |
model = SentenceTransformer(
|
16 |
model_name,
|
17 |
device="cpu",
|
18 |
+
tokenizer_kwargs={"model_max_length": 512}, # arbitrary for this model, here to keep things fast
|
19 |
)
|
20 |
|
21 |
|
|
|
36 |
|
37 |
def load_dataset_from_hub(hub_repo_id: str):
|
38 |
gr.Info(message="Loading dataset...")
|
39 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
40 |
|
41 |
|
42 |
def get_columns(hub_repo_id: str, split: str):
|
|
|
51 |
|
52 |
|
53 |
def get_splits(hub_repo_id: str):
|
54 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
55 |
splits = list(ds.keys())
|
56 |
return gr.Dropdown(
|
57 |
choices=splits, value=splits[0], label="Select a split", visible=True
|
|
|
61 |
@lru_cache
|
62 |
def vectorize_dataset(hub_repo_id: str, split: str, column: str):
|
63 |
gr.Info("Vectorizing dataset...")
|
64 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
65 |
df = ds[split].to_polars()
|
66 |
embeddings = model.encode(df[column].cast(str).to_list(), show_progress_bar=True, batch_size=128)
|
67 |
return embeddings
|
|
|
69 |
|
70 |
def run_query(hub_repo_id: str, query: str, split: str, column: str):
|
71 |
embeddings = vectorize_dataset(hub_repo_id, split, column)
|
72 |
+
ds = load_dataset(hub_repo_id, num_proc=os.cpu_count())
|
73 |
df = ds[split].to_polars()
|
74 |
df = df.with_columns(pl.Series(embeddings).alias("embeddings"))
|
75 |
try:
|