Spaces:
Sleeping
Sleeping
Commit
·
a2f004f
1
Parent(s):
c67b794
fix
Browse files- src/backend.py +15 -18
src/backend.py
CHANGED
@@ -133,33 +133,30 @@ def pattern_match(patterns, source_list):
|
|
133 |
|
134 |
def _backend_routine():
|
135 |
# List only the text classification models
|
136 |
-
rl_models =
|
137 |
logger.info(f"Found {len(rl_models)} RL models")
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
pending_models = list(set(rl_models) - set(evaluated_models))
|
143 |
-
pending_and_compatible_models = []
|
144 |
-
for repo_id, sha in pending_models:
|
145 |
-
try:
|
146 |
-
siblings = API.model_info(repo_id, revision="main").siblings
|
147 |
-
except Exception:
|
148 |
-
continue
|
149 |
-
filenames = [sib.rfilename for sib in siblings]
|
150 |
if "agent.pt" in filenames:
|
151 |
-
|
152 |
|
153 |
-
logger.info(f"Found {len(
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
if len(
|
156 |
return None
|
157 |
|
158 |
# Shuffle the dataset
|
159 |
-
random.shuffle(
|
160 |
|
161 |
# Select a random model
|
162 |
-
repo_id, sha =
|
163 |
user_id, model_id = repo_id.split("/")
|
164 |
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
|
165 |
|
|
|
133 |
|
134 |
def _backend_routine():
|
135 |
# List only the text classification models
|
136 |
+
rl_models = API.list_models(filter=["reinforcement-learning"])
|
137 |
logger.info(f"Found {len(rl_models)} RL models")
|
138 |
+
|
139 |
+
compatible_models = []
|
140 |
+
for model in rl_models:
|
141 |
+
filenames = [sib.rfilename for sib in model.siblings]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
if "agent.pt" in filenames:
|
143 |
+
compatible_models.append((model.modelId, model.sha))
|
144 |
|
145 |
+
logger.info(f"Found {len(compatible_models)} compatible models")
|
146 |
+
|
147 |
+
dataset = load_dataset(RESULTS_REPO, split="train", download_mode="force_redownload", verification_mode="no_checks")
|
148 |
+
evaluated_models = [("/".join([x["user_id"], x["model_id"]]), x["sha"]) for x in dataset]
|
149 |
+
pending_models = list(set(compatible_models) - set(evaluated_models))
|
150 |
+
logger.info(f"Found {len(pending_models)} pending models")
|
151 |
|
152 |
+
if len(pending_models) == 0:
|
153 |
return None
|
154 |
|
155 |
# Shuffle the dataset
|
156 |
+
random.shuffle(pending_models)
|
157 |
|
158 |
# Select a random model
|
159 |
+
repo_id, sha = pending_models.pop()
|
160 |
user_id, model_id = repo_id.split("/")
|
161 |
row = {"model_id": model_id, "user_id": user_id, "sha": sha}
|
162 |
|