Spaces:
Runtime error
Runtime error
yuantao-infini-ai
commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +3 -9
- __init__.py +0 -0
- __pycache__/__init__.cpython-310.pyc +0 -0
- __pycache__/__init__.cpython-311.pyc +0 -0
- __pycache__/api_provider.cpython-310.pyc +0 -0
- __pycache__/base_model_worker.cpython-310.pyc +0 -0
- __pycache__/cli.cpython-310.pyc +0 -0
- __pycache__/cli.cpython-311.pyc +0 -0
- __pycache__/controller.cpython-310.pyc +0 -0
- __pycache__/gradio_web_server.cpython-310.pyc +0 -0
- __pycache__/inference.cpython-310.pyc +0 -0
- __pycache__/model_worker.cpython-310.pyc +0 -0
- __pycache__/test_message.cpython-310.pyc +0 -0
- api_provider.py +130 -0
- base_model_worker.py +239 -0
- cli.py +313 -0
- controller.py +348 -0
- gateway/README.md +57 -0
- gateway/nginx.conf +97 -0
- gradio_block_arena_anony.py +608 -0
- gradio_block_arena_named.py +458 -0
- gradio_web_server.py +883 -0
- gradio_web_server_multi.py +270 -0
- huggingface_api.py +73 -0
- huggingface_api_worker.py +391 -0
- inference.py +596 -0
- launch_all_serve.py +284 -0
- model_worker.py +363 -0
- monitor/basic_stats.py +210 -0
- monitor/clean_battle_data.py +269 -0
- monitor/clean_chat_data.py +171 -0
- monitor/dataset_release_scripts/arena_33k/count_unique_users.py +25 -0
- monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py +155 -0
- monitor/dataset_release_scripts/arena_33k/merge_field.py +25 -0
- monitor/dataset_release_scripts/arena_33k/sample.py +32 -0
- monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py +9 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py +13 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py +119 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py +148 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py +27 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md +23 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py +45 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh +18 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/sample.py +32 -0
- monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py +17 -0
- monitor/elo_analysis.py +303 -0
- monitor/inspect_conv.py +87 -0
- monitor/intersect_conv_file.py +25 -0
- monitor/leaderboard_csv_to_html.py +51 -0
- monitor/monitor.py +313 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: indigo
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: demo_test
|
3 |
+
app_file: gradio_web_server.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.45.0
|
|
|
|
|
6 |
---
|
|
|
|
__init__.py
ADDED
File without changes
|
__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (168 Bytes). View file
|
|
__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (184 Bytes). View file
|
|
__pycache__/api_provider.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
__pycache__/base_model_worker.cpython-310.pyc
ADDED
Binary file (7.01 kB). View file
|
|
__pycache__/cli.cpython-310.pyc
ADDED
Binary file (9 kB). View file
|
|
__pycache__/cli.cpython-311.pyc
ADDED
Binary file (15.6 kB). View file
|
|
__pycache__/controller.cpython-310.pyc
ADDED
Binary file (9.35 kB). View file
|
|
__pycache__/gradio_web_server.cpython-310.pyc
ADDED
Binary file (20.6 kB). View file
|
|
__pycache__/inference.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
__pycache__/model_worker.cpython-310.pyc
ADDED
Binary file (9.37 kB). View file
|
|
__pycache__/test_message.cpython-310.pyc
ADDED
Binary file (2.22 kB). View file
|
|
api_provider.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Call API providers."""
|
2 |
+
|
3 |
+
import os
|
4 |
+
import random
|
5 |
+
import time
|
6 |
+
|
7 |
+
from fastchat.utils import build_logger
|
8 |
+
from fastchat.constants import WORKER_API_TIMEOUT
|
9 |
+
|
10 |
+
|
11 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
12 |
+
|
13 |
+
|
14 |
+
def openai_api_stream_iter(
|
15 |
+
model_name,
|
16 |
+
messages,
|
17 |
+
temperature,
|
18 |
+
top_p,
|
19 |
+
max_new_tokens,
|
20 |
+
api_base=None,
|
21 |
+
api_key=None,
|
22 |
+
):
|
23 |
+
import openai
|
24 |
+
|
25 |
+
openai.api_base = api_base or "https://api.openai.com/v1"
|
26 |
+
openai.api_key = api_key or os.environ["OPENAI_API_KEY"]
|
27 |
+
if model_name == "gpt-4-turbo":
|
28 |
+
model_name = "gpt-4-1106-preview"
|
29 |
+
|
30 |
+
# Make requests
|
31 |
+
gen_params = {
|
32 |
+
"model": model_name,
|
33 |
+
"prompt": messages,
|
34 |
+
"temperature": temperature,
|
35 |
+
"top_p": top_p,
|
36 |
+
"max_new_tokens": max_new_tokens,
|
37 |
+
}
|
38 |
+
logger.info(f"==== request ====\n{gen_params}")
|
39 |
+
|
40 |
+
res = openai.ChatCompletion.create(
|
41 |
+
model=model_name,
|
42 |
+
messages=messages,
|
43 |
+
temperature=temperature,
|
44 |
+
max_tokens=max_new_tokens,
|
45 |
+
stream=True,
|
46 |
+
)
|
47 |
+
text = ""
|
48 |
+
for chunk in res:
|
49 |
+
text += chunk["choices"][0]["delta"].get("content", "")
|
50 |
+
data = {
|
51 |
+
"text": text,
|
52 |
+
"error_code": 0,
|
53 |
+
}
|
54 |
+
yield data
|
55 |
+
|
56 |
+
|
57 |
+
def anthropic_api_stream_iter(model_name, prompt, temperature, top_p, max_new_tokens):
|
58 |
+
import anthropic
|
59 |
+
|
60 |
+
c = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
|
61 |
+
|
62 |
+
# Make requests
|
63 |
+
gen_params = {
|
64 |
+
"model": model_name,
|
65 |
+
"prompt": prompt,
|
66 |
+
"temperature": temperature,
|
67 |
+
"top_p": top_p,
|
68 |
+
"max_new_tokens": max_new_tokens,
|
69 |
+
}
|
70 |
+
logger.info(f"==== request ====\n{gen_params}")
|
71 |
+
|
72 |
+
res = c.completions.create(
|
73 |
+
prompt=prompt,
|
74 |
+
stop_sequences=[anthropic.HUMAN_PROMPT],
|
75 |
+
max_tokens_to_sample=max_new_tokens,
|
76 |
+
temperature=temperature,
|
77 |
+
top_p=top_p,
|
78 |
+
model=model_name,
|
79 |
+
stream=True,
|
80 |
+
)
|
81 |
+
text = ""
|
82 |
+
for chunk in res:
|
83 |
+
text += chunk.completion
|
84 |
+
data = {
|
85 |
+
"text": text,
|
86 |
+
"error_code": 0,
|
87 |
+
}
|
88 |
+
yield data
|
89 |
+
|
90 |
+
|
91 |
+
def init_palm_chat(model_name):
|
92 |
+
import vertexai # pip3 install google-cloud-aiplatform
|
93 |
+
from vertexai.preview.language_models import ChatModel
|
94 |
+
|
95 |
+
project_id = os.environ["GCP_PROJECT_ID"]
|
96 |
+
location = "us-central1"
|
97 |
+
vertexai.init(project=project_id, location=location)
|
98 |
+
|
99 |
+
chat_model = ChatModel.from_pretrained(model_name)
|
100 |
+
chat = chat_model.start_chat(examples=[])
|
101 |
+
return chat
|
102 |
+
|
103 |
+
|
104 |
+
def palm_api_stream_iter(chat, message, temperature, top_p, max_new_tokens):
|
105 |
+
parameters = {
|
106 |
+
"temperature": temperature,
|
107 |
+
"top_p": top_p,
|
108 |
+
"max_output_tokens": max_new_tokens,
|
109 |
+
}
|
110 |
+
gen_params = {
|
111 |
+
"model": "palm-2",
|
112 |
+
"prompt": message,
|
113 |
+
}
|
114 |
+
gen_params.update(parameters)
|
115 |
+
logger.info(f"==== request ====\n{gen_params}")
|
116 |
+
|
117 |
+
response = chat.send_message(message, **parameters)
|
118 |
+
content = response.text
|
119 |
+
|
120 |
+
pos = 0
|
121 |
+
while pos < len(content):
|
122 |
+
# This is a fancy way to simulate token generation latency combined
|
123 |
+
# with a Poisson process.
|
124 |
+
pos += random.randint(10, 20)
|
125 |
+
time.sleep(random.expovariate(50))
|
126 |
+
data = {
|
127 |
+
"text": content[:pos],
|
128 |
+
"error_code": 0,
|
129 |
+
}
|
130 |
+
yield data
|
base_model_worker.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import threading
|
3 |
+
import time
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
7 |
+
from fastapi.responses import StreamingResponse, JSONResponse
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from fastchat.constants import WORKER_HEART_BEAT_INTERVAL
|
11 |
+
from fastchat.conversation import Conversation
|
12 |
+
from fastchat.utils import pretty_print_semaphore, build_logger
|
13 |
+
|
14 |
+
|
15 |
+
worker = None
|
16 |
+
logger = None
|
17 |
+
|
18 |
+
app = FastAPI()
|
19 |
+
|
20 |
+
|
21 |
+
def heart_beat_worker(obj):
|
22 |
+
while True:
|
23 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
24 |
+
obj.send_heart_beat()
|
25 |
+
|
26 |
+
|
27 |
+
class BaseModelWorker:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
controller_addr: str,
|
31 |
+
worker_addr: str,
|
32 |
+
worker_id: str,
|
33 |
+
model_path: str,
|
34 |
+
model_names: List[str],
|
35 |
+
limit_worker_concurrency: int,
|
36 |
+
conv_template: str = None,
|
37 |
+
):
|
38 |
+
global logger, worker
|
39 |
+
|
40 |
+
self.controller_addr = controller_addr
|
41 |
+
self.worker_addr = worker_addr
|
42 |
+
self.worker_id = worker_id
|
43 |
+
if model_path.endswith("/"):
|
44 |
+
model_path = model_path[:-1]
|
45 |
+
self.model_names = model_names or [model_path.split("/")[-1]]
|
46 |
+
self.limit_worker_concurrency = limit_worker_concurrency
|
47 |
+
self.conv = self.make_conv_template(conv_template, model_path)
|
48 |
+
self.conv.sep_style = int(self.conv.sep_style)
|
49 |
+
self.tokenizer = None
|
50 |
+
self.context_len = None
|
51 |
+
self.call_ct = 0
|
52 |
+
self.semaphore = None
|
53 |
+
|
54 |
+
self.heart_beat_thread = None
|
55 |
+
|
56 |
+
if logger is None:
|
57 |
+
logger = build_logger("model_worker", f"model_worker_{self.worker_id}.log")
|
58 |
+
if worker is None:
|
59 |
+
worker = self
|
60 |
+
|
61 |
+
def make_conv_template(
|
62 |
+
self,
|
63 |
+
conv_template: str = None,
|
64 |
+
model_path: str = None,
|
65 |
+
) -> Conversation:
|
66 |
+
"""
|
67 |
+
can be overrided to costomize the conversation template for different model workers.
|
68 |
+
"""
|
69 |
+
from fastchat.conversation import get_conv_template
|
70 |
+
from fastchat.model.model_adapter import get_conversation_template
|
71 |
+
|
72 |
+
if conv_template:
|
73 |
+
conv = get_conv_template(conv_template)
|
74 |
+
else:
|
75 |
+
conv = get_conversation_template(model_path)
|
76 |
+
print(conv)
|
77 |
+
return conv
|
78 |
+
|
79 |
+
def init_heart_beat(self):
|
80 |
+
self.register_to_controller()
|
81 |
+
self.heart_beat_thread = threading.Thread(
|
82 |
+
target=heart_beat_worker,
|
83 |
+
args=(self,),
|
84 |
+
daemon=True,
|
85 |
+
)
|
86 |
+
self.heart_beat_thread.start()
|
87 |
+
|
88 |
+
def register_to_controller(self):
|
89 |
+
logger.info("Register to controller")
|
90 |
+
|
91 |
+
url = self.controller_addr + "/register_worker"
|
92 |
+
data = {
|
93 |
+
"worker_name": self.worker_addr,
|
94 |
+
"check_heart_beat": True,
|
95 |
+
"worker_status": self.get_status(),
|
96 |
+
}
|
97 |
+
r = requests.post(url, json=data)
|
98 |
+
assert r.status_code == 200
|
99 |
+
|
100 |
+
def send_heart_beat(self):
|
101 |
+
logger.info(
|
102 |
+
f"Send heart beat. Models: {self.model_names}. "
|
103 |
+
f"Semaphore: {pretty_print_semaphore(self.semaphore)}. "
|
104 |
+
f"call_ct: {self.call_ct}. "
|
105 |
+
f"worker_id: {self.worker_id}. "
|
106 |
+
)
|
107 |
+
|
108 |
+
url = self.controller_addr + "/receive_heart_beat"
|
109 |
+
|
110 |
+
while True:
|
111 |
+
try:
|
112 |
+
ret = requests.post(
|
113 |
+
url,
|
114 |
+
json={
|
115 |
+
"worker_name": self.worker_addr,
|
116 |
+
"queue_length": self.get_queue_length(),
|
117 |
+
},
|
118 |
+
timeout=5,
|
119 |
+
)
|
120 |
+
exist = ret.json()["exist"]
|
121 |
+
break
|
122 |
+
except (requests.exceptions.RequestException, KeyError) as e:
|
123 |
+
logger.error(f"heart beat error: {e}")
|
124 |
+
time.sleep(5)
|
125 |
+
|
126 |
+
if not exist:
|
127 |
+
self.register_to_controller()
|
128 |
+
|
129 |
+
def get_queue_length(self):
|
130 |
+
if (
|
131 |
+
self.semaphore is None
|
132 |
+
or self.semaphore._value is None
|
133 |
+
or self.semaphore._waiters is None
|
134 |
+
):
|
135 |
+
return 0
|
136 |
+
else:
|
137 |
+
return (
|
138 |
+
self.limit_worker_concurrency
|
139 |
+
- self.semaphore._value
|
140 |
+
+ len(self.semaphore._waiters)
|
141 |
+
)
|
142 |
+
|
143 |
+
def get_status(self):
|
144 |
+
return {
|
145 |
+
"model_names": self.model_names,
|
146 |
+
"speed": 1,
|
147 |
+
"queue_length": self.get_queue_length(),
|
148 |
+
}
|
149 |
+
|
150 |
+
def count_token(self, params):
|
151 |
+
prompt = params["prompt"]
|
152 |
+
|
153 |
+
try:
|
154 |
+
input_ids = self.tokenizer(prompt).input_ids
|
155 |
+
input_echo_len = len(input_ids)
|
156 |
+
except TypeError:
|
157 |
+
input_echo_len = self.tokenizer.num_tokens(prompt)
|
158 |
+
|
159 |
+
ret = {
|
160 |
+
"count": input_echo_len,
|
161 |
+
"error_code": 0,
|
162 |
+
}
|
163 |
+
return ret
|
164 |
+
|
165 |
+
def get_conv_template(self):
|
166 |
+
return {"conv": self.conv}
|
167 |
+
|
168 |
+
def generate_stream_gate(self, params):
|
169 |
+
raise NotImplementedError
|
170 |
+
|
171 |
+
def generate_gate(self, params):
|
172 |
+
raise NotImplementedError
|
173 |
+
|
174 |
+
def get_embeddings(self, params):
|
175 |
+
raise NotImplementedError
|
176 |
+
|
177 |
+
|
178 |
+
def release_worker_semaphore():
|
179 |
+
worker.semaphore.release()
|
180 |
+
|
181 |
+
|
182 |
+
def acquire_worker_semaphore():
|
183 |
+
if worker.semaphore is None:
|
184 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
185 |
+
return worker.semaphore.acquire()
|
186 |
+
|
187 |
+
|
188 |
+
def create_background_tasks():
|
189 |
+
background_tasks = BackgroundTasks()
|
190 |
+
background_tasks.add_task(release_worker_semaphore)
|
191 |
+
return background_tasks
|
192 |
+
|
193 |
+
|
194 |
+
@app.post("/worker_generate_stream")
|
195 |
+
async def api_generate_stream(request: Request):
|
196 |
+
params = await request.json()
|
197 |
+
await acquire_worker_semaphore()
|
198 |
+
generator = worker.generate_stream_gate(params)
|
199 |
+
background_tasks = create_background_tasks()
|
200 |
+
return StreamingResponse(generator, background=background_tasks)
|
201 |
+
|
202 |
+
|
203 |
+
@app.post("/worker_generate")
|
204 |
+
async def api_generate(request: Request):
|
205 |
+
params = await request.json()
|
206 |
+
await acquire_worker_semaphore()
|
207 |
+
output = await asyncio.to_thread(worker.generate_gate, params)
|
208 |
+
release_worker_semaphore()
|
209 |
+
return JSONResponse(output)
|
210 |
+
|
211 |
+
|
212 |
+
@app.post("/worker_get_embeddings")
|
213 |
+
async def api_get_embeddings(request: Request):
|
214 |
+
params = await request.json()
|
215 |
+
await acquire_worker_semaphore()
|
216 |
+
embedding = worker.get_embeddings(params)
|
217 |
+
release_worker_semaphore()
|
218 |
+
return JSONResponse(content=embedding)
|
219 |
+
|
220 |
+
|
221 |
+
@app.post("/worker_get_status")
|
222 |
+
async def api_get_status(request: Request):
|
223 |
+
return worker.get_status()
|
224 |
+
|
225 |
+
|
226 |
+
@app.post("/count_token")
|
227 |
+
async def api_count_token(request: Request):
|
228 |
+
params = await request.json()
|
229 |
+
return worker.count_token(params)
|
230 |
+
|
231 |
+
|
232 |
+
@app.post("/worker_get_conv_template")
|
233 |
+
async def api_get_conv(request: Request):
|
234 |
+
return worker.get_conv_template()
|
235 |
+
|
236 |
+
|
237 |
+
@app.post("/model_details")
|
238 |
+
async def api_model_details(request: Request):
|
239 |
+
return {"context_length": worker.context_len}
|
cli.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chat with a model with command line interface.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
|
6 |
+
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
|
7 |
+
|
8 |
+
Other commands:
|
9 |
+
- Type "!!exit" or an empty line to exit.
|
10 |
+
- Type "!!reset" to start a new conversation.
|
11 |
+
- Type "!!remove" to remove the last prompt.
|
12 |
+
- Type "!!regen" to regenerate the last message.
|
13 |
+
- Type "!!save <filename>" to save the conversation history to a json file.
|
14 |
+
- Type "!!load <filename>" to load a conversation history from a json file.
|
15 |
+
"""
|
16 |
+
import argparse
|
17 |
+
import os
|
18 |
+
import re
|
19 |
+
import sys
|
20 |
+
|
21 |
+
from prompt_toolkit import PromptSession
|
22 |
+
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
23 |
+
from prompt_toolkit.completion import WordCompleter
|
24 |
+
from prompt_toolkit.history import InMemoryHistory
|
25 |
+
from prompt_toolkit.key_binding import KeyBindings
|
26 |
+
from rich.console import Console
|
27 |
+
from rich.live import Live
|
28 |
+
from rich.markdown import Markdown
|
29 |
+
import torch
|
30 |
+
|
31 |
+
from fastchat.model.model_adapter import add_model_args
|
32 |
+
from fastchat.modules.awq import AWQConfig
|
33 |
+
from fastchat.modules.exllama import ExllamaConfig
|
34 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
35 |
+
from fastchat.modules.gptq import GptqConfig
|
36 |
+
from fastchat.serve.inference import ChatIO, chat_loop
|
37 |
+
from fastchat.utils import str_to_torch_dtype
|
38 |
+
|
39 |
+
|
40 |
+
class SimpleChatIO(ChatIO):
|
41 |
+
def __init__(self, multiline: bool = False, prefix: str = ''):
|
42 |
+
self._multiline = multiline
|
43 |
+
self.prefix = prefix
|
44 |
+
|
45 |
+
def prompt_for_input(self, role) -> str:
|
46 |
+
if not self._multiline:
|
47 |
+
return input(f"{role}: {self.prefix}")
|
48 |
+
|
49 |
+
prompt_data = []
|
50 |
+
line = input(f"{role} [ctrl-d/z on empty line to end]: ")
|
51 |
+
while True:
|
52 |
+
prompt_data.append(line.strip())
|
53 |
+
try:
|
54 |
+
line = input()
|
55 |
+
except EOFError as e:
|
56 |
+
break
|
57 |
+
return f"\n{self.prefix}".join(prompt_data)
|
58 |
+
|
59 |
+
def prompt_for_output(self, role: str):
|
60 |
+
print(f"{role}: ", end="", flush=True)
|
61 |
+
|
62 |
+
def stream_output(self, output_stream):
|
63 |
+
pre = 0
|
64 |
+
for outputs in output_stream:
|
65 |
+
output_text = outputs["text"]
|
66 |
+
output_text = output_text.strip().split(" ")
|
67 |
+
now = len(output_text) - 1
|
68 |
+
if now > pre:
|
69 |
+
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
70 |
+
pre = now
|
71 |
+
print(" ".join(output_text[pre:]), flush=True)
|
72 |
+
return " ".join(output_text)
|
73 |
+
|
74 |
+
def print_output(self, text: str):
|
75 |
+
print(text)
|
76 |
+
|
77 |
+
|
78 |
+
class RichChatIO(ChatIO):
|
79 |
+
bindings = KeyBindings()
|
80 |
+
|
81 |
+
@bindings.add("escape", "enter")
|
82 |
+
def _(event):
|
83 |
+
event.app.current_buffer.newline()
|
84 |
+
|
85 |
+
def __init__(self, multiline: bool = False, mouse: bool = False):
|
86 |
+
self._prompt_session = PromptSession(history=InMemoryHistory())
|
87 |
+
self._completer = WordCompleter(
|
88 |
+
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
|
89 |
+
pattern=re.compile("$"),
|
90 |
+
)
|
91 |
+
self._console = Console()
|
92 |
+
self._multiline = multiline
|
93 |
+
self._mouse = mouse
|
94 |
+
|
95 |
+
def prompt_for_input(self, role) -> str:
|
96 |
+
self._console.print(f"[bold]{role}:")
|
97 |
+
# TODO(suquark): multiline input has some issues. fix it later.
|
98 |
+
prompt_input = self._prompt_session.prompt(
|
99 |
+
completer=self._completer,
|
100 |
+
multiline=False,
|
101 |
+
mouse_support=self._mouse,
|
102 |
+
auto_suggest=AutoSuggestFromHistory(),
|
103 |
+
key_bindings=self.bindings if self._multiline else None,
|
104 |
+
)
|
105 |
+
self._console.print()
|
106 |
+
return prompt_input
|
107 |
+
|
108 |
+
def prompt_for_output(self, role: str):
|
109 |
+
self._console.print(f"[bold]{role.replace('/', '|')}:")
|
110 |
+
|
111 |
+
def stream_output(self, output_stream):
|
112 |
+
"""Stream output from a role."""
|
113 |
+
# TODO(suquark): the console flickers when there is a code block
|
114 |
+
# above it. We need to cut off "live" when a code block is done.
|
115 |
+
|
116 |
+
# Create a Live context for updating the console output
|
117 |
+
with Live(console=self._console, refresh_per_second=4) as live:
|
118 |
+
# Read lines from the stream
|
119 |
+
for outputs in output_stream:
|
120 |
+
if not outputs:
|
121 |
+
continue
|
122 |
+
text = outputs["text"]
|
123 |
+
# Render the accumulated text as Markdown
|
124 |
+
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
125 |
+
# in rich. The chatbots output treat "\n" as a new line for
|
126 |
+
# better compatibility with real-world text. However, rendering
|
127 |
+
# in markdown would break the format. It is because standard markdown
|
128 |
+
# treat a single "\n" in normal text as a space.
|
129 |
+
# Our workaround is adding two spaces at the end of each line.
|
130 |
+
# This is not a perfect solution, as it would
|
131 |
+
# introduce trailing spaces (only) in code block, but it works well
|
132 |
+
# especially for console output, because in general the console does not
|
133 |
+
# care about trailing spaces.
|
134 |
+
lines = []
|
135 |
+
for line in text.splitlines():
|
136 |
+
lines.append(line)
|
137 |
+
if line.startswith("```"):
|
138 |
+
# Code block marker - do not add trailing spaces, as it would
|
139 |
+
# break the syntax highlighting
|
140 |
+
lines.append("\n")
|
141 |
+
else:
|
142 |
+
lines.append(" \n")
|
143 |
+
markdown = Markdown("".join(lines))
|
144 |
+
# Update the Live console output
|
145 |
+
live.update(markdown)
|
146 |
+
self._console.print()
|
147 |
+
return text
|
148 |
+
|
149 |
+
def print_output(self, text: str):
|
150 |
+
self.stream_output([{"text": text}])
|
151 |
+
|
152 |
+
|
153 |
+
class ProgrammaticChatIO(ChatIO):
|
154 |
+
def prompt_for_input(self, role) -> str:
|
155 |
+
contents = ""
|
156 |
+
# `end_sequence` signals the end of a message. It is unlikely to occur in
|
157 |
+
# message content.
|
158 |
+
end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
|
159 |
+
len_end = len(end_sequence)
|
160 |
+
while True:
|
161 |
+
if len(contents) >= len_end:
|
162 |
+
last_chars = contents[-len_end:]
|
163 |
+
if last_chars == end_sequence:
|
164 |
+
break
|
165 |
+
try:
|
166 |
+
char = sys.stdin.read(1)
|
167 |
+
contents = contents + char
|
168 |
+
except EOFError:
|
169 |
+
continue
|
170 |
+
contents = contents[:-len_end]
|
171 |
+
print(f"[!OP:{role}]: {contents}", flush=True)
|
172 |
+
return contents
|
173 |
+
|
174 |
+
def prompt_for_output(self, role: str):
|
175 |
+
print(f"[!OP:{role}]: ", end="", flush=True)
|
176 |
+
|
177 |
+
def stream_output(self, output_stream):
|
178 |
+
pre = 0
|
179 |
+
for outputs in output_stream:
|
180 |
+
output_text = outputs["text"]
|
181 |
+
output_text = output_text.strip().split(" ")
|
182 |
+
now = len(output_text) - 1
|
183 |
+
if now > pre:
|
184 |
+
print(" ".join(output_text[pre:now]), end=" ", flush=True)
|
185 |
+
pre = now
|
186 |
+
print(" ".join(output_text[pre:]), flush=True)
|
187 |
+
return " ".join(output_text)
|
188 |
+
|
189 |
+
def print_output(self, text: str):
|
190 |
+
print(text)
|
191 |
+
|
192 |
+
|
193 |
+
def main(args):
|
194 |
+
if args.gpus:
|
195 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
196 |
+
raise ValueError(
|
197 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
198 |
+
)
|
199 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
200 |
+
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
|
201 |
+
if args.enable_exllama:
|
202 |
+
exllama_config = ExllamaConfig(
|
203 |
+
max_seq_len=args.exllama_max_seq_len,
|
204 |
+
gpu_split=args.exllama_gpu_split,
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
exllama_config = None
|
208 |
+
if args.enable_xft:
|
209 |
+
xft_config = XftConfig(
|
210 |
+
max_seq_len=args.xft_max_seq_len,
|
211 |
+
data_type=args.xft_dtype,
|
212 |
+
)
|
213 |
+
if args.device != "cpu":
|
214 |
+
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
|
215 |
+
args.device = "cpu"
|
216 |
+
else:
|
217 |
+
xft_config = None
|
218 |
+
if args.style == "simple":
|
219 |
+
chatio = SimpleChatIO(args.multiline)
|
220 |
+
elif args.style == "rich":
|
221 |
+
chatio = RichChatIO(args.multiline, args.mouse)
|
222 |
+
elif args.style == "programmatic":
|
223 |
+
chatio = ProgrammaticChatIO()
|
224 |
+
else:
|
225 |
+
raise ValueError(f"Invalid style for console: {args.style}")
|
226 |
+
try:
|
227 |
+
if args.upload_file_path:
|
228 |
+
prefix = open(args.upload_file_path, 'r').read()
|
229 |
+
args.conv_system_msg = prefix[:20000]
|
230 |
+
chat_loop(
|
231 |
+
args.model_path,
|
232 |
+
args.device,
|
233 |
+
args.num_gpus,
|
234 |
+
args.max_gpu_memory,
|
235 |
+
str_to_torch_dtype(args.dtype),
|
236 |
+
args.load_8bit,
|
237 |
+
args.cpu_offloading,
|
238 |
+
args.conv_template,
|
239 |
+
args.conv_system_msg,
|
240 |
+
args.temperature,
|
241 |
+
args.repetition_penalty,
|
242 |
+
args.max_new_tokens,
|
243 |
+
chatio,
|
244 |
+
gptq_config=GptqConfig(
|
245 |
+
ckpt=args.gptq_ckpt or args.model_path,
|
246 |
+
wbits=args.gptq_wbits,
|
247 |
+
groupsize=args.gptq_groupsize,
|
248 |
+
act_order=args.gptq_act_order,
|
249 |
+
),
|
250 |
+
awq_config=AWQConfig(
|
251 |
+
ckpt=args.awq_ckpt or args.model_path,
|
252 |
+
wbits=args.awq_wbits,
|
253 |
+
groupsize=args.awq_groupsize,
|
254 |
+
),
|
255 |
+
exllama_config=exllama_config,
|
256 |
+
xft_config=xft_config,
|
257 |
+
revision=args.revision,
|
258 |
+
judge_sent_end=args.judge_sent_end,
|
259 |
+
debug=args.debug,
|
260 |
+
history=not args.no_history,
|
261 |
+
)
|
262 |
+
except KeyboardInterrupt:
|
263 |
+
print("exit...")
|
264 |
+
|
265 |
+
|
266 |
+
if __name__ == "__main__":
|
267 |
+
parser = argparse.ArgumentParser()
|
268 |
+
add_model_args(parser)
|
269 |
+
parser.add_argument(
|
270 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
271 |
+
)
|
272 |
+
parser.add_argument(
|
273 |
+
"--conv-system-msg", type=str, default=None, help="Conversation system message."
|
274 |
+
)
|
275 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
276 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
277 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
278 |
+
parser.add_argument("--no-history", action="store_true")
|
279 |
+
parser.add_argument(
|
280 |
+
"--style",
|
281 |
+
type=str,
|
282 |
+
default="simple",
|
283 |
+
choices=["simple", "rich", "programmatic"],
|
284 |
+
help="Display style.",
|
285 |
+
)
|
286 |
+
parser.add_argument(
|
287 |
+
"--multiline",
|
288 |
+
action="store_true",
|
289 |
+
help="Enable multiline input. Use ESC+Enter for newline.",
|
290 |
+
)
|
291 |
+
parser.add_argument(
|
292 |
+
"--mouse",
|
293 |
+
action="store_true",
|
294 |
+
help="[Rich Style]: Enable mouse support for cursor positioning.",
|
295 |
+
)
|
296 |
+
parser.add_argument(
|
297 |
+
"--judge-sent-end",
|
298 |
+
action="store_true",
|
299 |
+
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
|
300 |
+
)
|
301 |
+
parser.add_argument(
|
302 |
+
"--debug",
|
303 |
+
action="store_true",
|
304 |
+
help="Print useful debug information (e.g., prompts)",
|
305 |
+
)
|
306 |
+
parser.add_argument(
|
307 |
+
"--upload-file-path",
|
308 |
+
type=str,
|
309 |
+
default="",
|
310 |
+
help="upload long txt for summary.",
|
311 |
+
)
|
312 |
+
args = parser.parse_args()
|
313 |
+
main(args)
|
controller.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import asyncio
|
7 |
+
import dataclasses
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import time
|
13 |
+
from typing import List, Union
|
14 |
+
import threading
|
15 |
+
|
16 |
+
from fastapi import FastAPI, Request
|
17 |
+
from fastapi.responses import StreamingResponse
|
18 |
+
import numpy as np
|
19 |
+
import requests
|
20 |
+
import uvicorn
|
21 |
+
|
22 |
+
from fastchat.constants import (
|
23 |
+
CONTROLLER_HEART_BEAT_EXPIRATION,
|
24 |
+
WORKER_API_TIMEOUT,
|
25 |
+
ErrorCode,
|
26 |
+
SERVER_ERROR_MSG,
|
27 |
+
)
|
28 |
+
from fastchat.utils import build_logger
|
29 |
+
|
30 |
+
|
31 |
+
logger = build_logger("controller", "controller.log")
|
32 |
+
|
33 |
+
|
34 |
+
class DispatchMethod(Enum):
|
35 |
+
LOTTERY = auto()
|
36 |
+
SHORTEST_QUEUE = auto()
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def from_str(cls, name):
|
40 |
+
if name == "lottery":
|
41 |
+
return cls.LOTTERY
|
42 |
+
elif name == "shortest_queue":
|
43 |
+
return cls.SHORTEST_QUEUE
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Invalid dispatch method")
|
46 |
+
|
47 |
+
|
48 |
+
@dataclasses.dataclass
|
49 |
+
class WorkerInfo:
|
50 |
+
model_names: List[str]
|
51 |
+
speed: int
|
52 |
+
queue_length: int
|
53 |
+
check_heart_beat: bool
|
54 |
+
last_heart_beat: str
|
55 |
+
|
56 |
+
|
57 |
+
def heart_beat_controller(controller):
|
58 |
+
while True:
|
59 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
60 |
+
controller.remove_stale_workers_by_expiration()
|
61 |
+
|
62 |
+
|
63 |
+
class Controller:
|
64 |
+
def __init__(self, dispatch_method: str):
|
65 |
+
# Dict[str -> WorkerInfo]
|
66 |
+
self.worker_info = {}
|
67 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
68 |
+
|
69 |
+
self.heart_beat_thread = threading.Thread(
|
70 |
+
target=heart_beat_controller, args=(self,)
|
71 |
+
)
|
72 |
+
self.heart_beat_thread.start()
|
73 |
+
|
74 |
+
def register_worker(
|
75 |
+
self, worker_name: str, check_heart_beat: bool, worker_status: dict
|
76 |
+
):
|
77 |
+
if worker_name not in self.worker_info:
|
78 |
+
logger.info(f"Register a new worker: {worker_name}")
|
79 |
+
else:
|
80 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
81 |
+
|
82 |
+
if not worker_status:
|
83 |
+
worker_status = self.get_worker_status(worker_name)
|
84 |
+
if not worker_status:
|
85 |
+
return False
|
86 |
+
|
87 |
+
self.worker_info[worker_name] = WorkerInfo(
|
88 |
+
worker_status["model_names"],
|
89 |
+
worker_status["speed"],
|
90 |
+
worker_status["queue_length"],
|
91 |
+
check_heart_beat,
|
92 |
+
time.time(),
|
93 |
+
)
|
94 |
+
|
95 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
96 |
+
return True
|
97 |
+
|
98 |
+
def get_worker_status(self, worker_name: str):
|
99 |
+
try:
|
100 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
101 |
+
except requests.exceptions.RequestException as e:
|
102 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
103 |
+
return None
|
104 |
+
|
105 |
+
if r.status_code != 200:
|
106 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
107 |
+
return None
|
108 |
+
|
109 |
+
return r.json()
|
110 |
+
|
111 |
+
def remove_worker(self, worker_name: str):
|
112 |
+
del self.worker_info[worker_name]
|
113 |
+
|
114 |
+
def refresh_all_workers(self):
|
115 |
+
old_info = dict(self.worker_info)
|
116 |
+
self.worker_info = {}
|
117 |
+
|
118 |
+
for w_name, w_info in old_info.items():
|
119 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
120 |
+
logger.info(f"Remove stale worker: {w_name}")
|
121 |
+
|
122 |
+
def list_models(self):
|
123 |
+
model_names = set()
|
124 |
+
|
125 |
+
for w_name, w_info in self.worker_info.items():
|
126 |
+
model_names.update(w_info.model_names)
|
127 |
+
|
128 |
+
return list(model_names)
|
129 |
+
|
130 |
+
def get_worker_address(self, model_name: str):
|
131 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
132 |
+
worker_names = []
|
133 |
+
worker_speeds = []
|
134 |
+
for w_name, w_info in self.worker_info.items():
|
135 |
+
if model_name in w_info.model_names:
|
136 |
+
worker_names.append(w_name)
|
137 |
+
worker_speeds.append(w_info.speed)
|
138 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
139 |
+
norm = np.sum(worker_speeds)
|
140 |
+
if norm < 1e-4:
|
141 |
+
return ""
|
142 |
+
worker_speeds = worker_speeds / norm
|
143 |
+
if True: # Directly return address
|
144 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
145 |
+
worker_name = worker_names[pt]
|
146 |
+
return worker_name
|
147 |
+
|
148 |
+
# Check status before returning
|
149 |
+
while True:
|
150 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
151 |
+
worker_name = worker_names[pt]
|
152 |
+
|
153 |
+
if self.get_worker_status(worker_name):
|
154 |
+
break
|
155 |
+
else:
|
156 |
+
self.remove_worker(worker_name)
|
157 |
+
worker_speeds[pt] = 0
|
158 |
+
norm = np.sum(worker_speeds)
|
159 |
+
if norm < 1e-4:
|
160 |
+
return ""
|
161 |
+
worker_speeds = worker_speeds / norm
|
162 |
+
continue
|
163 |
+
return worker_name
|
164 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
165 |
+
worker_names = []
|
166 |
+
worker_qlen = []
|
167 |
+
for w_name, w_info in self.worker_info.items():
|
168 |
+
if model_name in w_info.model_names:
|
169 |
+
worker_names.append(w_name)
|
170 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
171 |
+
if len(worker_names) == 0:
|
172 |
+
return ""
|
173 |
+
min_index = np.argmin(worker_qlen)
|
174 |
+
w_name = worker_names[min_index]
|
175 |
+
self.worker_info[w_name].queue_length += 1
|
176 |
+
logger.info(
|
177 |
+
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
|
178 |
+
)
|
179 |
+
return w_name
|
180 |
+
else:
|
181 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
182 |
+
|
183 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
184 |
+
if worker_name not in self.worker_info:
|
185 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
186 |
+
return False
|
187 |
+
|
188 |
+
self.worker_info[worker_name].queue_length = queue_length
|
189 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
190 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
191 |
+
return True
|
192 |
+
|
193 |
+
def remove_stale_workers_by_expiration(self):
|
194 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
195 |
+
to_delete = []
|
196 |
+
for worker_name, w_info in self.worker_info.items():
|
197 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
198 |
+
to_delete.append(worker_name)
|
199 |
+
|
200 |
+
for worker_name in to_delete:
|
201 |
+
self.remove_worker(worker_name)
|
202 |
+
|
203 |
+
def handle_no_worker(self, params):
|
204 |
+
logger.info(f"no worker: {params['model']}")
|
205 |
+
ret = {
|
206 |
+
"text": SERVER_ERROR_MSG,
|
207 |
+
"error_code": ErrorCode.CONTROLLER_NO_WORKER,
|
208 |
+
}
|
209 |
+
return json.dumps(ret).encode() + b"\0"
|
210 |
+
|
211 |
+
def handle_worker_timeout(self, worker_address):
|
212 |
+
logger.info(f"worker timeout: {worker_address}")
|
213 |
+
ret = {
|
214 |
+
"text": SERVER_ERROR_MSG,
|
215 |
+
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
|
216 |
+
}
|
217 |
+
return json.dumps(ret).encode() + b"\0"
|
218 |
+
|
219 |
+
# Let the controller act as a worker to achieve hierarchical
|
220 |
+
# management. This can be used to connect isolated sub networks.
|
221 |
+
def worker_api_get_status(self):
|
222 |
+
model_names = set()
|
223 |
+
speed = 0
|
224 |
+
queue_length = 0
|
225 |
+
|
226 |
+
for w_name in self.worker_info:
|
227 |
+
worker_status = self.get_worker_status(w_name)
|
228 |
+
if worker_status is not None:
|
229 |
+
model_names.update(worker_status["model_names"])
|
230 |
+
speed += worker_status["speed"]
|
231 |
+
queue_length += worker_status["queue_length"]
|
232 |
+
|
233 |
+
model_names = sorted(list(model_names))
|
234 |
+
return {
|
235 |
+
"model_names": model_names,
|
236 |
+
"speed": speed,
|
237 |
+
"queue_length": queue_length,
|
238 |
+
}
|
239 |
+
|
240 |
+
def worker_api_generate_stream(self, params):
|
241 |
+
worker_addr = self.get_worker_address(params["model"])
|
242 |
+
if not worker_addr:
|
243 |
+
yield self.handle_no_worker(params)
|
244 |
+
|
245 |
+
try:
|
246 |
+
response = requests.post(
|
247 |
+
worker_addr + "/worker_generate_stream",
|
248 |
+
json=params,
|
249 |
+
stream=True,
|
250 |
+
timeout=WORKER_API_TIMEOUT,
|
251 |
+
)
|
252 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
253 |
+
if chunk:
|
254 |
+
yield chunk + b"\0"
|
255 |
+
except requests.exceptions.RequestException as e:
|
256 |
+
yield self.handle_worker_timeout(worker_addr)
|
257 |
+
|
258 |
+
|
259 |
+
app = FastAPI()
|
260 |
+
|
261 |
+
|
262 |
+
@app.post("/register_worker")
|
263 |
+
async def register_worker(request: Request):
|
264 |
+
data = await request.json()
|
265 |
+
controller.register_worker(
|
266 |
+
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
|
267 |
+
)
|
268 |
+
|
269 |
+
|
270 |
+
@app.post("/refresh_all_workers")
|
271 |
+
async def refresh_all_workers():
|
272 |
+
models = controller.refresh_all_workers()
|
273 |
+
|
274 |
+
|
275 |
+
@app.post("/list_models")
|
276 |
+
async def list_models():
|
277 |
+
models = controller.list_models()
|
278 |
+
return {"models": models}
|
279 |
+
|
280 |
+
|
281 |
+
@app.post("/get_worker_address")
|
282 |
+
async def get_worker_address(request: Request):
|
283 |
+
data = await request.json()
|
284 |
+
addr = controller.get_worker_address(data["model"])
|
285 |
+
return {"address": addr}
|
286 |
+
|
287 |
+
|
288 |
+
@app.post("/receive_heart_beat")
|
289 |
+
async def receive_heart_beat(request: Request):
|
290 |
+
data = await request.json()
|
291 |
+
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
|
292 |
+
return {"exist": exist}
|
293 |
+
|
294 |
+
|
295 |
+
@app.post("/worker_generate_stream")
|
296 |
+
async def worker_api_generate_stream(request: Request):
|
297 |
+
params = await request.json()
|
298 |
+
generator = controller.worker_api_generate_stream(params)
|
299 |
+
return StreamingResponse(generator)
|
300 |
+
|
301 |
+
|
302 |
+
@app.post("/worker_get_status")
|
303 |
+
async def worker_api_get_status(request: Request):
|
304 |
+
return controller.worker_api_get_status()
|
305 |
+
|
306 |
+
|
307 |
+
@app.get("/test_connection")
|
308 |
+
async def worker_api_get_status(request: Request):
|
309 |
+
return "success"
|
310 |
+
|
311 |
+
|
312 |
+
def create_controller():
|
313 |
+
parser = argparse.ArgumentParser()
|
314 |
+
parser.add_argument("--host", type=str, default="localhost")
|
315 |
+
parser.add_argument("--port", type=int, default=21001)
|
316 |
+
parser.add_argument(
|
317 |
+
"--dispatch-method",
|
318 |
+
type=str,
|
319 |
+
choices=["lottery", "shortest_queue"],
|
320 |
+
default="shortest_queue",
|
321 |
+
)
|
322 |
+
parser.add_argument(
|
323 |
+
"--ssl",
|
324 |
+
action="store_true",
|
325 |
+
required=False,
|
326 |
+
default=False,
|
327 |
+
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
|
328 |
+
)
|
329 |
+
args = parser.parse_args()
|
330 |
+
logger.info(f"args: {args}")
|
331 |
+
|
332 |
+
controller = Controller(args.dispatch_method)
|
333 |
+
return args, controller
|
334 |
+
|
335 |
+
|
336 |
+
if __name__ == "__main__":
|
337 |
+
args, controller = create_controller()
|
338 |
+
if args.ssl:
|
339 |
+
uvicorn.run(
|
340 |
+
app,
|
341 |
+
host=args.host,
|
342 |
+
port=args.port,
|
343 |
+
log_level="info",
|
344 |
+
ssl_keyfile=os.environ["SSL_KEYFILE"],
|
345 |
+
ssl_certfile=os.environ["SSL_CERTFILE"],
|
346 |
+
)
|
347 |
+
else:
|
348 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
gateway/README.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# fastchat Nginx Gateway
|
2 |
+
|
3 |
+
## Purpose of the Gateway
|
4 |
+
|
5 |
+
The Nginx gateway serves the following purposes:
|
6 |
+
|
7 |
+
1. Protects Gradio servers by acting as a firewall.
|
8 |
+
2. Facilitates dynamic mounting and unmounting of Gradio servers.
|
9 |
+
3. Provides load balancing for Gradio servers.
|
10 |
+
4. Offers additional security features, such as total connection limit.
|
11 |
+
5. Reduces attack surface by requiring only a single public port to be exposed for serving.
|
12 |
+
|
13 |
+
## Deployment and Updating of the Gateway
|
14 |
+
|
15 |
+
### Installing Nginx
|
16 |
+
|
17 |
+
On Debian-based distributions (e.g., Ubuntu):
|
18 |
+
|
19 |
+
```bash
|
20 |
+
sudo apt update
|
21 |
+
sudo apt install nginx
|
22 |
+
```
|
23 |
+
On Red Hat-based distributions (e.g., CentOS, Fedora):
|
24 |
+
|
25 |
+
```bash
|
26 |
+
sudo yum install epel-release
|
27 |
+
sudo yum install nginx
|
28 |
+
```
|
29 |
+
|
30 |
+
### Deployment
|
31 |
+
|
32 |
+
Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission).
|
33 |
+
|
34 |
+
Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server.
|
35 |
+
|
36 |
+
Modify `upstream websocket` to configure Gradio servers behind the gateway.
|
37 |
+
|
38 |
+
Lastly, update Nginx.
|
39 |
+
|
40 |
+
|
41 |
+
### HTTPS Deployment with a Public Domain URL
|
42 |
+
|
43 |
+
Make sure you obtain the HTTPS certificate and the private key used to generate the certificate.
|
44 |
+
|
45 |
+
Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields.
|
46 |
+
|
47 |
+
If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url.
|
48 |
+
|
49 |
+
### Updating
|
50 |
+
|
51 |
+
Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
sudo nginx -t # check `/etc/nginx/nginx.conf`
|
55 |
+
sudo systemctl reload nginx # restart Nginx service to load the new config
|
56 |
+
sudo systemctl status nginx # check the status of the Nginx service. It should be active (running).
|
57 |
+
```
|
gateway/nginx.conf
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
user www-data;
|
2 |
+
worker_processes auto;
|
3 |
+
pid /run/nginx.pid;
|
4 |
+
include /etc/nginx/modules-enabled/*.conf;
|
5 |
+
|
6 |
+
events {
|
7 |
+
worker_connections 1024; # maximum number of connections that a worker process can handle concurrently
|
8 |
+
# multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle
|
9 |
+
|
10 |
+
}
|
11 |
+
|
12 |
+
http {
|
13 |
+
##
|
14 |
+
# Basic Settings
|
15 |
+
##
|
16 |
+
|
17 |
+
sendfile on; # enable sendfile for performance optimization
|
18 |
+
tcp_nopush on; # enable TCP no-pushing
|
19 |
+
tcp_nodelay on; # enable TCP no-delay
|
20 |
+
keepalive_timeout 65; # sets the timeout for keep-alive connections
|
21 |
+
types_hash_max_size 2048; # maximum size of the types hash table
|
22 |
+
# server_tokens off; # disable server token (i.e., server signature) in response headers to improve security
|
23 |
+
|
24 |
+
# server_names_hash_bucket_size 64;
|
25 |
+
# server_name_in_redirect off;
|
26 |
+
|
27 |
+
include /etc/nginx/mime.types; # include MIME types file
|
28 |
+
default_type application/octet-stream; # default MIME type for unknown file types
|
29 |
+
|
30 |
+
##
|
31 |
+
# SSL Settings
|
32 |
+
##
|
33 |
+
|
34 |
+
ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use
|
35 |
+
ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers
|
36 |
+
|
37 |
+
##
|
38 |
+
# Logging Settings
|
39 |
+
##
|
40 |
+
|
41 |
+
access_log /var/log/nginx/access.log; # path to access log file
|
42 |
+
error_log /var/log/nginx/error.log; # path to error log file
|
43 |
+
|
44 |
+
##
|
45 |
+
# Gzip Settings
|
46 |
+
##
|
47 |
+
gzip on; # enable Gzip compression
|
48 |
+
|
49 |
+
##
|
50 |
+
# Virtual Host Configs
|
51 |
+
##
|
52 |
+
|
53 |
+
include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory
|
54 |
+
include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files
|
55 |
+
|
56 |
+
# WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/
|
57 |
+
map $http_upgrade $connection_upgrade {
|
58 |
+
default upgrade;
|
59 |
+
'' close;
|
60 |
+
}
|
61 |
+
|
62 |
+
upstream websocket {
|
63 |
+
ip_hash; # load balancing by IP to guarantee session persistence
|
64 |
+
server localhost:7860; # The port should be the gradio web server port
|
65 |
+
# server localhost:7861; # extra gradio server if more than one
|
66 |
+
}
|
67 |
+
|
68 |
+
limit_conn_status 429;
|
69 |
+
limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP
|
70 |
+
limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server
|
71 |
+
|
72 |
+
server {
|
73 |
+
listen 443 ssl; # the listening port of our server
|
74 |
+
ssl_certificate [PATH_TO_SSL_CERT];
|
75 |
+
ssl_certificate_key [PATH_TO_PRIVATE_KEY];
|
76 |
+
server_name chat.lmsys.org; # replace the url with your own domain url
|
77 |
+
limit_conn perserver 1024; # connections per server
|
78 |
+
location / {
|
79 |
+
proxy_pass http://websocket; # proxy all requests to the defined upstream server
|
80 |
+
limit_conn perip 5; # connections per IP
|
81 |
+
proxy_set_header Host $host; # set the Host header for the upstream server
|
82 |
+
proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server
|
83 |
+
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header
|
84 |
+
proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication
|
85 |
+
proxy_set_header Upgrade $http_upgrade;
|
86 |
+
proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
# the following block routes all HTTP traffic to HTTPS via nginx
|
91 |
+
server {
|
92 |
+
listen 80;
|
93 |
+
server_name chat.lmsys.org;
|
94 |
+
return 301 https://chat.lmsys.org$request_uri;
|
95 |
+
}
|
96 |
+
|
97 |
+
}
|
gradio_block_arena_anony.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (battle) tab.
|
3 |
+
Users chat with two anonymous models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from fastchat.constants import (
|
13 |
+
MODERATION_MSG,
|
14 |
+
CONVERSATION_LIMIT_MSG,
|
15 |
+
SLOW_MODEL_MSG,
|
16 |
+
INPUT_CHAR_LEN_LIMIT,
|
17 |
+
CONVERSATION_TURN_LIMIT,
|
18 |
+
)
|
19 |
+
from fastchat.model.model_adapter import get_conversation_template
|
20 |
+
from fastchat.serve.gradio_block_arena_named import flash_buttons
|
21 |
+
from fastchat.serve.gradio_web_server import (
|
22 |
+
State,
|
23 |
+
bot_response,
|
24 |
+
get_conv_log_filename,
|
25 |
+
no_change_btn,
|
26 |
+
enable_btn,
|
27 |
+
disable_btn,
|
28 |
+
invisible_btn,
|
29 |
+
acknowledgment_md,
|
30 |
+
ip_expiration_dict,
|
31 |
+
get_ip,
|
32 |
+
)
|
33 |
+
from fastchat.utils import (
|
34 |
+
build_logger,
|
35 |
+
moderation_filter,
|
36 |
+
)
|
37 |
+
|
38 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
39 |
+
|
40 |
+
num_sides = 2
|
41 |
+
enable_moderation = False
|
42 |
+
anony_names = ["", ""]
|
43 |
+
models = []
|
44 |
+
|
45 |
+
|
46 |
+
def set_global_vars_anony(enable_moderation_):
|
47 |
+
global enable_moderation
|
48 |
+
enable_moderation = enable_moderation_
|
49 |
+
|
50 |
+
|
51 |
+
def load_demo_side_by_side_anony(models_, url_params):
|
52 |
+
global models
|
53 |
+
models = models_
|
54 |
+
|
55 |
+
states = (None,) * num_sides
|
56 |
+
selector_updates = (
|
57 |
+
gr.Markdown.update(visible=True),
|
58 |
+
gr.Markdown.update(visible=True),
|
59 |
+
)
|
60 |
+
|
61 |
+
return states + selector_updates
|
62 |
+
|
63 |
+
|
64 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
65 |
+
with open(get_conv_log_filename(), "a") as fout:
|
66 |
+
data = {
|
67 |
+
"tstamp": round(time.time(), 4),
|
68 |
+
"type": vote_type,
|
69 |
+
"models": [x for x in model_selectors],
|
70 |
+
"states": [x.dict() for x in states],
|
71 |
+
"ip": get_ip(request),
|
72 |
+
}
|
73 |
+
fout.write(json.dumps(data) + "\n")
|
74 |
+
|
75 |
+
if ":" not in model_selectors[0]:
|
76 |
+
for i in range(15):
|
77 |
+
names = (
|
78 |
+
"### Model A: " + states[0].model_name,
|
79 |
+
"### Model B: " + states[1].model_name,
|
80 |
+
)
|
81 |
+
yield names + ("",) + (disable_btn,) * 4
|
82 |
+
time.sleep(0.2)
|
83 |
+
else:
|
84 |
+
names = (
|
85 |
+
"### Model A: " + states[0].model_name,
|
86 |
+
"### Model B: " + states[1].model_name,
|
87 |
+
)
|
88 |
+
yield names + ("",) + (disable_btn,) * 4
|
89 |
+
|
90 |
+
|
91 |
+
def leftvote_last_response(
|
92 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
93 |
+
):
|
94 |
+
logger.info(f"leftvote (anony). ip: {get_ip(request)}")
|
95 |
+
for x in vote_last_response(
|
96 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
97 |
+
):
|
98 |
+
yield x
|
99 |
+
|
100 |
+
|
101 |
+
def rightvote_last_response(
|
102 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
103 |
+
):
|
104 |
+
logger.info(f"rightvote (anony). ip: {get_ip(request)}")
|
105 |
+
for x in vote_last_response(
|
106 |
+
[state0, state1], "rightvote", [model_selector0, model_selector1], request
|
107 |
+
):
|
108 |
+
yield x
|
109 |
+
|
110 |
+
|
111 |
+
def tievote_last_response(
|
112 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
113 |
+
):
|
114 |
+
logger.info(f"tievote (anony). ip: {get_ip(request)}")
|
115 |
+
for x in vote_last_response(
|
116 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
117 |
+
):
|
118 |
+
yield x
|
119 |
+
|
120 |
+
|
121 |
+
def bothbad_vote_last_response(
|
122 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
123 |
+
):
|
124 |
+
logger.info(f"bothbad_vote (anony). ip: {get_ip(request)}")
|
125 |
+
for x in vote_last_response(
|
126 |
+
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
|
127 |
+
):
|
128 |
+
yield x
|
129 |
+
|
130 |
+
|
131 |
+
def regenerate(state0, state1, request: gr.Request):
|
132 |
+
logger.info(f"regenerate (anony). ip: {get_ip(request)}")
|
133 |
+
states = [state0, state1]
|
134 |
+
for i in range(num_sides):
|
135 |
+
states[i].conv.update_last_message(None)
|
136 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
|
137 |
+
|
138 |
+
|
139 |
+
def clear_history(request: gr.Request):
|
140 |
+
logger.info(f"clear_history (anony). ip: {get_ip(request)}")
|
141 |
+
return (
|
142 |
+
[None] * num_sides
|
143 |
+
+ [None] * num_sides
|
144 |
+
+ anony_names
|
145 |
+
+ [""]
|
146 |
+
+ [invisible_btn] * 4
|
147 |
+
+ [disable_btn] * 2
|
148 |
+
+ [""]
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
153 |
+
logger.info(f"share (anony). ip: {get_ip(request)}")
|
154 |
+
if state0 is not None and state1 is not None:
|
155 |
+
vote_last_response(
|
156 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
SAMPLING_WEIGHTS = {
|
161 |
+
# tier 0
|
162 |
+
"gpt-4": 4,
|
163 |
+
"gpt-4-turbo": 4,
|
164 |
+
"gpt-3.5-turbo": 2,
|
165 |
+
"gpt-3.5-turbo-1106": 2,
|
166 |
+
"claude-2": 8,
|
167 |
+
"claude-1": 2,
|
168 |
+
"claude-instant-1": 8,
|
169 |
+
"zephyr-7b-beta": 2,
|
170 |
+
"openchat-3.5": 2,
|
171 |
+
# tier 1
|
172 |
+
"deluxe-chat-v1.1": 2,
|
173 |
+
"palm-2": 1.5,
|
174 |
+
"llama-2-70b-chat": 1.5,
|
175 |
+
"llama-2-13b-chat": 1.5,
|
176 |
+
"codellama-34b-instruct": 1.5,
|
177 |
+
"vicuna-33b": 8,
|
178 |
+
"vicuna-13b": 1.5,
|
179 |
+
"wizardlm-70b": 1.5,
|
180 |
+
"wizardlm-13b": 1.5,
|
181 |
+
"qwen-14b-chat": 1.5,
|
182 |
+
"mistral-7b-instruct": 1.5,
|
183 |
+
# tier 2
|
184 |
+
"vicuna-7b": 1.0,
|
185 |
+
"llama-2-7b-chat": 1.0,
|
186 |
+
"chatglm2-6b": 1.0,
|
187 |
+
# deprecated
|
188 |
+
"zephyr-7b-alpha": 1.5,
|
189 |
+
"codellama-13b-instruct": 1.0,
|
190 |
+
"mpt-30b-chat": 1.5,
|
191 |
+
"guanaco-33b": 1.0,
|
192 |
+
"fastchat-t5-3b": 0.5,
|
193 |
+
"alpaca-13b": 0.5,
|
194 |
+
"mpt-7b-chat": 0.1,
|
195 |
+
"oasst-pythia-12b": 0.1,
|
196 |
+
"RWKV-4-Raven-14B": 0.1,
|
197 |
+
"gpt4all-13b-snoozy": 0.1,
|
198 |
+
"koala-13b": 0.1,
|
199 |
+
"stablelm-tuned-alpha-7b": 0.1,
|
200 |
+
"dolly-v2-12b": 0.1,
|
201 |
+
"llama-13b": 0.1,
|
202 |
+
"chatglm-6b": 0.5,
|
203 |
+
"deluxe-chat-v1": 4,
|
204 |
+
}
|
205 |
+
|
206 |
+
# target model sampling weights will be boosted.
|
207 |
+
BATTLE_TARGETS = {
|
208 |
+
"gpt-4": {"claude-2"},
|
209 |
+
"gpt-4-turbo": {"gpt-4", "gpt-3.5-turbo"},
|
210 |
+
"gpt-3.5-turbo": {"claude-instant-1", "gpt-4", "claude-2"},
|
211 |
+
"claude-2": {"gpt-4", "gpt-3.5-turbo", "claude-1"},
|
212 |
+
"claude-1": {"claude-2", "gpt-4", "gpt-3.5-turbo"},
|
213 |
+
"claude-instant-1": {"gpt-3.5-turbo", "claude-2"},
|
214 |
+
"deluxe-chat-v1.1": {"gpt-4"},
|
215 |
+
"openchat-3.5": {"gpt-3.5-turbo", "llama-2-70b-chat", "zephyr-7b-beta"},
|
216 |
+
"qwen-14b-chat": {"vicuna-13b", "llama-2-13b-chat", "llama-2-70b-chat"},
|
217 |
+
"zephyr-7b-alpha": {"mistral-7b-instruct", "llama-2-13b-chat"},
|
218 |
+
"zephyr-7b-beta": {
|
219 |
+
"mistral-7b-instruct",
|
220 |
+
"llama-2-13b-chat",
|
221 |
+
"llama-2-7b-chat",
|
222 |
+
"wizardlm-13b",
|
223 |
+
},
|
224 |
+
"llama-2-70b-chat": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
|
225 |
+
"llama-2-13b-chat": {"mistral-7b-instruct", "vicuna-13b", "llama-2-70b-chat"},
|
226 |
+
"llama-2-7b-chat": {"mistral-7b-instruct", "vicuna-7b", "llama-2-13b-chat"},
|
227 |
+
"mistral-7b-instruct": {
|
228 |
+
"llama-2-7b-chat",
|
229 |
+
"llama-2-13b-chat",
|
230 |
+
"llama-2-70b-chat",
|
231 |
+
},
|
232 |
+
"vicuna-33b": {"llama-2-70b-chat", "gpt-3.5-turbo", "claude-instant-1"},
|
233 |
+
"vicuna-13b": {"llama-2-13b-chat", "llama-2-70b-chat"},
|
234 |
+
"vicuna-7b": {"llama-2-7b-chat", "mistral-7b-instruct", "llama-2-13b-chat"},
|
235 |
+
"wizardlm-70b": {"gpt-3.5-turbo", "vicuna-33b", "claude-instant-1"},
|
236 |
+
"palm-2": {"llama-2-13b-chat", "gpt-3.5-turbo"},
|
237 |
+
}
|
238 |
+
|
239 |
+
SAMPLING_BOOST_MODELS = ["openchat-3.5", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
|
240 |
+
|
241 |
+
# outage models won't be sampled.
|
242 |
+
OUTAGE_MODELS = []
|
243 |
+
|
244 |
+
|
245 |
+
def get_sample_weight(model):
|
246 |
+
if model in OUTAGE_MODELS:
|
247 |
+
return 0
|
248 |
+
weight = SAMPLING_WEIGHTS.get(model, 1.0)
|
249 |
+
if model in SAMPLING_BOOST_MODELS:
|
250 |
+
weight *= 5
|
251 |
+
return weight
|
252 |
+
|
253 |
+
|
254 |
+
def get_battle_pair():
|
255 |
+
if len(models) == 1:
|
256 |
+
return models[0], models[0]
|
257 |
+
|
258 |
+
model_weights = []
|
259 |
+
for model in models:
|
260 |
+
weight = get_sample_weight(model)
|
261 |
+
model_weights.append(weight)
|
262 |
+
total_weight = np.sum(model_weights)
|
263 |
+
model_weights = model_weights / total_weight
|
264 |
+
chosen_idx = np.random.choice(len(models), p=model_weights)
|
265 |
+
chosen_model = models[chosen_idx]
|
266 |
+
|
267 |
+
rival_models = []
|
268 |
+
rival_weights = []
|
269 |
+
for model in models:
|
270 |
+
if model == chosen_model:
|
271 |
+
continue
|
272 |
+
weight = get_sample_weight(model)
|
273 |
+
if (
|
274 |
+
weight != 0
|
275 |
+
and chosen_model in BATTLE_TARGETS
|
276 |
+
and model in BATTLE_TARGETS[chosen_model]
|
277 |
+
):
|
278 |
+
# boost to 50% chance
|
279 |
+
weight = total_weight / len(BATTLE_TARGETS[chosen_model])
|
280 |
+
rival_models.append(model)
|
281 |
+
rival_weights.append(weight)
|
282 |
+
# for p, w in zip(rival_models, rival_weights):
|
283 |
+
# print(p, w)
|
284 |
+
rival_weights = rival_weights / np.sum(rival_weights)
|
285 |
+
rival_idx = np.random.choice(len(rival_models), p=rival_weights)
|
286 |
+
rival_model = rival_models[rival_idx]
|
287 |
+
|
288 |
+
swap = np.random.randint(2)
|
289 |
+
if swap == 0:
|
290 |
+
return chosen_model, rival_model
|
291 |
+
else:
|
292 |
+
return rival_model, chosen_model
|
293 |
+
|
294 |
+
|
295 |
+
def add_text(
|
296 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
297 |
+
):
|
298 |
+
ip = get_ip(request)
|
299 |
+
logger.info(f"add_text (anony). ip: {ip}. len: {len(text)}")
|
300 |
+
states = [state0, state1]
|
301 |
+
model_selectors = [model_selector0, model_selector1]
|
302 |
+
|
303 |
+
# Init states if necessary
|
304 |
+
if states[0] is None:
|
305 |
+
assert states[1] is None
|
306 |
+
|
307 |
+
model_left, model_right = get_battle_pair()
|
308 |
+
states = [
|
309 |
+
State(model_left),
|
310 |
+
State(model_right),
|
311 |
+
]
|
312 |
+
|
313 |
+
if len(text) <= 0:
|
314 |
+
for i in range(num_sides):
|
315 |
+
states[i].skip_next = True
|
316 |
+
return (
|
317 |
+
states
|
318 |
+
+ [x.to_gradio_chatbot() for x in states]
|
319 |
+
+ [""]
|
320 |
+
+ [
|
321 |
+
no_change_btn,
|
322 |
+
]
|
323 |
+
* 6
|
324 |
+
+ [""]
|
325 |
+
)
|
326 |
+
|
327 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
328 |
+
flagged = moderation_filter(text, model_list)
|
329 |
+
if flagged:
|
330 |
+
logger.info(f"violate moderation (anony). ip: {ip}. text: {text}")
|
331 |
+
# overwrite the original text
|
332 |
+
text = MODERATION_MSG
|
333 |
+
|
334 |
+
conv = states[0].conv
|
335 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
336 |
+
logger.info(f"conversation turn limit. ip: {get_ip(request)}. text: {text}")
|
337 |
+
for i in range(num_sides):
|
338 |
+
states[i].skip_next = True
|
339 |
+
return (
|
340 |
+
states
|
341 |
+
+ [x.to_gradio_chatbot() for x in states]
|
342 |
+
+ [CONVERSATION_LIMIT_MSG]
|
343 |
+
+ [
|
344 |
+
no_change_btn,
|
345 |
+
]
|
346 |
+
* 6
|
347 |
+
+ [""]
|
348 |
+
)
|
349 |
+
|
350 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
351 |
+
for i in range(num_sides):
|
352 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
353 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
354 |
+
states[i].skip_next = False
|
355 |
+
|
356 |
+
slow_model_msg = ""
|
357 |
+
for i in range(num_sides):
|
358 |
+
if "deluxe" in states[i].model_name:
|
359 |
+
slow_model_msg = SLOW_MODEL_MSG
|
360 |
+
return (
|
361 |
+
states
|
362 |
+
+ [x.to_gradio_chatbot() for x in states]
|
363 |
+
+ [""]
|
364 |
+
+ [
|
365 |
+
disable_btn,
|
366 |
+
]
|
367 |
+
* 6
|
368 |
+
+ [slow_model_msg]
|
369 |
+
)
|
370 |
+
|
371 |
+
|
372 |
+
def bot_response_multi(
|
373 |
+
state0,
|
374 |
+
state1,
|
375 |
+
temperature,
|
376 |
+
top_p,
|
377 |
+
max_new_tokens,
|
378 |
+
request: gr.Request,
|
379 |
+
):
|
380 |
+
logger.info(f"bot_response_multi (anony). ip: {get_ip(request)}")
|
381 |
+
|
382 |
+
if state0 is None or state0.skip_next:
|
383 |
+
# This generate call is skipped due to invalid inputs
|
384 |
+
yield (
|
385 |
+
state0,
|
386 |
+
state1,
|
387 |
+
state0.to_gradio_chatbot(),
|
388 |
+
state1.to_gradio_chatbot(),
|
389 |
+
) + (no_change_btn,) * 6
|
390 |
+
return
|
391 |
+
|
392 |
+
states = [state0, state1]
|
393 |
+
gen = []
|
394 |
+
for i in range(num_sides):
|
395 |
+
gen.append(
|
396 |
+
bot_response(
|
397 |
+
states[i],
|
398 |
+
temperature,
|
399 |
+
top_p,
|
400 |
+
max_new_tokens,
|
401 |
+
request,
|
402 |
+
)
|
403 |
+
)
|
404 |
+
|
405 |
+
chatbots = [None] * num_sides
|
406 |
+
while True:
|
407 |
+
stop = True
|
408 |
+
for i in range(num_sides):
|
409 |
+
try:
|
410 |
+
ret = next(gen[i])
|
411 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
412 |
+
stop = False
|
413 |
+
except StopIteration:
|
414 |
+
pass
|
415 |
+
yield states + chatbots + [disable_btn] * 6
|
416 |
+
if stop:
|
417 |
+
break
|
418 |
+
|
419 |
+
|
420 |
+
def build_side_by_side_ui_anony(models):
|
421 |
+
notice_markdown = """
|
422 |
+
# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
|
423 |
+
| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
424 |
+
|
425 |
+
## 📜 Rules
|
426 |
+
- Ask any question to two anonymous models (e.g., ChatGPT, Claude, Llama) and vote for the better one!
|
427 |
+
- You can continue chatting until you identify a winner.
|
428 |
+
- Vote won't be counted if model identity is revealed during conversation.
|
429 |
+
|
430 |
+
## 🏆 Arena Elo [Leaderboard](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard)
|
431 |
+
We use **100K** human votes to compile an Elo-based LLM leaderboard.
|
432 |
+
Find out who is the 🥇LLM Champion!
|
433 |
+
|
434 |
+
## 👇 Chat now!
|
435 |
+
|
436 |
+
"""
|
437 |
+
|
438 |
+
states = [gr.State() for _ in range(num_sides)]
|
439 |
+
model_selectors = [None] * num_sides
|
440 |
+
chatbots = [None] * num_sides
|
441 |
+
|
442 |
+
gr.Markdown(notice_markdown, elem_id="notice_markdown")
|
443 |
+
|
444 |
+
with gr.Box(elem_id="share-region-anony"):
|
445 |
+
with gr.Row():
|
446 |
+
for i in range(num_sides):
|
447 |
+
label = "Model A" if i == 0 else "Model B"
|
448 |
+
with gr.Column():
|
449 |
+
chatbots[i] = gr.Chatbot(
|
450 |
+
label=label, elem_id=f"chatbot", height=550
|
451 |
+
)
|
452 |
+
|
453 |
+
with gr.Row():
|
454 |
+
for i in range(num_sides):
|
455 |
+
with gr.Column():
|
456 |
+
model_selectors[i] = gr.Markdown(anony_names[i])
|
457 |
+
with gr.Row():
|
458 |
+
slow_warning = gr.Markdown("", elem_id="notice_markdown")
|
459 |
+
|
460 |
+
with gr.Row():
|
461 |
+
leftvote_btn = gr.Button(
|
462 |
+
value="👈 A is better", visible=False, interactive=False
|
463 |
+
)
|
464 |
+
rightvote_btn = gr.Button(
|
465 |
+
value="👉 B is better", visible=False, interactive=False
|
466 |
+
)
|
467 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
468 |
+
bothbad_btn = gr.Button(
|
469 |
+
value="👎 Both are bad", visible=False, interactive=False
|
470 |
+
)
|
471 |
+
|
472 |
+
with gr.Row():
|
473 |
+
with gr.Column(scale=20):
|
474 |
+
textbox = gr.Textbox(
|
475 |
+
show_label=False,
|
476 |
+
placeholder="👉 Enter your prompt and press ENTER",
|
477 |
+
container=False,
|
478 |
+
elem_id="input_box",
|
479 |
+
)
|
480 |
+
with gr.Column(scale=1, min_width=50):
|
481 |
+
send_btn = gr.Button(value="Send", variant="primary")
|
482 |
+
|
483 |
+
with gr.Row() as button_row:
|
484 |
+
clear_btn = gr.Button(value="🎲 New Round", interactive=False)
|
485 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
486 |
+
share_btn = gr.Button(value="📷 Share")
|
487 |
+
|
488 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
489 |
+
temperature = gr.Slider(
|
490 |
+
minimum=0.0,
|
491 |
+
maximum=1.0,
|
492 |
+
value=0.7,
|
493 |
+
step=0.1,
|
494 |
+
interactive=True,
|
495 |
+
label="Temperature",
|
496 |
+
)
|
497 |
+
top_p = gr.Slider(
|
498 |
+
minimum=0.0,
|
499 |
+
maximum=1.0,
|
500 |
+
value=1.0,
|
501 |
+
step=0.1,
|
502 |
+
interactive=True,
|
503 |
+
label="Top P",
|
504 |
+
)
|
505 |
+
max_output_tokens = gr.Slider(
|
506 |
+
minimum=16,
|
507 |
+
maximum=1024,
|
508 |
+
value=512,
|
509 |
+
step=64,
|
510 |
+
interactive=True,
|
511 |
+
label="Max output tokens",
|
512 |
+
)
|
513 |
+
|
514 |
+
gr.Markdown(acknowledgment_md)
|
515 |
+
|
516 |
+
# Register listeners
|
517 |
+
btn_list = [
|
518 |
+
leftvote_btn,
|
519 |
+
rightvote_btn,
|
520 |
+
tie_btn,
|
521 |
+
bothbad_btn,
|
522 |
+
regenerate_btn,
|
523 |
+
clear_btn,
|
524 |
+
]
|
525 |
+
leftvote_btn.click(
|
526 |
+
leftvote_last_response,
|
527 |
+
states + model_selectors,
|
528 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
529 |
+
)
|
530 |
+
rightvote_btn.click(
|
531 |
+
rightvote_last_response,
|
532 |
+
states + model_selectors,
|
533 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
534 |
+
)
|
535 |
+
tie_btn.click(
|
536 |
+
tievote_last_response,
|
537 |
+
states + model_selectors,
|
538 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
539 |
+
)
|
540 |
+
bothbad_btn.click(
|
541 |
+
bothbad_vote_last_response,
|
542 |
+
states + model_selectors,
|
543 |
+
model_selectors + [textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
544 |
+
)
|
545 |
+
regenerate_btn.click(
|
546 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
547 |
+
).then(
|
548 |
+
bot_response_multi,
|
549 |
+
states + [temperature, top_p, max_output_tokens],
|
550 |
+
states + chatbots + btn_list,
|
551 |
+
).then(
|
552 |
+
flash_buttons, [], btn_list
|
553 |
+
)
|
554 |
+
clear_btn.click(
|
555 |
+
clear_history,
|
556 |
+
None,
|
557 |
+
states + chatbots + model_selectors + [textbox] + btn_list + [slow_warning],
|
558 |
+
)
|
559 |
+
|
560 |
+
share_js = """
|
561 |
+
function (a, b, c, d) {
|
562 |
+
const captureElement = document.querySelector('#share-region-anony');
|
563 |
+
html2canvas(captureElement)
|
564 |
+
.then(canvas => {
|
565 |
+
canvas.style.display = 'none'
|
566 |
+
document.body.appendChild(canvas)
|
567 |
+
return canvas
|
568 |
+
})
|
569 |
+
.then(canvas => {
|
570 |
+
const image = canvas.toDataURL('image/png')
|
571 |
+
const a = document.createElement('a')
|
572 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
573 |
+
a.setAttribute('href', image)
|
574 |
+
a.click()
|
575 |
+
canvas.remove()
|
576 |
+
});
|
577 |
+
return [a, b, c, d];
|
578 |
+
}
|
579 |
+
"""
|
580 |
+
share_btn.click(share_click, states + model_selectors, [], _js=share_js)
|
581 |
+
|
582 |
+
textbox.submit(
|
583 |
+
add_text,
|
584 |
+
states + model_selectors + [textbox],
|
585 |
+
states + chatbots + [textbox] + btn_list + [slow_warning],
|
586 |
+
).then(
|
587 |
+
bot_response_multi,
|
588 |
+
states + [temperature, top_p, max_output_tokens],
|
589 |
+
states + chatbots + btn_list,
|
590 |
+
).then(
|
591 |
+
flash_buttons,
|
592 |
+
[],
|
593 |
+
btn_list,
|
594 |
+
)
|
595 |
+
|
596 |
+
send_btn.click(
|
597 |
+
add_text,
|
598 |
+
states + model_selectors + [textbox],
|
599 |
+
states + chatbots + [textbox] + btn_list,
|
600 |
+
).then(
|
601 |
+
bot_response_multi,
|
602 |
+
states + [temperature, top_p, max_output_tokens],
|
603 |
+
states + chatbots + btn_list,
|
604 |
+
).then(
|
605 |
+
flash_buttons, [], btn_list
|
606 |
+
)
|
607 |
+
|
608 |
+
return states + model_selectors
|
gradio_block_arena_named.py
ADDED
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Chatbot Arena (side-by-side) tab.
|
3 |
+
Users chat with two chosen models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from fastchat.constants import (
|
13 |
+
MODERATION_MSG,
|
14 |
+
CONVERSATION_LIMIT_MSG,
|
15 |
+
INPUT_CHAR_LEN_LIMIT,
|
16 |
+
CONVERSATION_TURN_LIMIT,
|
17 |
+
)
|
18 |
+
from fastchat.model.model_adapter import get_conversation_template
|
19 |
+
from fastchat.serve.gradio_web_server import (
|
20 |
+
State,
|
21 |
+
bot_response,
|
22 |
+
get_conv_log_filename,
|
23 |
+
no_change_btn,
|
24 |
+
enable_btn,
|
25 |
+
disable_btn,
|
26 |
+
invisible_btn,
|
27 |
+
acknowledgment_md,
|
28 |
+
get_model_description_md,
|
29 |
+
ip_expiration_dict,
|
30 |
+
get_ip,
|
31 |
+
)
|
32 |
+
from fastchat.utils import (
|
33 |
+
build_logger,
|
34 |
+
moderation_filter,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
39 |
+
|
40 |
+
num_sides = 2
|
41 |
+
enable_moderation = False
|
42 |
+
|
43 |
+
|
44 |
+
def set_global_vars_named(enable_moderation_):
|
45 |
+
global enable_moderation
|
46 |
+
enable_moderation = enable_moderation_
|
47 |
+
|
48 |
+
|
49 |
+
def load_demo_side_by_side_named(models, url_params):
|
50 |
+
states = (None,) * num_sides
|
51 |
+
|
52 |
+
model_left = models[0] if len(models) > 0 else ""
|
53 |
+
if len(models) > 1:
|
54 |
+
weights = ([8] * 4 + [4] * 8 + [1] * 32)[: len(models) - 1]
|
55 |
+
weights = weights / np.sum(weights)
|
56 |
+
model_right = np.random.choice(models[1:], p=weights)
|
57 |
+
else:
|
58 |
+
model_right = model_left
|
59 |
+
|
60 |
+
selector_updates = (
|
61 |
+
gr.Dropdown.update(choices=models, value=model_left, visible=True),
|
62 |
+
gr.Dropdown.update(choices=models, value=model_right, visible=True),
|
63 |
+
)
|
64 |
+
|
65 |
+
return states + selector_updates
|
66 |
+
|
67 |
+
|
68 |
+
def vote_last_response(states, vote_type, model_selectors, request: gr.Request):
|
69 |
+
with open(get_conv_log_filename(), "a") as fout:
|
70 |
+
data = {
|
71 |
+
"tstamp": round(time.time(), 4),
|
72 |
+
"type": vote_type,
|
73 |
+
"models": [x for x in model_selectors],
|
74 |
+
"states": [x.dict() for x in states],
|
75 |
+
"ip": get_ip(request),
|
76 |
+
}
|
77 |
+
fout.write(json.dumps(data) + "\n")
|
78 |
+
|
79 |
+
|
80 |
+
def leftvote_last_response(
|
81 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
82 |
+
):
|
83 |
+
logger.info(f"leftvote (named). ip: {get_ip(request)}")
|
84 |
+
vote_last_response(
|
85 |
+
[state0, state1], "leftvote", [model_selector0, model_selector1], request
|
86 |
+
)
|
87 |
+
return ("",) + (disable_btn,) * 4
|
88 |
+
|
89 |
+
|
90 |
+
def rightvote_last_response(
|
91 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
92 |
+
):
|
93 |
+
logger.info(f"rightvote (named). ip: {get_ip(request)}")
|
94 |
+
vote_last_response(
|
95 |
+
[state0, state1], "rightvote", [model_selector0, model_selector1], request
|
96 |
+
)
|
97 |
+
return ("",) + (disable_btn,) * 4
|
98 |
+
|
99 |
+
|
100 |
+
def tievote_last_response(
|
101 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
102 |
+
):
|
103 |
+
logger.info(f"tievote (named). ip: {get_ip(request)}")
|
104 |
+
vote_last_response(
|
105 |
+
[state0, state1], "tievote", [model_selector0, model_selector1], request
|
106 |
+
)
|
107 |
+
return ("",) + (disable_btn,) * 4
|
108 |
+
|
109 |
+
|
110 |
+
def bothbad_vote_last_response(
|
111 |
+
state0, state1, model_selector0, model_selector1, request: gr.Request
|
112 |
+
):
|
113 |
+
logger.info(f"bothbad_vote (named). ip: {get_ip(request)}")
|
114 |
+
vote_last_response(
|
115 |
+
[state0, state1], "bothbad_vote", [model_selector0, model_selector1], request
|
116 |
+
)
|
117 |
+
return ("",) + (disable_btn,) * 4
|
118 |
+
|
119 |
+
|
120 |
+
def regenerate(state0, state1, request: gr.Request):
|
121 |
+
logger.info(f"regenerate (named). ip: {get_ip(request)}")
|
122 |
+
states = [state0, state1]
|
123 |
+
for i in range(num_sides):
|
124 |
+
states[i].conv.update_last_message(None)
|
125 |
+
return states + [x.to_gradio_chatbot() for x in states] + [""] + [disable_btn] * 6
|
126 |
+
|
127 |
+
|
128 |
+
def clear_history(request: gr.Request):
|
129 |
+
logger.info(f"clear_history (named). ip: {get_ip(request)}")
|
130 |
+
return (
|
131 |
+
[None] * num_sides
|
132 |
+
+ [None] * num_sides
|
133 |
+
+ [""]
|
134 |
+
+ [invisible_btn] * 4
|
135 |
+
+ [disable_btn] * 2
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
def share_click(state0, state1, model_selector0, model_selector1, request: gr.Request):
|
140 |
+
logger.info(f"share (named). ip: {get_ip(request)}")
|
141 |
+
if state0 is not None and state1 is not None:
|
142 |
+
vote_last_response(
|
143 |
+
[state0, state1], "share", [model_selector0, model_selector1], request
|
144 |
+
)
|
145 |
+
|
146 |
+
|
147 |
+
def add_text(
|
148 |
+
state0, state1, model_selector0, model_selector1, text, request: gr.Request
|
149 |
+
):
|
150 |
+
ip = get_ip(request)
|
151 |
+
logger.info(f"add_text (named). ip: {ip}. len: {len(text)}")
|
152 |
+
states = [state0, state1]
|
153 |
+
model_selectors = [model_selector0, model_selector1]
|
154 |
+
|
155 |
+
# Init states if necessary
|
156 |
+
for i in range(num_sides):
|
157 |
+
if states[i] is None:
|
158 |
+
states[i] = State(model_selectors[i])
|
159 |
+
|
160 |
+
if len(text) <= 0:
|
161 |
+
for i in range(num_sides):
|
162 |
+
states[i].skip_next = True
|
163 |
+
return (
|
164 |
+
states
|
165 |
+
+ [x.to_gradio_chatbot() for x in states]
|
166 |
+
+ [""]
|
167 |
+
+ [
|
168 |
+
no_change_btn,
|
169 |
+
]
|
170 |
+
* 6
|
171 |
+
)
|
172 |
+
|
173 |
+
model_list = [states[i].model_name for i in range(num_sides)]
|
174 |
+
flagged = moderation_filter(text, model_list)
|
175 |
+
if flagged:
|
176 |
+
logger.info(f"violate moderation (named). ip: {ip}. text: {text}")
|
177 |
+
# overwrite the original text
|
178 |
+
text = MODERATION_MSG
|
179 |
+
|
180 |
+
conv = states[0].conv
|
181 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
182 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
183 |
+
for i in range(num_sides):
|
184 |
+
states[i].skip_next = True
|
185 |
+
return (
|
186 |
+
states
|
187 |
+
+ [x.to_gradio_chatbot() for x in states]
|
188 |
+
+ [CONVERSATION_LIMIT_MSG]
|
189 |
+
+ [
|
190 |
+
no_change_btn,
|
191 |
+
]
|
192 |
+
* 6
|
193 |
+
)
|
194 |
+
|
195 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
196 |
+
for i in range(num_sides):
|
197 |
+
states[i].conv.append_message(states[i].conv.roles[0], text)
|
198 |
+
states[i].conv.append_message(states[i].conv.roles[1], None)
|
199 |
+
states[i].skip_next = False
|
200 |
+
|
201 |
+
return (
|
202 |
+
states
|
203 |
+
+ [x.to_gradio_chatbot() for x in states]
|
204 |
+
+ [""]
|
205 |
+
+ [
|
206 |
+
disable_btn,
|
207 |
+
]
|
208 |
+
* 6
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
def bot_response_multi(
|
213 |
+
state0,
|
214 |
+
state1,
|
215 |
+
temperature,
|
216 |
+
top_p,
|
217 |
+
max_new_tokens,
|
218 |
+
request: gr.Request,
|
219 |
+
):
|
220 |
+
logger.info(f"bot_response_multi (named). ip: {get_ip(request)}")
|
221 |
+
|
222 |
+
if state0.skip_next:
|
223 |
+
# This generate call is skipped due to invalid inputs
|
224 |
+
yield (
|
225 |
+
state0,
|
226 |
+
state1,
|
227 |
+
state0.to_gradio_chatbot(),
|
228 |
+
state1.to_gradio_chatbot(),
|
229 |
+
) + (no_change_btn,) * 6
|
230 |
+
return
|
231 |
+
|
232 |
+
states = [state0, state1]
|
233 |
+
gen = []
|
234 |
+
for i in range(num_sides):
|
235 |
+
gen.append(
|
236 |
+
bot_response(
|
237 |
+
states[i],
|
238 |
+
temperature,
|
239 |
+
top_p,
|
240 |
+
max_new_tokens,
|
241 |
+
request,
|
242 |
+
)
|
243 |
+
)
|
244 |
+
|
245 |
+
chatbots = [None] * num_sides
|
246 |
+
while True:
|
247 |
+
stop = True
|
248 |
+
for i in range(num_sides):
|
249 |
+
try:
|
250 |
+
ret = next(gen[i])
|
251 |
+
states[i], chatbots[i] = ret[0], ret[1]
|
252 |
+
stop = False
|
253 |
+
except StopIteration:
|
254 |
+
pass
|
255 |
+
yield states + chatbots + [disable_btn] * 6
|
256 |
+
if stop:
|
257 |
+
break
|
258 |
+
|
259 |
+
|
260 |
+
def flash_buttons():
|
261 |
+
btn_updates = [
|
262 |
+
[disable_btn] * 4 + [enable_btn] * 2,
|
263 |
+
[enable_btn] * 6,
|
264 |
+
]
|
265 |
+
for i in range(4):
|
266 |
+
yield btn_updates[i % 2]
|
267 |
+
time.sleep(0.5)
|
268 |
+
|
269 |
+
|
270 |
+
def build_side_by_side_ui_named(models):
|
271 |
+
notice_markdown = """
|
272 |
+
# ⚔️ Chatbot Arena ⚔️ : Benchmarking LLMs in the Wild
|
273 |
+
| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
274 |
+
|
275 |
+
## 📜 Rules
|
276 |
+
- Chat with any two models side-by-side and vote!
|
277 |
+
- You can continue chatting for multiple rounds.
|
278 |
+
- Click "Clear history" to start a new round.
|
279 |
+
|
280 |
+
## 🤖 Choose two models to compare
|
281 |
+
"""
|
282 |
+
|
283 |
+
states = [gr.State() for _ in range(num_sides)]
|
284 |
+
model_selectors = [None] * num_sides
|
285 |
+
chatbots = [None] * num_sides
|
286 |
+
|
287 |
+
model_description_md = get_model_description_md(models)
|
288 |
+
notice = gr.Markdown(
|
289 |
+
notice_markdown + model_description_md, elem_id="notice_markdown"
|
290 |
+
)
|
291 |
+
|
292 |
+
with gr.Box(elem_id="share-region-named"):
|
293 |
+
with gr.Row():
|
294 |
+
for i in range(num_sides):
|
295 |
+
with gr.Column():
|
296 |
+
model_selectors[i] = gr.Dropdown(
|
297 |
+
choices=models,
|
298 |
+
value=models[i] if len(models) > i else "",
|
299 |
+
interactive=True,
|
300 |
+
show_label=False,
|
301 |
+
container=False,
|
302 |
+
)
|
303 |
+
|
304 |
+
with gr.Row():
|
305 |
+
for i in range(num_sides):
|
306 |
+
label = "Model A" if i == 0 else "Model B"
|
307 |
+
with gr.Column():
|
308 |
+
chatbots[i] = gr.Chatbot(
|
309 |
+
label=label, elem_id=f"chatbot", height=550
|
310 |
+
)
|
311 |
+
|
312 |
+
with gr.Row():
|
313 |
+
leftvote_btn = gr.Button(
|
314 |
+
value="👈 A is better", visible=False, interactive=False
|
315 |
+
)
|
316 |
+
rightvote_btn = gr.Button(
|
317 |
+
value="👉 B is better", visible=False, interactive=False
|
318 |
+
)
|
319 |
+
tie_btn = gr.Button(value="🤝 Tie", visible=False, interactive=False)
|
320 |
+
bothbad_btn = gr.Button(
|
321 |
+
value="👎 Both are bad", visible=False, interactive=False
|
322 |
+
)
|
323 |
+
|
324 |
+
with gr.Row():
|
325 |
+
with gr.Column(scale=20):
|
326 |
+
textbox = gr.Textbox(
|
327 |
+
show_label=False,
|
328 |
+
placeholder="Enter your prompt here and press ENTER",
|
329 |
+
container=False,
|
330 |
+
elem_id="input_box",
|
331 |
+
)
|
332 |
+
with gr.Column(scale=1, min_width=50):
|
333 |
+
send_btn = gr.Button(value="Send", variant="primary")
|
334 |
+
|
335 |
+
with gr.Row() as button_row:
|
336 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
337 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
338 |
+
share_btn = gr.Button(value="📷 Share")
|
339 |
+
|
340 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
341 |
+
temperature = gr.Slider(
|
342 |
+
minimum=0.0,
|
343 |
+
maximum=1.0,
|
344 |
+
value=0.7,
|
345 |
+
step=0.1,
|
346 |
+
interactive=True,
|
347 |
+
label="Temperature",
|
348 |
+
)
|
349 |
+
top_p = gr.Slider(
|
350 |
+
minimum=0.0,
|
351 |
+
maximum=1.0,
|
352 |
+
value=1.0,
|
353 |
+
step=0.1,
|
354 |
+
interactive=True,
|
355 |
+
label="Top P",
|
356 |
+
)
|
357 |
+
max_output_tokens = gr.Slider(
|
358 |
+
minimum=16,
|
359 |
+
maximum=1024,
|
360 |
+
value=512,
|
361 |
+
step=64,
|
362 |
+
interactive=True,
|
363 |
+
label="Max output tokens",
|
364 |
+
)
|
365 |
+
|
366 |
+
gr.Markdown(acknowledgment_md)
|
367 |
+
|
368 |
+
# Register listeners
|
369 |
+
btn_list = [
|
370 |
+
leftvote_btn,
|
371 |
+
rightvote_btn,
|
372 |
+
tie_btn,
|
373 |
+
bothbad_btn,
|
374 |
+
regenerate_btn,
|
375 |
+
clear_btn,
|
376 |
+
]
|
377 |
+
leftvote_btn.click(
|
378 |
+
leftvote_last_response,
|
379 |
+
states + model_selectors,
|
380 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
381 |
+
)
|
382 |
+
rightvote_btn.click(
|
383 |
+
rightvote_last_response,
|
384 |
+
states + model_selectors,
|
385 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
386 |
+
)
|
387 |
+
tie_btn.click(
|
388 |
+
tievote_last_response,
|
389 |
+
states + model_selectors,
|
390 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
391 |
+
)
|
392 |
+
bothbad_btn.click(
|
393 |
+
bothbad_vote_last_response,
|
394 |
+
states + model_selectors,
|
395 |
+
[textbox, leftvote_btn, rightvote_btn, tie_btn, bothbad_btn],
|
396 |
+
)
|
397 |
+
regenerate_btn.click(
|
398 |
+
regenerate, states, states + chatbots + [textbox] + btn_list
|
399 |
+
).then(
|
400 |
+
bot_response_multi,
|
401 |
+
states + [temperature, top_p, max_output_tokens],
|
402 |
+
states + chatbots + btn_list,
|
403 |
+
).then(
|
404 |
+
flash_buttons, [], btn_list
|
405 |
+
)
|
406 |
+
clear_btn.click(clear_history, None, states + chatbots + [textbox] + btn_list)
|
407 |
+
|
408 |
+
share_js = """
|
409 |
+
function (a, b, c, d) {
|
410 |
+
const captureElement = document.querySelector('#share-region-named');
|
411 |
+
html2canvas(captureElement)
|
412 |
+
.then(canvas => {
|
413 |
+
canvas.style.display = 'none'
|
414 |
+
document.body.appendChild(canvas)
|
415 |
+
return canvas
|
416 |
+
})
|
417 |
+
.then(canvas => {
|
418 |
+
const image = canvas.toDataURL('image/png')
|
419 |
+
const a = document.createElement('a')
|
420 |
+
a.setAttribute('download', 'chatbot-arena.png')
|
421 |
+
a.setAttribute('href', image)
|
422 |
+
a.click()
|
423 |
+
canvas.remove()
|
424 |
+
});
|
425 |
+
return [a, b, c, d];
|
426 |
+
}
|
427 |
+
"""
|
428 |
+
share_btn.click(share_click, states + model_selectors, [], _js=share_js)
|
429 |
+
|
430 |
+
for i in range(num_sides):
|
431 |
+
model_selectors[i].change(
|
432 |
+
clear_history, None, states + chatbots + [textbox] + btn_list
|
433 |
+
)
|
434 |
+
|
435 |
+
textbox.submit(
|
436 |
+
add_text,
|
437 |
+
states + model_selectors + [textbox],
|
438 |
+
states + chatbots + [textbox] + btn_list,
|
439 |
+
).then(
|
440 |
+
bot_response_multi,
|
441 |
+
states + [temperature, top_p, max_output_tokens],
|
442 |
+
states + chatbots + btn_list,
|
443 |
+
).then(
|
444 |
+
flash_buttons, [], btn_list
|
445 |
+
)
|
446 |
+
send_btn.click(
|
447 |
+
add_text,
|
448 |
+
states + model_selectors + [textbox],
|
449 |
+
states + chatbots + [textbox] + btn_list,
|
450 |
+
).then(
|
451 |
+
bot_response_multi,
|
452 |
+
states + [temperature, top_p, max_output_tokens],
|
453 |
+
states + chatbots + btn_list,
|
454 |
+
).then(
|
455 |
+
flash_buttons, [], btn_list
|
456 |
+
)
|
457 |
+
|
458 |
+
return states + model_selectors
|
gradio_web_server.py
ADDED
@@ -0,0 +1,883 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server for chatting with a single model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import argparse
|
6 |
+
from collections import defaultdict
|
7 |
+
import datetime
|
8 |
+
import json
|
9 |
+
import os
|
10 |
+
import random
|
11 |
+
import time
|
12 |
+
import uuid
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import requests
|
16 |
+
|
17 |
+
from fastchat.conversation import SeparatorStyle
|
18 |
+
from fastchat.constants import (
|
19 |
+
LOGDIR,
|
20 |
+
WORKER_API_TIMEOUT,
|
21 |
+
ErrorCode,
|
22 |
+
MODERATION_MSG,
|
23 |
+
CONVERSATION_LIMIT_MSG,
|
24 |
+
SERVER_ERROR_MSG,
|
25 |
+
INPUT_CHAR_LEN_LIMIT,
|
26 |
+
CONVERSATION_TURN_LIMIT,
|
27 |
+
SESSION_EXPIRATION_TIME,
|
28 |
+
)
|
29 |
+
from fastchat.model.model_adapter import get_conversation_template
|
30 |
+
from fastchat.conversation import get_conv_template
|
31 |
+
from fastchat.model.model_registry import get_model_info, model_info
|
32 |
+
from fastchat.serve.api_provider import (
|
33 |
+
anthropic_api_stream_iter,
|
34 |
+
openai_api_stream_iter,
|
35 |
+
palm_api_stream_iter,
|
36 |
+
init_palm_chat,
|
37 |
+
)
|
38 |
+
from fastchat.utils import (
|
39 |
+
build_logger,
|
40 |
+
moderation_filter,
|
41 |
+
get_window_url_params_js,
|
42 |
+
get_window_url_params_with_tos_js,
|
43 |
+
parse_gradio_auth_creds,
|
44 |
+
)
|
45 |
+
|
46 |
+
CONV_TEMPLATE = ''
|
47 |
+
|
48 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
49 |
+
|
50 |
+
headers = {"User-Agent": "FastChat Client"}
|
51 |
+
|
52 |
+
no_change_btn = gr.Button.update()
|
53 |
+
enable_btn = gr.Button.update(interactive=True, visible=True)
|
54 |
+
disable_btn = gr.Button.update(interactive=False)
|
55 |
+
invisible_btn = gr.Button.update(interactive=False, visible=False)
|
56 |
+
|
57 |
+
controller_url = None
|
58 |
+
enable_moderation = False
|
59 |
+
|
60 |
+
acknowledgment_md = """
|
61 |
+
### Acknowledgment
|
62 |
+
<div class="image-container">
|
63 |
+
<p> We thank <a href="https://www.kaggle.com/" target="_blank">Kaggle</a>, <a href="https://mbzuai.ac.ae/" target="_blank">MBZUAI</a>, <a href="https://www.anyscale.com/" target="_blank">AnyScale</a>, and <a href="https://huggingface.co/" target="_blank">HuggingFace</a> for their <a href="https://lmsys.org/donations/" target="_blank">sponsorship</a>. </p>
|
64 |
+
<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/7/7c/Kaggle_logo.png/400px-Kaggle_logo.png" alt="Image 1">
|
65 |
+
<img src="https://mma.prnewswire.com/media/1227419/MBZUAI_Logo.jpg?p=facebookg" alt="Image 2">
|
66 |
+
<img src="https://docs.anyscale.com/site-assets/logo.png" alt="Image 3">
|
67 |
+
<img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo-with-title.png" alt="Image 4">
|
68 |
+
</div>
|
69 |
+
"""
|
70 |
+
|
71 |
+
ip_expiration_dict = defaultdict(lambda: 0)
|
72 |
+
|
73 |
+
# Information about custom OpenAI compatible API models.
|
74 |
+
# JSON file format:
|
75 |
+
# {
|
76 |
+
# "vicuna-7b": {
|
77 |
+
# "model_name": "vicuna-7b-v1.5",
|
78 |
+
# "api_base": "http://8.8.8.55:5555/v1",
|
79 |
+
# "api_key": "password"
|
80 |
+
# },
|
81 |
+
# }
|
82 |
+
openai_compatible_models_info = {}
|
83 |
+
|
84 |
+
|
85 |
+
class State:
|
86 |
+
def __init__(self, model_name):
|
87 |
+
# if model_name=='checkpoint-800':
|
88 |
+
# self.conv = get_conv_template(CONV_TEMPLATE)
|
89 |
+
# elif model_name=='MiniCPM-2B-sft-bf16':
|
90 |
+
ret = requests.post(
|
91 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
92 |
+
)
|
93 |
+
worker_addr = ret.json()["address"]
|
94 |
+
conv_name = requests.post(
|
95 |
+
worker_addr + "/worker_get_conv_template",
|
96 |
+
).json()['conv']['name']
|
97 |
+
self.conv = get_conv_template(conv_name)
|
98 |
+
# self.conv = get_conv_template('minicpm')
|
99 |
+
# print(self.conv)
|
100 |
+
# self.conv = get_conversation_template(model_name)
|
101 |
+
self.conv_id = uuid.uuid4().hex
|
102 |
+
self.skip_next = False
|
103 |
+
self.model_name = model_name
|
104 |
+
|
105 |
+
if model_name == "palm-2":
|
106 |
+
# According to release note, "chat-bison@001" is PaLM 2 for chat.
|
107 |
+
# https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023
|
108 |
+
self.palm_chat = init_palm_chat("chat-bison@001")
|
109 |
+
|
110 |
+
def to_gradio_chatbot(self):
|
111 |
+
return self.conv.to_gradio_chatbot()
|
112 |
+
|
113 |
+
def dict(self):
|
114 |
+
base = self.conv.dict()
|
115 |
+
base.update(
|
116 |
+
{
|
117 |
+
"conv_id": self.conv_id,
|
118 |
+
"model_name": self.model_name,
|
119 |
+
}
|
120 |
+
)
|
121 |
+
return base
|
122 |
+
|
123 |
+
|
124 |
+
def set_global_vars(controller_url_, enable_moderation_):
|
125 |
+
global controller_url, enable_moderation
|
126 |
+
controller_url = controller_url_
|
127 |
+
enable_moderation = enable_moderation_
|
128 |
+
|
129 |
+
|
130 |
+
def get_conv_log_filename():
|
131 |
+
t = datetime.datetime.now()
|
132 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
133 |
+
return name
|
134 |
+
|
135 |
+
|
136 |
+
def get_model_list(
|
137 |
+
controller_url, register_openai_compatible_models, add_chatgpt, add_claude, add_palm
|
138 |
+
):
|
139 |
+
if controller_url:
|
140 |
+
ret = requests.post(controller_url + "/refresh_all_workers")
|
141 |
+
assert ret.status_code == 200
|
142 |
+
ret = requests.post(controller_url + "/list_models")
|
143 |
+
# ret = requests.post(controller_url + "/get_worker_address")
|
144 |
+
# ret = requests.post(controller_url + "/worker_get_status")
|
145 |
+
models = ret.json()["models"]
|
146 |
+
else:
|
147 |
+
models = []
|
148 |
+
|
149 |
+
# Add API providers
|
150 |
+
if register_openai_compatible_models:
|
151 |
+
global openai_compatible_models_info
|
152 |
+
openai_compatible_models_info = json.load(
|
153 |
+
open(register_openai_compatible_models)
|
154 |
+
)
|
155 |
+
models += list(openai_compatible_models_info.keys())
|
156 |
+
|
157 |
+
if add_chatgpt:
|
158 |
+
models += ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]
|
159 |
+
if add_claude:
|
160 |
+
models += ["claude-2", "claude-instant-1"]
|
161 |
+
if add_palm:
|
162 |
+
models += ["palm-2"]
|
163 |
+
models = list(set(models))
|
164 |
+
|
165 |
+
if "deluxe-chat-v1" in models:
|
166 |
+
del models[models.index("deluxe-chat-v1")]
|
167 |
+
if "deluxe-chat-v1.1" in models:
|
168 |
+
del models[models.index("deluxe-chat-v1.1")]
|
169 |
+
|
170 |
+
priority = {k: f"___{i:02d}" for i, k in enumerate(model_info)}
|
171 |
+
models.sort(key=lambda x: priority.get(x, x))
|
172 |
+
logger.info(f"Models: {models}")
|
173 |
+
return models
|
174 |
+
|
175 |
+
|
176 |
+
def load_demo_single(models, url_params):
|
177 |
+
selected_model = models[0] if len(models) > 0 else ""
|
178 |
+
if "model" in url_params:
|
179 |
+
model = url_params["model"]
|
180 |
+
if model in models:
|
181 |
+
selected_model = model
|
182 |
+
|
183 |
+
dropdown_update = gr.Dropdown.update(
|
184 |
+
choices=models, value=selected_model, visible=True
|
185 |
+
)
|
186 |
+
|
187 |
+
state = None
|
188 |
+
return state, dropdown_update
|
189 |
+
|
190 |
+
|
191 |
+
def load_demo(url_params, request: gr.Request):
|
192 |
+
global models
|
193 |
+
|
194 |
+
ip = get_ip(request)
|
195 |
+
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
|
196 |
+
ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
|
197 |
+
|
198 |
+
if args.model_list_mode == "reload":
|
199 |
+
models = get_model_list(
|
200 |
+
controller_url,
|
201 |
+
args.register_openai_compatible_models,
|
202 |
+
args.add_chatgpt,
|
203 |
+
args.add_claude,
|
204 |
+
args.add_palm,
|
205 |
+
)
|
206 |
+
|
207 |
+
return load_demo_single(models, url_params)
|
208 |
+
|
209 |
+
|
210 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
211 |
+
with open('./web_chat_downvote.jsonl', "a+") as fout:
|
212 |
+
# data = {
|
213 |
+
# "tstamp": round(time.time(), 4),
|
214 |
+
# "type": vote_type,
|
215 |
+
# "model": model_selector,
|
216 |
+
# "state": state.dict(),
|
217 |
+
# "ip": get_ip(request),
|
218 |
+
# }
|
219 |
+
conversations = []
|
220 |
+
for i, turn in enumerate(state.dict()['messages']):
|
221 |
+
role = 'user' if i % 2 == 0 else 'assistant'
|
222 |
+
conversations.append({'role': role, 'content': turn[1]})
|
223 |
+
data = {
|
224 |
+
'conversations': conversations,
|
225 |
+
'idx': state.dict()['conv_id'],
|
226 |
+
'tinder': 'badcase',
|
227 |
+
'model': state.dict()['model_name'],
|
228 |
+
'tokens_in': -1,
|
229 |
+
'tokens_out': -1,
|
230 |
+
}
|
231 |
+
fout.write(json.dumps(data, ensure_ascii=False) + "\n")
|
232 |
+
|
233 |
+
|
234 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
235 |
+
ip = get_ip(request)
|
236 |
+
logger.info(f"upvote. ip: {ip}")
|
237 |
+
vote_last_response(state, "upvote", model_selector, request)
|
238 |
+
return ("",) + (disable_btn,) * 3
|
239 |
+
|
240 |
+
|
241 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
242 |
+
ip = get_ip(request)
|
243 |
+
logger.info(f"downvote. ip: {ip}")
|
244 |
+
vote_last_response(state, "downvote", model_selector, request)
|
245 |
+
return ("",) + (disable_btn,) * 3
|
246 |
+
|
247 |
+
|
248 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
249 |
+
ip = get_ip(request)
|
250 |
+
logger.info(f"flag. ip: {ip}")
|
251 |
+
vote_last_response(state, "flag", model_selector, request)
|
252 |
+
return ("",) + (disable_btn,) * 3
|
253 |
+
|
254 |
+
|
255 |
+
def regenerate(state, request: gr.Request):
|
256 |
+
ip = get_ip(request)
|
257 |
+
logger.info(f"regenerate. ip: {ip}")
|
258 |
+
state.conv.update_last_message(None)
|
259 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
260 |
+
|
261 |
+
|
262 |
+
def clear_history(request: gr.Request):
|
263 |
+
ip = get_ip(request)
|
264 |
+
logger.info(f"clear_history. ip: {ip}")
|
265 |
+
state = None
|
266 |
+
return (state, [], "") + (disable_btn,) * 5
|
267 |
+
|
268 |
+
|
269 |
+
def get_ip(request: gr.Request):
|
270 |
+
if "cf-connecting-ip" in request.headers:
|
271 |
+
ip = request.headers["cf-connecting-ip"]
|
272 |
+
else:
|
273 |
+
ip = request.client.host
|
274 |
+
return ip
|
275 |
+
|
276 |
+
|
277 |
+
def add_text(state, model_selector, text, request: gr.Request):
|
278 |
+
ip = get_ip(request)
|
279 |
+
logger.info(f"add_text. ip: {ip}. len: {len(text)}")
|
280 |
+
|
281 |
+
if state is None:
|
282 |
+
state = State(model_selector)
|
283 |
+
|
284 |
+
if len(text) <= 0:
|
285 |
+
state.skip_next = True
|
286 |
+
return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 5
|
287 |
+
|
288 |
+
flagged = moderation_filter(text, [state.model_name])
|
289 |
+
if flagged:
|
290 |
+
logger.info(f"violate moderation. ip: {ip}. text: {text}")
|
291 |
+
# overwrite the original text
|
292 |
+
text = MODERATION_MSG
|
293 |
+
|
294 |
+
conv = state.conv
|
295 |
+
if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT:
|
296 |
+
logger.info(f"conversation turn limit. ip: {ip}. text: {text}")
|
297 |
+
state.skip_next = True
|
298 |
+
return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + (
|
299 |
+
no_change_btn,
|
300 |
+
) * 5
|
301 |
+
|
302 |
+
text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off
|
303 |
+
conv.append_message(conv.roles[0], text)
|
304 |
+
conv.append_message(conv.roles[1], None)
|
305 |
+
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 5
|
306 |
+
|
307 |
+
|
308 |
+
def post_process_code(code):
|
309 |
+
sep = "\n```"
|
310 |
+
if sep in code:
|
311 |
+
blocks = code.split(sep)
|
312 |
+
if len(blocks) % 2 == 1:
|
313 |
+
for i in range(1, len(blocks), 2):
|
314 |
+
blocks[i] = blocks[i].replace("\\_", "_")
|
315 |
+
code = sep.join(blocks)
|
316 |
+
return code
|
317 |
+
|
318 |
+
|
319 |
+
def model_worker_stream_iter(
|
320 |
+
conv,
|
321 |
+
model_name,
|
322 |
+
worker_addr,
|
323 |
+
prompt,
|
324 |
+
temperature,
|
325 |
+
repetition_penalty,
|
326 |
+
top_p,
|
327 |
+
max_new_tokens,
|
328 |
+
):
|
329 |
+
# Make requests
|
330 |
+
gen_params = {
|
331 |
+
"model": model_name,
|
332 |
+
"prompt": prompt,
|
333 |
+
"temperature": temperature,
|
334 |
+
"repetition_penalty": repetition_penalty,
|
335 |
+
"top_p": top_p,
|
336 |
+
"max_new_tokens": max_new_tokens,
|
337 |
+
"stop": conv.stop_str,
|
338 |
+
"stop_token_ids": conv.stop_token_ids,
|
339 |
+
"echo": False,
|
340 |
+
}
|
341 |
+
logger.info(f"==== request ====\n{gen_params}")
|
342 |
+
|
343 |
+
# Stream output
|
344 |
+
response = requests.post(
|
345 |
+
worker_addr + "/worker_generate_stream",
|
346 |
+
headers=headers,
|
347 |
+
json=gen_params,
|
348 |
+
stream=True,
|
349 |
+
timeout=WORKER_API_TIMEOUT,
|
350 |
+
)
|
351 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
352 |
+
if chunk:
|
353 |
+
data = json.loads(chunk.decode())
|
354 |
+
yield data
|
355 |
+
|
356 |
+
|
357 |
+
def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request):
|
358 |
+
ip = get_ip(request)
|
359 |
+
logger.info(f"bot_response. ip: {ip}")
|
360 |
+
start_tstamp = time.time()
|
361 |
+
temperature = float(temperature)
|
362 |
+
top_p = float(top_p)
|
363 |
+
max_new_tokens = int(max_new_tokens)
|
364 |
+
|
365 |
+
if state.skip_next:
|
366 |
+
# This generate call is skipped due to invalid inputs
|
367 |
+
state.skip_next = False
|
368 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
369 |
+
return
|
370 |
+
|
371 |
+
conv, model_name = state.conv, state.model_name
|
372 |
+
if model_name in ["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", "gpt-3.5-turbo-1106"]:
|
373 |
+
prompt = conv.to_openai_api_messages()
|
374 |
+
stream_iter = openai_api_stream_iter(
|
375 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
376 |
+
)
|
377 |
+
elif model_name in ["claude-2", "claude-1", "claude-instant-1"]:
|
378 |
+
prompt = conv.get_prompt()
|
379 |
+
stream_iter = anthropic_api_stream_iter(
|
380 |
+
model_name, prompt, temperature, top_p, max_new_tokens
|
381 |
+
)
|
382 |
+
elif model_name == "palm-2":
|
383 |
+
stream_iter = palm_api_stream_iter(
|
384 |
+
state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens
|
385 |
+
)
|
386 |
+
elif model_name in openai_compatible_models_info:
|
387 |
+
model_info = openai_compatible_models_info[model_name]
|
388 |
+
prompt = conv.to_openai_api_messages()
|
389 |
+
stream_iter = openai_api_stream_iter(
|
390 |
+
model_info["model_name"],
|
391 |
+
prompt,
|
392 |
+
temperature,
|
393 |
+
top_p,
|
394 |
+
max_new_tokens,
|
395 |
+
api_base=model_info["api_base"],
|
396 |
+
api_key=model_info["api_key"],
|
397 |
+
)
|
398 |
+
else:
|
399 |
+
# Query worker address
|
400 |
+
ret = requests.post(
|
401 |
+
controller_url + "/get_worker_address", json={"model": model_name}
|
402 |
+
)
|
403 |
+
worker_addr = ret.json()["address"]
|
404 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
405 |
+
|
406 |
+
# No available worker
|
407 |
+
if worker_addr == "":
|
408 |
+
conv.update_last_message(SERVER_ERROR_MSG)
|
409 |
+
yield (
|
410 |
+
state,
|
411 |
+
state.to_gradio_chatbot(),
|
412 |
+
disable_btn,
|
413 |
+
disable_btn,
|
414 |
+
disable_btn,
|
415 |
+
enable_btn,
|
416 |
+
enable_btn,
|
417 |
+
)
|
418 |
+
return
|
419 |
+
|
420 |
+
# Construct prompt.
|
421 |
+
# We need to call it here, so it will not be affected by "▌".
|
422 |
+
prompt = conv.get_prompt()
|
423 |
+
# Set repetition_penalty
|
424 |
+
if "t5" in model_name:
|
425 |
+
repetition_penalty = 1.2
|
426 |
+
else:
|
427 |
+
repetition_penalty = 1.0
|
428 |
+
|
429 |
+
stream_iter = model_worker_stream_iter(
|
430 |
+
conv,
|
431 |
+
model_name,
|
432 |
+
worker_addr,
|
433 |
+
prompt,
|
434 |
+
temperature,
|
435 |
+
repetition_penalty,
|
436 |
+
top_p,
|
437 |
+
max_new_tokens,
|
438 |
+
)
|
439 |
+
|
440 |
+
conv.update_last_message("▌")
|
441 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
442 |
+
|
443 |
+
try:
|
444 |
+
for i, data in enumerate(stream_iter):
|
445 |
+
if data["error_code"] == 0:
|
446 |
+
output = data["text"].strip()
|
447 |
+
conv.update_last_message(output + "▌")
|
448 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
449 |
+
else:
|
450 |
+
output = data["text"] + f"\n\n(error_code: {data['error_code']})"
|
451 |
+
conv.update_last_message(output)
|
452 |
+
yield (state, state.to_gradio_chatbot()) + (
|
453 |
+
disable_btn,
|
454 |
+
disable_btn,
|
455 |
+
disable_btn,
|
456 |
+
enable_btn,
|
457 |
+
enable_btn,
|
458 |
+
)
|
459 |
+
return
|
460 |
+
output = data["text"].strip()
|
461 |
+
if "vicuna" in model_name:
|
462 |
+
output = post_process_code(output)
|
463 |
+
conv.update_last_message(output)
|
464 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
465 |
+
except requests.exceptions.RequestException as e:
|
466 |
+
conv.update_last_message(
|
467 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
468 |
+
f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})"
|
469 |
+
)
|
470 |
+
yield (state, state.to_gradio_chatbot()) + (
|
471 |
+
disable_btn,
|
472 |
+
disable_btn,
|
473 |
+
disable_btn,
|
474 |
+
enable_btn,
|
475 |
+
enable_btn,
|
476 |
+
)
|
477 |
+
return
|
478 |
+
except Exception as e:
|
479 |
+
conv.update_last_message(
|
480 |
+
f"{SERVER_ERROR_MSG}\n\n"
|
481 |
+
f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})"
|
482 |
+
)
|
483 |
+
yield (state, state.to_gradio_chatbot()) + (
|
484 |
+
disable_btn,
|
485 |
+
disable_btn,
|
486 |
+
disable_btn,
|
487 |
+
enable_btn,
|
488 |
+
enable_btn,
|
489 |
+
)
|
490 |
+
return
|
491 |
+
|
492 |
+
finish_tstamp = time.time()
|
493 |
+
logger.info(f"{output}")
|
494 |
+
|
495 |
+
with open(get_conv_log_filename(), "a") as fout:
|
496 |
+
data = {
|
497 |
+
"tstamp": round(finish_tstamp, 4),
|
498 |
+
"type": "chat",
|
499 |
+
"model": model_name,
|
500 |
+
"gen_params": {
|
501 |
+
"temperature": temperature,
|
502 |
+
"top_p": top_p,
|
503 |
+
"max_new_tokens": max_new_tokens,
|
504 |
+
},
|
505 |
+
"start": round(start_tstamp, 4),
|
506 |
+
"finish": round(finish_tstamp, 4),
|
507 |
+
"state": state.dict(),
|
508 |
+
"ip": get_ip(request),
|
509 |
+
}
|
510 |
+
fout.write(json.dumps(data) + "\n")
|
511 |
+
|
512 |
+
|
513 |
+
block_css = """
|
514 |
+
#notice_markdown {
|
515 |
+
font-size: 110%
|
516 |
+
}
|
517 |
+
#notice_markdown th {
|
518 |
+
display: none;
|
519 |
+
}
|
520 |
+
#notice_markdown td {
|
521 |
+
padding-top: 6px;
|
522 |
+
padding-bottom: 6px;
|
523 |
+
}
|
524 |
+
#leaderboard_markdown {
|
525 |
+
font-size: 110%
|
526 |
+
}
|
527 |
+
#leaderboard_markdown td {
|
528 |
+
padding-top: 6px;
|
529 |
+
padding-bottom: 6px;
|
530 |
+
}
|
531 |
+
#leaderboard_dataframe td {
|
532 |
+
line-height: 0.1em;
|
533 |
+
}
|
534 |
+
#about_markdown {
|
535 |
+
font-size: 110%
|
536 |
+
}
|
537 |
+
#input_box textarea {
|
538 |
+
}
|
539 |
+
footer {
|
540 |
+
display:none !important
|
541 |
+
}
|
542 |
+
.image-container {
|
543 |
+
display: flex;
|
544 |
+
align-items: center;
|
545 |
+
padding: 1px;
|
546 |
+
}
|
547 |
+
.image-container img {
|
548 |
+
margin: 0 30px;
|
549 |
+
height: 20px;
|
550 |
+
max-height: 100%;
|
551 |
+
width: auto;
|
552 |
+
max-width: 20%;
|
553 |
+
}
|
554 |
+
.image-about img {
|
555 |
+
margin: 0 30px;
|
556 |
+
margin-top: 30px;
|
557 |
+
height: 60px;
|
558 |
+
max-height: 100%;
|
559 |
+
width: auto;
|
560 |
+
float: left;
|
561 |
+
}
|
562 |
+
"""
|
563 |
+
|
564 |
+
|
565 |
+
def get_model_description_md(models):
|
566 |
+
model_description_md = """
|
567 |
+
| | | |
|
568 |
+
| ---- | ---- | ---- |
|
569 |
+
"""
|
570 |
+
ct = 0
|
571 |
+
visited = set()
|
572 |
+
for i, name in enumerate(models):
|
573 |
+
minfo = get_model_info(name)
|
574 |
+
if minfo.simple_name in visited:
|
575 |
+
continue
|
576 |
+
visited.add(minfo.simple_name)
|
577 |
+
one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}"
|
578 |
+
|
579 |
+
if ct % 3 == 0:
|
580 |
+
model_description_md += "|"
|
581 |
+
model_description_md += f" {one_model_md} |"
|
582 |
+
if ct % 3 == 2:
|
583 |
+
model_description_md += "\n"
|
584 |
+
ct += 1
|
585 |
+
return model_description_md
|
586 |
+
|
587 |
+
|
588 |
+
def build_about():
|
589 |
+
about_markdown = f"""
|
590 |
+
# About Us
|
591 |
+
Chatbot Arena is an open-source research project developed by members from [LMSYS](https://lmsys.org/about/) and UC Berkeley [SkyLab](https://sky.cs.berkeley.edu/). Our mission is to build an open crowdsourced platform to collect human feedback and evaluate LLMs under real-world scenarios. We open-source our code at [GitHub](https://github.com/lm-sys/FastChat) and release chat and human feedback datasets [here](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md). We invite everyone to join us in this journey!
|
592 |
+
|
593 |
+
## Read More
|
594 |
+
- Chatbot Arena [launch post](https://lmsys.org/blog/2023-05-03-arena/), [data release](https://lmsys.org/blog/2023-07-20-dataset/)
|
595 |
+
- LMSYS-Chat-1M [report](https://arxiv.org/abs/2309.11998)
|
596 |
+
|
597 |
+
## Core Members
|
598 |
+
[Lianmin Zheng](https://lmzheng.net/), [Wei-Lin Chiang](https://infwinston.github.io/), [Ying Sheng](https://sites.google.com/view/yingsheng/home), [Siyuan Zhuang](https://scholar.google.com/citations?user=KSZmI5EAAAAJ)
|
599 |
+
|
600 |
+
## Advisors
|
601 |
+
[Ion Stoica](http://people.eecs.berkeley.edu/~istoica/), [Joseph E. Gonzalez](https://people.eecs.berkeley.edu/~jegonzal/), [Hao Zhang](https://cseweb.ucsd.edu/~haozhang/)
|
602 |
+
|
603 |
+
## Contact Us
|
604 |
+
- Follow our [Twitter](https://twitter.com/lmsysorg), [Discord](https://discord.gg/HSWAKCrnFx) or email us at [email protected]
|
605 |
+
- File issues on [GitHub](https://github.com/lm-sys/FastChat)
|
606 |
+
- Download our datasets and models on [HuggingFace](https://huggingface.co/lmsys)
|
607 |
+
|
608 |
+
## Sponsors
|
609 |
+
We thank [Kaggle](https://www.kaggle.com/), [MBZUAI](https://mbzuai.ac.ae/), [Anyscale](https://www.anyscale.com/), [HuggingFace](https://huggingface.co/) for their generous sponsorship.
|
610 |
+
Learn more about partnership [here](https://lmsys.org/donations/).
|
611 |
+
|
612 |
+
<div class="image-about">
|
613 |
+
<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/7/7c/Kaggle_logo.png/400px-Kaggle_logo.png" alt="Image 1">
|
614 |
+
<img src="https://upload.wikimedia.org/wikipedia/en/5/55/Mohamed_bin_Zayed_University_of_Artificial_Intelligence_logo.png" alt="Image 2">
|
615 |
+
<img src="https://docs.anyscale.com/site-assets/logo.png" alt="Image 3">
|
616 |
+
<img src="https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.png" alt="Image 4">
|
617 |
+
</div>
|
618 |
+
"""
|
619 |
+
|
620 |
+
# state = gr.State()
|
621 |
+
gr.Markdown(about_markdown, elem_id="about_markdown")
|
622 |
+
|
623 |
+
# return [state]
|
624 |
+
|
625 |
+
|
626 |
+
def build_single_model_ui(models, add_promotion_links=False):
|
627 |
+
promotion = (
|
628 |
+
"""
|
629 |
+
- | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
630 |
+
- Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/)
|
631 |
+
- Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/)
|
632 |
+
"""
|
633 |
+
if add_promotion_links
|
634 |
+
else ""
|
635 |
+
)
|
636 |
+
|
637 |
+
notice_markdown = f"""
|
638 |
+
# 🏔️ Chat with Open Large Language Models
|
639 |
+
{promotion}
|
640 |
+
|
641 |
+
## 👉 Choose any model to chat
|
642 |
+
"""
|
643 |
+
|
644 |
+
state = gr.State()
|
645 |
+
model_description_md = get_model_description_md(models)
|
646 |
+
gr.Markdown(notice_markdown + model_description_md, elem_id="notice_markdown")
|
647 |
+
|
648 |
+
with gr.Row(elem_id="model_selector_row"):
|
649 |
+
model_selector = gr.Dropdown(
|
650 |
+
choices=models,
|
651 |
+
value=models[0] if len(models) > 0 else "",
|
652 |
+
interactive=True,
|
653 |
+
show_label=False,
|
654 |
+
container=False,
|
655 |
+
)
|
656 |
+
|
657 |
+
chatbot = gr.Chatbot(
|
658 |
+
elem_id="chatbot",
|
659 |
+
label="Scroll down and start chatting",
|
660 |
+
height=550,
|
661 |
+
)
|
662 |
+
with gr.Row():
|
663 |
+
with gr.Column(scale=20):
|
664 |
+
textbox = gr.Textbox(
|
665 |
+
show_label=False,
|
666 |
+
placeholder="Enter your prompt here and press ENTER",
|
667 |
+
container=False,
|
668 |
+
elem_id="input_box",
|
669 |
+
)
|
670 |
+
with gr.Column(scale=1, min_width=50):
|
671 |
+
send_btn = gr.Button(value="Send", variant="primary")
|
672 |
+
|
673 |
+
with gr.Row() as button_row:
|
674 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
675 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
676 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
677 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
678 |
+
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
679 |
+
|
680 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
681 |
+
temperature = gr.Slider(
|
682 |
+
minimum=0.0,
|
683 |
+
maximum=1.0,
|
684 |
+
value=0.7,
|
685 |
+
step=0.1,
|
686 |
+
interactive=True,
|
687 |
+
label="Temperature",
|
688 |
+
)
|
689 |
+
top_p = gr.Slider(
|
690 |
+
minimum=0.0,
|
691 |
+
maximum=1.0,
|
692 |
+
value=1.0,
|
693 |
+
step=0.1,
|
694 |
+
interactive=True,
|
695 |
+
label="Top P",
|
696 |
+
)
|
697 |
+
max_output_tokens = gr.Slider(
|
698 |
+
minimum=16,
|
699 |
+
maximum=3072,
|
700 |
+
value=2048,
|
701 |
+
step=1,
|
702 |
+
interactive=True,
|
703 |
+
label="Max output tokens",
|
704 |
+
)
|
705 |
+
|
706 |
+
if add_promotion_links:
|
707 |
+
gr.Markdown(acknowledgment_md)
|
708 |
+
|
709 |
+
# Register listeners
|
710 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
711 |
+
upvote_btn.click(
|
712 |
+
upvote_last_response,
|
713 |
+
[state, model_selector],
|
714 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
715 |
+
)
|
716 |
+
downvote_btn.click(
|
717 |
+
downvote_last_response,
|
718 |
+
[state, model_selector],
|
719 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
720 |
+
)
|
721 |
+
flag_btn.click(
|
722 |
+
flag_last_response,
|
723 |
+
[state, model_selector],
|
724 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
725 |
+
)
|
726 |
+
regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then(
|
727 |
+
bot_response,
|
728 |
+
[state, temperature, top_p, max_output_tokens],
|
729 |
+
[state, chatbot] + btn_list,
|
730 |
+
)
|
731 |
+
clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list)
|
732 |
+
|
733 |
+
model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list)
|
734 |
+
|
735 |
+
textbox.submit(
|
736 |
+
add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list
|
737 |
+
).then(
|
738 |
+
bot_response,
|
739 |
+
[state, temperature, top_p, max_output_tokens],
|
740 |
+
[state, chatbot] + btn_list,
|
741 |
+
)
|
742 |
+
send_btn.click(
|
743 |
+
add_text,
|
744 |
+
[state, model_selector, textbox],
|
745 |
+
[state, chatbot, textbox] + btn_list,
|
746 |
+
).then(
|
747 |
+
bot_response,
|
748 |
+
[state, temperature, top_p, max_output_tokens],
|
749 |
+
[state, chatbot] + btn_list,
|
750 |
+
)
|
751 |
+
|
752 |
+
return [state, model_selector]
|
753 |
+
|
754 |
+
|
755 |
+
def build_demo(models):
|
756 |
+
with gr.Blocks(
|
757 |
+
title="Chat with Open Large Language Models",
|
758 |
+
theme=gr.themes.Default(),
|
759 |
+
css=block_css,
|
760 |
+
) as demo:
|
761 |
+
url_params = gr.JSON(visible=False)
|
762 |
+
|
763 |
+
state, model_selector = build_single_model_ui(models)
|
764 |
+
|
765 |
+
if args.model_list_mode not in ["once", "reload"]:
|
766 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
767 |
+
|
768 |
+
if args.show_terms_of_use:
|
769 |
+
load_js = get_window_url_params_with_tos_js
|
770 |
+
else:
|
771 |
+
load_js = get_window_url_params_js
|
772 |
+
|
773 |
+
demo.load(
|
774 |
+
load_demo,
|
775 |
+
[url_params],
|
776 |
+
[
|
777 |
+
state,
|
778 |
+
model_selector,
|
779 |
+
],
|
780 |
+
_js=load_js,
|
781 |
+
)
|
782 |
+
|
783 |
+
return demo
|
784 |
+
|
785 |
+
|
786 |
+
if __name__ == "__main__":
|
787 |
+
parser = argparse.ArgumentParser()
|
788 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
789 |
+
parser.add_argument("--port", type=int)
|
790 |
+
parser.add_argument(
|
791 |
+
"--conv-template",
|
792 |
+
type=str,
|
793 |
+
default="megrez",
|
794 |
+
help="The address of the controller",
|
795 |
+
)
|
796 |
+
parser.add_argument(
|
797 |
+
"--share",
|
798 |
+
action="store_true",
|
799 |
+
help="Whether to generate a public, shareable link",
|
800 |
+
)
|
801 |
+
parser.add_argument(
|
802 |
+
"--controller-url",
|
803 |
+
type=str,
|
804 |
+
default="http://localhost:21001",
|
805 |
+
help="The address of the controller",
|
806 |
+
)
|
807 |
+
parser.add_argument(
|
808 |
+
"--concurrency-count",
|
809 |
+
type=int,
|
810 |
+
default=10,
|
811 |
+
help="The concurrency count of the gradio queue",
|
812 |
+
)
|
813 |
+
parser.add_argument(
|
814 |
+
"--model-list-mode",
|
815 |
+
type=str,
|
816 |
+
default="once",
|
817 |
+
choices=["once", "reload"],
|
818 |
+
help="Whether to load the model list once or reload the model list every time",
|
819 |
+
)
|
820 |
+
parser.add_argument(
|
821 |
+
"--moderate",
|
822 |
+
action="store_true",
|
823 |
+
help="Enable content moderation to block unsafe inputs",
|
824 |
+
)
|
825 |
+
parser.add_argument(
|
826 |
+
"--show-terms-of-use",
|
827 |
+
action="store_true",
|
828 |
+
help="Shows term of use before loading the demo",
|
829 |
+
)
|
830 |
+
parser.add_argument(
|
831 |
+
"--add-chatgpt",
|
832 |
+
action="store_true",
|
833 |
+
help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
|
834 |
+
)
|
835 |
+
parser.add_argument(
|
836 |
+
"--add-claude",
|
837 |
+
action="store_true",
|
838 |
+
help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
|
839 |
+
)
|
840 |
+
parser.add_argument(
|
841 |
+
"--add-palm",
|
842 |
+
action="store_true",
|
843 |
+
help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
|
844 |
+
)
|
845 |
+
parser.add_argument(
|
846 |
+
"--register-openai-compatible-models",
|
847 |
+
type=str,
|
848 |
+
help="Register custom OpenAI API compatible models by loading them from a JSON file",
|
849 |
+
)
|
850 |
+
parser.add_argument(
|
851 |
+
"--gradio-auth-path",
|
852 |
+
type=str,
|
853 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
854 |
+
)
|
855 |
+
args = parser.parse_args()
|
856 |
+
logger.info(f"args: {args}")
|
857 |
+
CONV_TEMPLATE = args.conv_template
|
858 |
+
# Set global variables
|
859 |
+
set_global_vars(args.controller_url, args.moderate)
|
860 |
+
models = get_model_list(
|
861 |
+
args.controller_url,
|
862 |
+
args.register_openai_compatible_models,
|
863 |
+
args.add_chatgpt,
|
864 |
+
args.add_claude,
|
865 |
+
args.add_palm,
|
866 |
+
)
|
867 |
+
# Set authorization credentials
|
868 |
+
auth = None
|
869 |
+
if args.gradio_auth_path is not None:
|
870 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
871 |
+
|
872 |
+
# Launch the demo
|
873 |
+
demo = build_demo(models)
|
874 |
+
ret = demo.queue(
|
875 |
+
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
876 |
+
).launch(
|
877 |
+
server_name=args.host,
|
878 |
+
server_port=args.port,
|
879 |
+
share=args.share,
|
880 |
+
max_threads=200,
|
881 |
+
auth=auth,
|
882 |
+
)
|
883 |
+
from IPython import embed;embed()
|
gradio_web_server_multi.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
The gradio demo server with multiple tabs.
|
3 |
+
It supports chatting with a single model or chatting with two models side-by-side.
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import pickle
|
8 |
+
import time
|
9 |
+
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
from fastchat.constants import (
|
13 |
+
SESSION_EXPIRATION_TIME,
|
14 |
+
)
|
15 |
+
from fastchat.serve.gradio_block_arena_anony import (
|
16 |
+
build_side_by_side_ui_anony,
|
17 |
+
load_demo_side_by_side_anony,
|
18 |
+
set_global_vars_anony,
|
19 |
+
)
|
20 |
+
from fastchat.serve.gradio_block_arena_named import (
|
21 |
+
build_side_by_side_ui_named,
|
22 |
+
load_demo_side_by_side_named,
|
23 |
+
set_global_vars_named,
|
24 |
+
)
|
25 |
+
from fastchat.serve.gradio_web_server import (
|
26 |
+
set_global_vars,
|
27 |
+
block_css,
|
28 |
+
build_single_model_ui,
|
29 |
+
build_about,
|
30 |
+
get_model_list,
|
31 |
+
load_demo_single,
|
32 |
+
ip_expiration_dict,
|
33 |
+
get_ip,
|
34 |
+
)
|
35 |
+
from fastchat.serve.monitor.monitor import build_leaderboard_tab
|
36 |
+
from fastchat.utils import (
|
37 |
+
build_logger,
|
38 |
+
get_window_url_params_js,
|
39 |
+
get_window_url_params_with_tos_js,
|
40 |
+
parse_gradio_auth_creds,
|
41 |
+
)
|
42 |
+
|
43 |
+
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")
|
44 |
+
|
45 |
+
|
46 |
+
def load_demo(url_params, request: gr.Request):
|
47 |
+
global models
|
48 |
+
|
49 |
+
ip = get_ip(request)
|
50 |
+
logger.info(f"load_demo. ip: {ip}. params: {url_params}")
|
51 |
+
ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME
|
52 |
+
|
53 |
+
selected = 0
|
54 |
+
if "arena" in url_params:
|
55 |
+
selected = 0
|
56 |
+
elif "compare" in url_params:
|
57 |
+
selected = 1
|
58 |
+
elif "single" in url_params:
|
59 |
+
selected = 2
|
60 |
+
elif "leaderboard" in url_params:
|
61 |
+
selected = 3
|
62 |
+
|
63 |
+
if args.model_list_mode == "reload":
|
64 |
+
if args.anony_only_for_proprietary_model:
|
65 |
+
models = get_model_list(
|
66 |
+
args.controller_url,
|
67 |
+
args.register_openai_compatible_models,
|
68 |
+
False,
|
69 |
+
False,
|
70 |
+
False,
|
71 |
+
)
|
72 |
+
else:
|
73 |
+
models = get_model_list(
|
74 |
+
args.controller_url,
|
75 |
+
args.register_openai_compatible_models,
|
76 |
+
args.add_chatgpt,
|
77 |
+
args.add_claude,
|
78 |
+
args.add_palm,
|
79 |
+
)
|
80 |
+
|
81 |
+
single_updates = load_demo_single(models, url_params)
|
82 |
+
|
83 |
+
models_anony = list(models)
|
84 |
+
if args.anony_only_for_proprietary_model:
|
85 |
+
# Only enable these models in anony battles.
|
86 |
+
if args.add_chatgpt:
|
87 |
+
models_anony += [
|
88 |
+
"gpt-4",
|
89 |
+
"gpt-3.5-turbo",
|
90 |
+
"gpt-4-turbo",
|
91 |
+
"gpt-3.5-turbo-1106",
|
92 |
+
]
|
93 |
+
if args.add_claude:
|
94 |
+
models_anony += ["claude-2", "claude-1", "claude-instant-1"]
|
95 |
+
if args.add_palm:
|
96 |
+
models_anony += ["palm-2"]
|
97 |
+
models_anony = list(set(models_anony))
|
98 |
+
|
99 |
+
side_by_side_anony_updates = load_demo_side_by_side_anony(models_anony, url_params)
|
100 |
+
side_by_side_named_updates = load_demo_side_by_side_named(models, url_params)
|
101 |
+
return (
|
102 |
+
(gr.Tabs.update(selected=selected),)
|
103 |
+
+ single_updates
|
104 |
+
+ side_by_side_anony_updates
|
105 |
+
+ side_by_side_named_updates
|
106 |
+
)
|
107 |
+
|
108 |
+
|
109 |
+
def build_demo(models, elo_results_file, leaderboard_table_file):
|
110 |
+
text_size = gr.themes.sizes.text_md
|
111 |
+
with gr.Blocks(
|
112 |
+
title="Chat with Open Large Language Models",
|
113 |
+
theme=gr.themes.Default(text_size=text_size),
|
114 |
+
css=block_css,
|
115 |
+
) as demo:
|
116 |
+
with gr.Tabs() as tabs:
|
117 |
+
with gr.Tab("Arena (battle)", id=0):
|
118 |
+
side_by_side_anony_list = build_side_by_side_ui_anony(models)
|
119 |
+
|
120 |
+
with gr.Tab("Arena (side-by-side)", id=1):
|
121 |
+
side_by_side_named_list = build_side_by_side_ui_named(models)
|
122 |
+
|
123 |
+
with gr.Tab("Direct Chat", id=2):
|
124 |
+
single_model_list = build_single_model_ui(
|
125 |
+
models, add_promotion_links=True
|
126 |
+
)
|
127 |
+
if elo_results_file:
|
128 |
+
with gr.Tab("Leaderboard", id=3):
|
129 |
+
build_leaderboard_tab(elo_results_file, leaderboard_table_file)
|
130 |
+
with gr.Tab("About Us", id=4):
|
131 |
+
about = build_about()
|
132 |
+
|
133 |
+
url_params = gr.JSON(visible=False)
|
134 |
+
|
135 |
+
if args.model_list_mode not in ["once", "reload"]:
|
136 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
137 |
+
|
138 |
+
if args.show_terms_of_use:
|
139 |
+
load_js = get_window_url_params_with_tos_js
|
140 |
+
else:
|
141 |
+
load_js = get_window_url_params_js
|
142 |
+
|
143 |
+
demo.load(
|
144 |
+
load_demo,
|
145 |
+
[url_params],
|
146 |
+
[tabs]
|
147 |
+
+ single_model_list
|
148 |
+
+ side_by_side_anony_list
|
149 |
+
+ side_by_side_named_list,
|
150 |
+
_js=load_js,
|
151 |
+
)
|
152 |
+
|
153 |
+
return demo
|
154 |
+
|
155 |
+
|
156 |
+
if __name__ == "__main__":
|
157 |
+
parser = argparse.ArgumentParser()
|
158 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
159 |
+
parser.add_argument("--port", type=int)
|
160 |
+
parser.add_argument(
|
161 |
+
"--share",
|
162 |
+
action="store_true",
|
163 |
+
help="Whether to generate a public, shareable link",
|
164 |
+
)
|
165 |
+
parser.add_argument(
|
166 |
+
"--controller-url",
|
167 |
+
type=str,
|
168 |
+
default="http://localhost:21001",
|
169 |
+
help="The address of the controller",
|
170 |
+
)
|
171 |
+
parser.add_argument(
|
172 |
+
"--concurrency-count",
|
173 |
+
type=int,
|
174 |
+
default=10,
|
175 |
+
help="The concurrency count of the gradio queue",
|
176 |
+
)
|
177 |
+
parser.add_argument(
|
178 |
+
"--model-list-mode",
|
179 |
+
type=str,
|
180 |
+
default="once",
|
181 |
+
choices=["once", "reload"],
|
182 |
+
help="Whether to load the model list once or reload the model list every time.",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--moderate",
|
186 |
+
action="store_true",
|
187 |
+
help="Enable content moderation to block unsafe inputs",
|
188 |
+
)
|
189 |
+
parser.add_argument(
|
190 |
+
"--show-terms-of-use",
|
191 |
+
action="store_true",
|
192 |
+
help="Shows term of use before loading the demo",
|
193 |
+
)
|
194 |
+
parser.add_argument(
|
195 |
+
"--add-chatgpt",
|
196 |
+
action="store_true",
|
197 |
+
help="Add OpenAI's ChatGPT models (gpt-3.5-turbo, gpt-4)",
|
198 |
+
)
|
199 |
+
parser.add_argument(
|
200 |
+
"--add-claude",
|
201 |
+
action="store_true",
|
202 |
+
help="Add Anthropic's Claude models (claude-2, claude-instant-1)",
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--add-palm",
|
206 |
+
action="store_true",
|
207 |
+
help="Add Google's PaLM model (PaLM 2 for Chat: chat-bison@001)",
|
208 |
+
)
|
209 |
+
parser.add_argument(
|
210 |
+
"--anony-only-for-proprietary-model",
|
211 |
+
action="store_true",
|
212 |
+
help="Only add ChatGPT, Claude, Bard under anony battle tab",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--register-openai-compatible-models",
|
216 |
+
type=str,
|
217 |
+
help="Register custom OpenAI API compatible models by loading them from a JSON file",
|
218 |
+
)
|
219 |
+
parser.add_argument(
|
220 |
+
"--gradio-auth-path",
|
221 |
+
type=str,
|
222 |
+
help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"',
|
223 |
+
default=None,
|
224 |
+
)
|
225 |
+
parser.add_argument(
|
226 |
+
"--elo-results-file", type=str, help="Load leaderboard results and plots"
|
227 |
+
)
|
228 |
+
parser.add_argument(
|
229 |
+
"--leaderboard-table-file", type=str, help="Load leaderboard results and plots"
|
230 |
+
)
|
231 |
+
args = parser.parse_args()
|
232 |
+
logger.info(f"args: {args}")
|
233 |
+
|
234 |
+
# Set global variables
|
235 |
+
set_global_vars(args.controller_url, args.moderate)
|
236 |
+
set_global_vars_named(args.moderate)
|
237 |
+
set_global_vars_anony(args.moderate)
|
238 |
+
if args.anony_only_for_proprietary_model:
|
239 |
+
models = get_model_list(
|
240 |
+
args.controller_url,
|
241 |
+
args.register_openai_compatible_models,
|
242 |
+
False,
|
243 |
+
False,
|
244 |
+
False,
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
models = get_model_list(
|
248 |
+
args.controller_url,
|
249 |
+
args.register_openai_compatible_models,
|
250 |
+
args.add_chatgpt,
|
251 |
+
args.add_claude,
|
252 |
+
args.add_palm,
|
253 |
+
)
|
254 |
+
|
255 |
+
# Set authorization credentials
|
256 |
+
auth = None
|
257 |
+
if args.gradio_auth_path is not None:
|
258 |
+
auth = parse_gradio_auth_creds(args.gradio_auth_path)
|
259 |
+
|
260 |
+
# Launch the demo
|
261 |
+
demo = build_demo(models, args.elo_results_file, args.leaderboard_table_file)
|
262 |
+
demo.queue(
|
263 |
+
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
264 |
+
).launch(
|
265 |
+
server_name=args.host,
|
266 |
+
server_port=args.port,
|
267 |
+
share=args.share,
|
268 |
+
max_threads=200,
|
269 |
+
auth=auth,
|
270 |
+
)
|
huggingface_api.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Use FastChat with Hugging Face generation APIs.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -m fastchat.serve.huggingface_api --model lmsys/vicuna-7b-v1.5
|
6 |
+
python3 -m fastchat.serve.huggingface_api --model lmsys/fastchat-t5-3b-v1.0
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from fastchat.model import load_model, get_conversation_template, add_model_args
|
13 |
+
|
14 |
+
|
15 |
+
@torch.inference_mode()
|
16 |
+
def main(args):
|
17 |
+
# Load model
|
18 |
+
model, tokenizer = load_model(
|
19 |
+
args.model_path,
|
20 |
+
device=args.device,
|
21 |
+
num_gpus=args.num_gpus,
|
22 |
+
max_gpu_memory=args.max_gpu_memory,
|
23 |
+
load_8bit=args.load_8bit,
|
24 |
+
cpu_offloading=args.cpu_offloading,
|
25 |
+
revision=args.revision,
|
26 |
+
debug=args.debug,
|
27 |
+
)
|
28 |
+
|
29 |
+
# Build the prompt with a conversation template
|
30 |
+
msg = args.message
|
31 |
+
conv = get_conversation_template(args.model_path)
|
32 |
+
conv.append_message(conv.roles[0], msg)
|
33 |
+
conv.append_message(conv.roles[1], None)
|
34 |
+
prompt = conv.get_prompt()
|
35 |
+
|
36 |
+
# Run inference
|
37 |
+
inputs = tokenizer([prompt], return_tensors="pt").to(args.device)
|
38 |
+
output_ids = model.generate(
|
39 |
+
**inputs,
|
40 |
+
do_sample=True if args.temperature > 1e-5 else False,
|
41 |
+
temperature=args.temperature,
|
42 |
+
repetition_penalty=args.repetition_penalty,
|
43 |
+
max_new_tokens=args.max_new_tokens,
|
44 |
+
)
|
45 |
+
|
46 |
+
if model.config.is_encoder_decoder:
|
47 |
+
output_ids = output_ids[0]
|
48 |
+
else:
|
49 |
+
output_ids = output_ids[0][len(inputs["input_ids"][0]) :]
|
50 |
+
outputs = tokenizer.decode(
|
51 |
+
output_ids, skip_special_tokens=True, spaces_between_special_tokens=False
|
52 |
+
)
|
53 |
+
|
54 |
+
# Print results
|
55 |
+
print(f"{conv.roles[0]}: {msg}")
|
56 |
+
print(f"{conv.roles[1]}: {outputs}")
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
parser = argparse.ArgumentParser()
|
61 |
+
add_model_args(parser)
|
62 |
+
parser.add_argument("--temperature", type=float, default=0.7)
|
63 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
64 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
65 |
+
parser.add_argument("--debug", action="store_true")
|
66 |
+
parser.add_argument("--message", type=str, default="Hello! Who are you?")
|
67 |
+
args = parser.parse_args()
|
68 |
+
|
69 |
+
# Reset default repetition penalty for T5 models.
|
70 |
+
if "t5" in args.model_path and args.repetition_penalty == 1.0:
|
71 |
+
args.repetition_penalty = 1.2
|
72 |
+
|
73 |
+
main(args)
|
huggingface_api_worker.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker that calls huggingface inference endpoint.
|
3 |
+
|
4 |
+
Register models in a JSON file with the following format:
|
5 |
+
{
|
6 |
+
"falcon-180b-chat": {
|
7 |
+
"model_path": "tiiuae/falcon-180B-chat",
|
8 |
+
"api_base": "https://api-inference.huggingface.co/models",
|
9 |
+
"token": "hf_xxx",
|
10 |
+
"context_length": 2048,
|
11 |
+
"model_names": "falcon-180b-chat",
|
12 |
+
"conv_template": null
|
13 |
+
}
|
14 |
+
}
|
15 |
+
|
16 |
+
"model_path", "api_base", "token", and "context_length" are necessary, while others are optional.
|
17 |
+
"""
|
18 |
+
import argparse
|
19 |
+
import asyncio
|
20 |
+
import json
|
21 |
+
import uuid
|
22 |
+
from typing import List, Optional
|
23 |
+
|
24 |
+
import requests
|
25 |
+
import uvicorn
|
26 |
+
from fastapi import BackgroundTasks, FastAPI, Request
|
27 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
28 |
+
from huggingface_hub import InferenceClient
|
29 |
+
|
30 |
+
from fastchat.constants import SERVER_ERROR_MSG, ErrorCode
|
31 |
+
from fastchat.serve.base_model_worker import BaseModelWorker
|
32 |
+
from fastchat.utils import build_logger
|
33 |
+
|
34 |
+
worker_id = str(uuid.uuid4())[:8]
|
35 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
36 |
+
|
37 |
+
workers = []
|
38 |
+
worker_map = {}
|
39 |
+
app = FastAPI()
|
40 |
+
|
41 |
+
|
42 |
+
# reference to
|
43 |
+
# https://github.com/philschmid/easyllm/blob/cbd908b3b3f44a97a22cb0fc2c93df3660bacdad/easyllm/clients/huggingface.py#L374-L392
|
44 |
+
def get_gen_kwargs(
|
45 |
+
params,
|
46 |
+
seed: Optional[int] = None,
|
47 |
+
):
|
48 |
+
stop = params.get("stop", None)
|
49 |
+
if isinstance(stop, list):
|
50 |
+
stop_sequences = stop
|
51 |
+
elif isinstance(stop, str):
|
52 |
+
stop_sequences = [stop]
|
53 |
+
else:
|
54 |
+
stop_sequences = []
|
55 |
+
gen_kwargs = {
|
56 |
+
"do_sample": True,
|
57 |
+
"return_full_text": bool(params.get("echo", False)),
|
58 |
+
"max_new_tokens": int(params.get("max_new_tokens", 256)),
|
59 |
+
"top_p": float(params.get("top_p", 1.0)),
|
60 |
+
"temperature": float(params.get("temperature", 1.0)),
|
61 |
+
"stop_sequences": stop_sequences,
|
62 |
+
"repetition_penalty": float(params.get("repetition_penalty", 1.0)),
|
63 |
+
"top_k": params.get("top_k", None),
|
64 |
+
"seed": seed,
|
65 |
+
}
|
66 |
+
if gen_kwargs["top_p"] == 1:
|
67 |
+
gen_kwargs["top_p"] = 0.9999999
|
68 |
+
if gen_kwargs["top_p"] == 0:
|
69 |
+
gen_kwargs.pop("top_p")
|
70 |
+
if gen_kwargs["temperature"] == 0:
|
71 |
+
gen_kwargs.pop("temperature")
|
72 |
+
gen_kwargs["do_sample"] = False
|
73 |
+
return gen_kwargs
|
74 |
+
|
75 |
+
|
76 |
+
def could_be_stop(text, stop):
|
77 |
+
for s in stop:
|
78 |
+
if any(text.endswith(s[:i]) for i in range(1, len(s) + 1)):
|
79 |
+
return True
|
80 |
+
return False
|
81 |
+
|
82 |
+
|
83 |
+
class HuggingfaceApiWorker(BaseModelWorker):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
controller_addr: str,
|
87 |
+
worker_addr: str,
|
88 |
+
worker_id: str,
|
89 |
+
model_path: str,
|
90 |
+
api_base: str,
|
91 |
+
token: str,
|
92 |
+
context_length: int,
|
93 |
+
model_names: List[str],
|
94 |
+
limit_worker_concurrency: int,
|
95 |
+
no_register: bool,
|
96 |
+
conv_template: Optional[str] = None,
|
97 |
+
seed: Optional[int] = None,
|
98 |
+
**kwargs,
|
99 |
+
):
|
100 |
+
super().__init__(
|
101 |
+
controller_addr,
|
102 |
+
worker_addr,
|
103 |
+
worker_id,
|
104 |
+
model_path,
|
105 |
+
model_names,
|
106 |
+
limit_worker_concurrency,
|
107 |
+
conv_template=conv_template,
|
108 |
+
)
|
109 |
+
|
110 |
+
self.model_path = model_path
|
111 |
+
self.api_base = api_base
|
112 |
+
self.token = token
|
113 |
+
self.context_len = context_length
|
114 |
+
self.seed = seed
|
115 |
+
|
116 |
+
logger.info(
|
117 |
+
f"Connecting with huggingface api {self.model_path} as {self.model_names} on worker {worker_id} ..."
|
118 |
+
)
|
119 |
+
|
120 |
+
if not no_register:
|
121 |
+
self.init_heart_beat()
|
122 |
+
|
123 |
+
def count_token(self, params):
|
124 |
+
# No tokenizer here
|
125 |
+
ret = {
|
126 |
+
"count": 0,
|
127 |
+
"error_code": 0,
|
128 |
+
}
|
129 |
+
return ret
|
130 |
+
|
131 |
+
def generate_stream_gate(self, params):
|
132 |
+
self.call_ct += 1
|
133 |
+
|
134 |
+
prompt = params["prompt"]
|
135 |
+
gen_kwargs = get_gen_kwargs(params, seed=self.seed)
|
136 |
+
stop = gen_kwargs["stop_sequences"]
|
137 |
+
if "falcon" in self.model_path and "chat" in self.model_path:
|
138 |
+
stop.extend(["\nUser:", "<|endoftext|>", " User:", "###"])
|
139 |
+
stop = list(set(stop))
|
140 |
+
gen_kwargs["stop_sequences"] = stop
|
141 |
+
|
142 |
+
logger.info(f"prompt: {prompt}")
|
143 |
+
logger.info(f"gen_kwargs: {gen_kwargs}")
|
144 |
+
|
145 |
+
try:
|
146 |
+
if self.model_path == "":
|
147 |
+
url = f"{self.api_base}"
|
148 |
+
else:
|
149 |
+
url = f"{self.api_base}/{self.model_path}"
|
150 |
+
client = InferenceClient(url, token=self.token)
|
151 |
+
res = client.text_generation(
|
152 |
+
prompt, stream=True, details=True, **gen_kwargs
|
153 |
+
)
|
154 |
+
|
155 |
+
reason = None
|
156 |
+
text = ""
|
157 |
+
for chunk in res:
|
158 |
+
if chunk.token.special:
|
159 |
+
continue
|
160 |
+
text += chunk.token.text
|
161 |
+
|
162 |
+
s = next((x for x in stop if text.endswith(x)), None)
|
163 |
+
if s is not None:
|
164 |
+
text = text[: -len(s)]
|
165 |
+
reason = "stop"
|
166 |
+
break
|
167 |
+
if could_be_stop(text, stop):
|
168 |
+
continue
|
169 |
+
if (
|
170 |
+
chunk.details is not None
|
171 |
+
and chunk.details.finish_reason is not None
|
172 |
+
):
|
173 |
+
reason = chunk.details.finish_reason
|
174 |
+
if reason not in ["stop", "length"]:
|
175 |
+
reason = None
|
176 |
+
ret = {
|
177 |
+
"text": text,
|
178 |
+
"error_code": 0,
|
179 |
+
"finish_reason": reason,
|
180 |
+
}
|
181 |
+
yield json.dumps(ret).encode() + b"\0"
|
182 |
+
except Exception as e:
|
183 |
+
ret = {
|
184 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
185 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
186 |
+
}
|
187 |
+
yield json.dumps(ret).encode() + b"\0"
|
188 |
+
|
189 |
+
def generate_gate(self, params):
|
190 |
+
for x in self.generate_stream_gate(params):
|
191 |
+
pass
|
192 |
+
return json.loads(x[:-1].decode())
|
193 |
+
|
194 |
+
def get_embeddings(self, params):
|
195 |
+
raise NotImplementedError()
|
196 |
+
|
197 |
+
|
198 |
+
def release_worker_semaphore(worker):
|
199 |
+
worker.semaphore.release()
|
200 |
+
|
201 |
+
|
202 |
+
def acquire_worker_semaphore(worker):
|
203 |
+
if worker.semaphore is None:
|
204 |
+
worker.semaphore = asyncio.Semaphore(worker.limit_worker_concurrency)
|
205 |
+
return worker.semaphore.acquire()
|
206 |
+
|
207 |
+
|
208 |
+
def create_background_tasks(worker):
|
209 |
+
background_tasks = BackgroundTasks()
|
210 |
+
background_tasks.add_task(lambda: release_worker_semaphore(worker))
|
211 |
+
return background_tasks
|
212 |
+
|
213 |
+
|
214 |
+
@app.post("/worker_generate_stream")
|
215 |
+
async def api_generate_stream(request: Request):
|
216 |
+
params = await request.json()
|
217 |
+
worker = worker_map[params["model"]]
|
218 |
+
await acquire_worker_semaphore(worker)
|
219 |
+
generator = worker.generate_stream_gate(params)
|
220 |
+
background_tasks = create_background_tasks(worker)
|
221 |
+
return StreamingResponse(generator, background=background_tasks)
|
222 |
+
|
223 |
+
|
224 |
+
@app.post("/worker_generate")
|
225 |
+
async def api_generate(request: Request):
|
226 |
+
params = await request.json()
|
227 |
+
worker = worker_map[params["model"]]
|
228 |
+
await acquire_worker_semaphore(worker)
|
229 |
+
output = worker.generate_gate(params)
|
230 |
+
release_worker_semaphore(worker)
|
231 |
+
return JSONResponse(output)
|
232 |
+
|
233 |
+
|
234 |
+
@app.post("/worker_get_embeddings")
|
235 |
+
async def api_get_embeddings(request: Request):
|
236 |
+
params = await request.json()
|
237 |
+
worker = worker_map[params["model"]]
|
238 |
+
await acquire_worker_semaphore(worker)
|
239 |
+
embedding = worker.get_embeddings(params)
|
240 |
+
release_worker_semaphore(worker)
|
241 |
+
return JSONResponse(content=embedding)
|
242 |
+
|
243 |
+
|
244 |
+
@app.post("/worker_get_status")
|
245 |
+
async def api_get_status(request: Request):
|
246 |
+
return {
|
247 |
+
"model_names": [m for w in workers for m in w.model_names],
|
248 |
+
"speed": 1,
|
249 |
+
"queue_length": sum([w.get_queue_length() for w in workers]),
|
250 |
+
}
|
251 |
+
|
252 |
+
|
253 |
+
@app.post("/count_token")
|
254 |
+
async def api_count_token(request: Request):
|
255 |
+
params = await request.json()
|
256 |
+
worker = worker_map[params["model"]]
|
257 |
+
return worker.count_token(params)
|
258 |
+
|
259 |
+
|
260 |
+
@app.post("/worker_get_conv_template")
|
261 |
+
async def api_get_conv(request: Request):
|
262 |
+
params = await request.json()
|
263 |
+
worker = worker_map[params["model"]]
|
264 |
+
return worker.get_conv_template()
|
265 |
+
|
266 |
+
|
267 |
+
@app.post("/model_details")
|
268 |
+
async def api_model_details(request: Request):
|
269 |
+
params = await request.json()
|
270 |
+
worker = worker_map[params["model"]]
|
271 |
+
return {"context_length": worker.context_len}
|
272 |
+
|
273 |
+
|
274 |
+
def create_huggingface_api_worker():
|
275 |
+
parser = argparse.ArgumentParser()
|
276 |
+
parser.add_argument("--host", type=str, default="localhost")
|
277 |
+
parser.add_argument("--port", type=int, default=21002)
|
278 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
279 |
+
parser.add_argument(
|
280 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
281 |
+
)
|
282 |
+
# all model-related parameters are listed in --model-info-file
|
283 |
+
parser.add_argument(
|
284 |
+
"--model-info-file",
|
285 |
+
type=str,
|
286 |
+
required=True,
|
287 |
+
help="Huggingface API model's info file path",
|
288 |
+
)
|
289 |
+
|
290 |
+
parser.add_argument(
|
291 |
+
"--limit-worker-concurrency",
|
292 |
+
type=int,
|
293 |
+
default=5,
|
294 |
+
help="Limit the model concurrency to prevent OOM.",
|
295 |
+
)
|
296 |
+
parser.add_argument("--no-register", action="store_true")
|
297 |
+
parser.add_argument(
|
298 |
+
"--seed",
|
299 |
+
type=int,
|
300 |
+
default=None,
|
301 |
+
help="Overwrite the random seed for each generation.",
|
302 |
+
)
|
303 |
+
args = parser.parse_args()
|
304 |
+
|
305 |
+
with open(args.model_info_file, "r", encoding="UTF-8") as f:
|
306 |
+
model_info = json.load(f)
|
307 |
+
|
308 |
+
logger.info(f"args: {args}")
|
309 |
+
|
310 |
+
model_path_list = []
|
311 |
+
api_base_list = []
|
312 |
+
token_list = []
|
313 |
+
context_length_list = []
|
314 |
+
model_names_list = []
|
315 |
+
conv_template_list = []
|
316 |
+
|
317 |
+
for m in model_info:
|
318 |
+
model_path_list.append(model_info[m]["model_path"])
|
319 |
+
api_base_list.append(model_info[m]["api_base"])
|
320 |
+
token_list.append(model_info[m]["token"])
|
321 |
+
|
322 |
+
context_length = model_info[m]["context_length"]
|
323 |
+
model_names = model_info[m].get("model_names", [m.split("/")[-1]])
|
324 |
+
if isinstance(model_names, str):
|
325 |
+
model_names = [model_names]
|
326 |
+
conv_template = model_info[m].get("conv_template", None)
|
327 |
+
|
328 |
+
context_length_list.append(context_length)
|
329 |
+
model_names_list.append(model_names)
|
330 |
+
conv_template_list.append(conv_template)
|
331 |
+
|
332 |
+
logger.info(f"Model paths: {model_path_list}")
|
333 |
+
logger.info(f"API bases: {api_base_list}")
|
334 |
+
logger.info(f"Tokens: {token_list}")
|
335 |
+
logger.info(f"Context lengths: {context_length_list}")
|
336 |
+
logger.info(f"Model names: {model_names_list}")
|
337 |
+
logger.info(f"Conv templates: {conv_template_list}")
|
338 |
+
|
339 |
+
for (
|
340 |
+
model_names,
|
341 |
+
conv_template,
|
342 |
+
model_path,
|
343 |
+
api_base,
|
344 |
+
token,
|
345 |
+
context_length,
|
346 |
+
) in zip(
|
347 |
+
model_names_list,
|
348 |
+
conv_template_list,
|
349 |
+
model_path_list,
|
350 |
+
api_base_list,
|
351 |
+
token_list,
|
352 |
+
context_length_list,
|
353 |
+
):
|
354 |
+
m = HuggingfaceApiWorker(
|
355 |
+
args.controller_address,
|
356 |
+
args.worker_address,
|
357 |
+
worker_id,
|
358 |
+
model_path,
|
359 |
+
api_base,
|
360 |
+
token,
|
361 |
+
context_length,
|
362 |
+
model_names,
|
363 |
+
args.limit_worker_concurrency,
|
364 |
+
no_register=args.no_register,
|
365 |
+
conv_template=conv_template,
|
366 |
+
seed=args.seed,
|
367 |
+
)
|
368 |
+
workers.append(m)
|
369 |
+
for name in model_names:
|
370 |
+
worker_map[name] = m
|
371 |
+
|
372 |
+
# register all the models
|
373 |
+
url = args.controller_address + "/register_worker"
|
374 |
+
data = {
|
375 |
+
"worker_name": workers[0].worker_addr,
|
376 |
+
"check_heart_beat": not args.no_register,
|
377 |
+
"worker_status": {
|
378 |
+
"model_names": [m for w in workers for m in w.model_names],
|
379 |
+
"speed": 1,
|
380 |
+
"queue_length": sum([w.get_queue_length() for w in workers]),
|
381 |
+
},
|
382 |
+
}
|
383 |
+
r = requests.post(url, json=data)
|
384 |
+
assert r.status_code == 200
|
385 |
+
|
386 |
+
return args, workers
|
387 |
+
|
388 |
+
|
389 |
+
if __name__ == "__main__":
|
390 |
+
args, workers = create_huggingface_api_worker()
|
391 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
inference.py
ADDED
@@ -0,0 +1,596 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Inference for FastChat models."""
|
2 |
+
import abc
|
3 |
+
import gc
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import time
|
9 |
+
from typing import Iterable, Optional, Dict
|
10 |
+
import warnings
|
11 |
+
|
12 |
+
import psutil
|
13 |
+
import torch
|
14 |
+
from transformers import (
|
15 |
+
AutoTokenizer,
|
16 |
+
AutoModelForCausalLM,
|
17 |
+
LlamaTokenizer,
|
18 |
+
LlamaForCausalLM,
|
19 |
+
AutoModel,
|
20 |
+
AutoModelForSeq2SeqLM,
|
21 |
+
T5Tokenizer,
|
22 |
+
AutoConfig,
|
23 |
+
)
|
24 |
+
from transformers.generation.logits_process import (
|
25 |
+
LogitsProcessorList,
|
26 |
+
RepetitionPenaltyLogitsProcessor,
|
27 |
+
TemperatureLogitsWarper,
|
28 |
+
TopKLogitsWarper,
|
29 |
+
TopPLogitsWarper,
|
30 |
+
)
|
31 |
+
|
32 |
+
from fastchat.conversation import get_conv_template, SeparatorStyle
|
33 |
+
from fastchat.model.model_adapter import (
|
34 |
+
load_model,
|
35 |
+
get_conversation_template,
|
36 |
+
get_generate_stream_function,
|
37 |
+
)
|
38 |
+
from fastchat.modules.awq import AWQConfig
|
39 |
+
from fastchat.modules.gptq import GptqConfig
|
40 |
+
from fastchat.modules.exllama import ExllamaConfig
|
41 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
42 |
+
from fastchat.utils import is_partial_stop, is_sentence_complete, get_context_length
|
43 |
+
|
44 |
+
|
45 |
+
def prepare_logits_processor(
|
46 |
+
temperature: float, repetition_penalty: float, top_p: float, top_k: int
|
47 |
+
) -> LogitsProcessorList:
|
48 |
+
processor_list = LogitsProcessorList()
|
49 |
+
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
|
50 |
+
if temperature >= 1e-5 and temperature != 1.0:
|
51 |
+
processor_list.append(TemperatureLogitsWarper(temperature))
|
52 |
+
if repetition_penalty > 1.0:
|
53 |
+
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
|
54 |
+
if 1e-8 <= top_p < 1.0:
|
55 |
+
processor_list.append(TopPLogitsWarper(top_p))
|
56 |
+
if top_k > 0:
|
57 |
+
processor_list.append(TopKLogitsWarper(top_k))
|
58 |
+
return processor_list
|
59 |
+
|
60 |
+
|
61 |
+
@torch.inference_mode()
|
62 |
+
def generate_stream(
|
63 |
+
model,
|
64 |
+
tokenizer,
|
65 |
+
params: Dict,
|
66 |
+
device: str,
|
67 |
+
context_len: int,
|
68 |
+
stream_interval: int = 2,
|
69 |
+
judge_sent_end: bool = False,
|
70 |
+
):
|
71 |
+
if hasattr(model, "device"):
|
72 |
+
device = model.device
|
73 |
+
|
74 |
+
# Read parameters
|
75 |
+
prompt = params["prompt"]
|
76 |
+
len_prompt = len(prompt)
|
77 |
+
temperature = float(params.get("temperature", 1.0))
|
78 |
+
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
79 |
+
top_p = float(params.get("top_p", 1.0))
|
80 |
+
top_k = int(params.get("top_k", -1)) # -1 means disable
|
81 |
+
max_new_tokens = int(params.get("max_new_tokens", 256))
|
82 |
+
logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
|
83 |
+
echo = bool(params.get("echo", True))
|
84 |
+
stop_str = params.get("stop", None)
|
85 |
+
stop_token_ids = params.get("stop_token_ids", None) or []
|
86 |
+
if tokenizer.eos_token_id not in stop_token_ids:
|
87 |
+
stop_token_ids.append(tokenizer.eos_token_id)
|
88 |
+
if params.get('none_stop'):
|
89 |
+
stop_token_ids = []
|
90 |
+
skip_special_tokens = params.get('skip_special_tokens')
|
91 |
+
|
92 |
+
logits_processor = prepare_logits_processor(
|
93 |
+
temperature, repetition_penalty, top_p, top_k
|
94 |
+
)
|
95 |
+
input_ids = tokenizer(prompt).input_ids
|
96 |
+
|
97 |
+
if model.config.is_encoder_decoder:
|
98 |
+
max_src_len = context_len
|
99 |
+
else: # truncate
|
100 |
+
max_src_len = context_len - max_new_tokens - 1
|
101 |
+
|
102 |
+
input_ids = input_ids[-max_src_len:]
|
103 |
+
output_ids = list(input_ids)
|
104 |
+
input_echo_len = len(input_ids)
|
105 |
+
|
106 |
+
if model.config.is_encoder_decoder:
|
107 |
+
if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
|
108 |
+
raise NotImplementedError
|
109 |
+
encoder_output = model.encoder(
|
110 |
+
input_ids=torch.as_tensor([input_ids], device=device)
|
111 |
+
)[0]
|
112 |
+
start_ids = torch.as_tensor(
|
113 |
+
[[model.generation_config.decoder_start_token_id]],
|
114 |
+
dtype=torch.int64,
|
115 |
+
device=device,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
start_ids = torch.as_tensor([input_ids], device=device)
|
119 |
+
|
120 |
+
past_key_values = out = None
|
121 |
+
token_logprobs = [None] # The first token has no logprobs.
|
122 |
+
sent_interrupt = False
|
123 |
+
finish_reason = None
|
124 |
+
for i in range(max_new_tokens):
|
125 |
+
if i == 0: # prefill
|
126 |
+
if model.config.is_encoder_decoder:
|
127 |
+
out = model.decoder(
|
128 |
+
input_ids=start_ids,
|
129 |
+
encoder_hidden_states=encoder_output,
|
130 |
+
use_cache=True,
|
131 |
+
)
|
132 |
+
logits = model.lm_head(out[0])
|
133 |
+
else:
|
134 |
+
out = model(input_ids=start_ids, use_cache=True)
|
135 |
+
logits = out.logits
|
136 |
+
past_key_values = out.past_key_values
|
137 |
+
|
138 |
+
if logprobs is not None:
|
139 |
+
# Prefull logprobs for the prompt.
|
140 |
+
shift_input_ids = start_ids[..., 1:].contiguous()
|
141 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
142 |
+
shift_logits = torch.log_softmax(shift_logits, dim=-1).tolist()
|
143 |
+
for label_id, logit in zip(
|
144 |
+
shift_input_ids[0].tolist(), shift_logits[0]
|
145 |
+
):
|
146 |
+
token_logprobs.append(logit[label_id])
|
147 |
+
else: # decoding
|
148 |
+
if model.config.is_encoder_decoder:
|
149 |
+
out = model.decoder(
|
150 |
+
input_ids=torch.as_tensor(
|
151 |
+
[[token] if not sent_interrupt else output_ids],
|
152 |
+
device=device,
|
153 |
+
),
|
154 |
+
encoder_hidden_states=encoder_output,
|
155 |
+
use_cache=True,
|
156 |
+
past_key_values=past_key_values if not sent_interrupt else None,
|
157 |
+
)
|
158 |
+
sent_interrupt = False
|
159 |
+
|
160 |
+
logits = model.lm_head(out[0])
|
161 |
+
else:
|
162 |
+
out = model(
|
163 |
+
input_ids=torch.as_tensor(
|
164 |
+
[[token] if not sent_interrupt else output_ids],
|
165 |
+
device=device,
|
166 |
+
),
|
167 |
+
use_cache=True,
|
168 |
+
past_key_values=past_key_values if not sent_interrupt else None,
|
169 |
+
)
|
170 |
+
sent_interrupt = False
|
171 |
+
logits = out.logits
|
172 |
+
past_key_values = out.past_key_values
|
173 |
+
|
174 |
+
if logits_processor:
|
175 |
+
if repetition_penalty > 1.0:
|
176 |
+
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
|
177 |
+
else:
|
178 |
+
tmp_output_ids = None
|
179 |
+
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
|
180 |
+
else:
|
181 |
+
last_token_logits = logits[0, -1, :]
|
182 |
+
|
183 |
+
if device == "mps":
|
184 |
+
# Switch to CPU by avoiding some bugs in mps backend.
|
185 |
+
last_token_logits = last_token_logits.float().to("cpu")
|
186 |
+
|
187 |
+
if temperature < 1e-5 or top_p < 1e-8: # greedy
|
188 |
+
_, indices = torch.topk(last_token_logits, 2)
|
189 |
+
tokens = [int(index) for index in indices.tolist()]
|
190 |
+
else:
|
191 |
+
probs = torch.softmax(last_token_logits, dim=-1)
|
192 |
+
indices = torch.multinomial(probs, num_samples=2)
|
193 |
+
tokens = [int(token) for token in indices.tolist()]
|
194 |
+
token = tokens[0]
|
195 |
+
output_ids.append(token)
|
196 |
+
if logprobs is not None:
|
197 |
+
# Cannot use last_token_logits because logprobs is based on raw logits.
|
198 |
+
token_logprobs.append(
|
199 |
+
torch.log_softmax(logits[0, -1, :], dim=-1)[token].tolist()
|
200 |
+
)
|
201 |
+
|
202 |
+
if token in stop_token_ids:
|
203 |
+
stopped = True
|
204 |
+
else:
|
205 |
+
stopped = False
|
206 |
+
|
207 |
+
# Yield the output tokens
|
208 |
+
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
|
209 |
+
if echo:
|
210 |
+
tmp_output_ids = output_ids
|
211 |
+
rfind_start = len_prompt
|
212 |
+
else:
|
213 |
+
tmp_output_ids = output_ids[input_echo_len:]
|
214 |
+
rfind_start = 0
|
215 |
+
|
216 |
+
output = tokenizer.decode(
|
217 |
+
tmp_output_ids,
|
218 |
+
skip_special_tokens=skip_special_tokens,
|
219 |
+
spaces_between_special_tokens=False,
|
220 |
+
clean_up_tokenization_spaces=True,
|
221 |
+
)
|
222 |
+
ret_logprobs = None
|
223 |
+
if logprobs is not None:
|
224 |
+
ret_logprobs = {
|
225 |
+
"text_offset": [],
|
226 |
+
"tokens": [
|
227 |
+
tokenizer.decode(token)
|
228 |
+
for token in (
|
229 |
+
output_ids if echo else output_ids[input_echo_len:]
|
230 |
+
)
|
231 |
+
],
|
232 |
+
"token_logprobs": token_logprobs
|
233 |
+
if echo
|
234 |
+
else token_logprobs[input_echo_len:],
|
235 |
+
"top_logprobs": [{}]
|
236 |
+
* len(token_logprobs if echo else token_logprobs[input_echo_len:]),
|
237 |
+
}
|
238 |
+
# Compute text_offset
|
239 |
+
curr_pos = 0
|
240 |
+
for text in ret_logprobs["tokens"]:
|
241 |
+
ret_logprobs["text_offset"].append(curr_pos)
|
242 |
+
curr_pos += len(text)
|
243 |
+
|
244 |
+
# TODO: For the issue of incomplete sentences interrupting output, apply a patch and others can also modify it to a more elegant way
|
245 |
+
if judge_sent_end and stopped and not is_sentence_complete(output):
|
246 |
+
if len(tokens) > 1:
|
247 |
+
token = tokens[1]
|
248 |
+
output_ids[-1] = token
|
249 |
+
else:
|
250 |
+
output_ids.pop()
|
251 |
+
stopped = False
|
252 |
+
sent_interrupt = True
|
253 |
+
|
254 |
+
partially_stopped = False
|
255 |
+
if stop_str:
|
256 |
+
if isinstance(stop_str, str):
|
257 |
+
pos = output.rfind(stop_str, rfind_start)
|
258 |
+
if pos != -1:
|
259 |
+
output = output[:pos]
|
260 |
+
stopped = True
|
261 |
+
else:
|
262 |
+
partially_stopped = is_partial_stop(output, stop_str)
|
263 |
+
elif isinstance(stop_str, Iterable):
|
264 |
+
for each_stop in stop_str:
|
265 |
+
pos = output.rfind(each_stop, rfind_start)
|
266 |
+
if pos != -1:
|
267 |
+
output = output[:pos]
|
268 |
+
stopped = True
|
269 |
+
break
|
270 |
+
else:
|
271 |
+
partially_stopped = is_partial_stop(output, each_stop)
|
272 |
+
if partially_stopped:
|
273 |
+
break
|
274 |
+
else:
|
275 |
+
raise ValueError("Invalid stop field type.")
|
276 |
+
|
277 |
+
# Prevent yielding partial stop sequence
|
278 |
+
if not partially_stopped:
|
279 |
+
yield {
|
280 |
+
"text": output,
|
281 |
+
"logprobs": ret_logprobs,
|
282 |
+
"usage": {
|
283 |
+
"prompt_tokens": input_echo_len,
|
284 |
+
"completion_tokens": i,
|
285 |
+
"total_tokens": input_echo_len + i,
|
286 |
+
},
|
287 |
+
"finish_reason": None,
|
288 |
+
}
|
289 |
+
|
290 |
+
if stopped:
|
291 |
+
break
|
292 |
+
|
293 |
+
# Finish stream event, which contains finish reason
|
294 |
+
else:
|
295 |
+
finish_reason = "length"
|
296 |
+
|
297 |
+
if stopped:
|
298 |
+
finish_reason = "stop"
|
299 |
+
|
300 |
+
yield {
|
301 |
+
"text": output,
|
302 |
+
"logprobs": ret_logprobs,
|
303 |
+
"usage": {
|
304 |
+
"prompt_tokens": input_echo_len,
|
305 |
+
"completion_tokens": i,
|
306 |
+
"total_tokens": input_echo_len + i,
|
307 |
+
},
|
308 |
+
"finish_reason": finish_reason,
|
309 |
+
}
|
310 |
+
|
311 |
+
# Clean
|
312 |
+
del past_key_values, out
|
313 |
+
gc.collect()
|
314 |
+
torch.cuda.empty_cache()
|
315 |
+
if device == "xpu":
|
316 |
+
torch.xpu.empty_cache()
|
317 |
+
if device == "npu":
|
318 |
+
torch.npu.empty_cache()
|
319 |
+
|
320 |
+
|
321 |
+
class ChatIO(abc.ABC):
|
322 |
+
@abc.abstractmethod
|
323 |
+
def prompt_for_input(self, role: str) -> str:
|
324 |
+
"""Prompt for input from a role."""
|
325 |
+
|
326 |
+
@abc.abstractmethod
|
327 |
+
def prompt_for_output(self, role: str):
|
328 |
+
"""Prompt for output from a role."""
|
329 |
+
|
330 |
+
@abc.abstractmethod
|
331 |
+
def stream_output(self, output_stream):
|
332 |
+
"""Stream output."""
|
333 |
+
|
334 |
+
@abc.abstractmethod
|
335 |
+
def print_output(self, text: str):
|
336 |
+
"""Print output."""
|
337 |
+
|
338 |
+
|
339 |
+
def convert_message_format(message):
|
340 |
+
formated_message = []
|
341 |
+
for i, turn in enumerate(message):
|
342 |
+
role = 'user' if i % 2 == 0 else 'assistant'
|
343 |
+
formated_message.append({'role': role, 'content': turn[1]})
|
344 |
+
|
345 |
+
data = {
|
346 |
+
'conversations': formated_message,
|
347 |
+
'idx': -1,
|
348 |
+
'tinder': 'badcase',
|
349 |
+
'model': '',
|
350 |
+
'tokens_in': 0,
|
351 |
+
'tokens_out': 0,
|
352 |
+
}
|
353 |
+
|
354 |
+
return data
|
355 |
+
|
356 |
+
|
357 |
+
def chat_loop(
|
358 |
+
model_path: str,
|
359 |
+
device: str,
|
360 |
+
num_gpus: int,
|
361 |
+
max_gpu_memory: str,
|
362 |
+
dtype: Optional[torch.dtype],
|
363 |
+
load_8bit: bool,
|
364 |
+
cpu_offloading: bool,
|
365 |
+
conv_template: Optional[str],
|
366 |
+
conv_system_msg: Optional[str],
|
367 |
+
temperature: float,
|
368 |
+
repetition_penalty: float,
|
369 |
+
max_new_tokens: int,
|
370 |
+
chatio: ChatIO,
|
371 |
+
gptq_config: Optional[GptqConfig] = None,
|
372 |
+
awq_config: Optional[AWQConfig] = None,
|
373 |
+
exllama_config: Optional[ExllamaConfig] = None,
|
374 |
+
xft_config: Optional[XftConfig] = None,
|
375 |
+
revision: str = "main",
|
376 |
+
judge_sent_end: bool = True,
|
377 |
+
debug: bool = True,
|
378 |
+
history: bool = True,
|
379 |
+
):
|
380 |
+
# Model
|
381 |
+
model, tokenizer = load_model(
|
382 |
+
model_path,
|
383 |
+
device=device,
|
384 |
+
num_gpus=num_gpus,
|
385 |
+
max_gpu_memory=max_gpu_memory,
|
386 |
+
dtype=dtype,
|
387 |
+
load_8bit=load_8bit,
|
388 |
+
cpu_offloading=cpu_offloading,
|
389 |
+
gptq_config=gptq_config,
|
390 |
+
awq_config=awq_config,
|
391 |
+
exllama_config=exllama_config,
|
392 |
+
xft_config=xft_config,
|
393 |
+
revision=revision,
|
394 |
+
debug=debug,
|
395 |
+
)
|
396 |
+
generate_stream_func = get_generate_stream_function(model, model_path)
|
397 |
+
|
398 |
+
model_type = str(type(model)).lower()
|
399 |
+
is_t5 = "t5" in model_type
|
400 |
+
is_codet5p = "codet5p" in model_type
|
401 |
+
is_xft = "xft" in model_type
|
402 |
+
|
403 |
+
# Hardcode T5's default repetition penalty to be 1.2
|
404 |
+
if is_t5 and repetition_penalty == 1.0:
|
405 |
+
repetition_penalty = 1.2
|
406 |
+
|
407 |
+
# Set context length
|
408 |
+
context_len = get_context_length(model.config)
|
409 |
+
|
410 |
+
# Chat
|
411 |
+
def new_chat():
|
412 |
+
if conv_template:
|
413 |
+
conv = get_conv_template(conv_template)
|
414 |
+
else:
|
415 |
+
conv = get_conversation_template(model_path)
|
416 |
+
if conv_system_msg is not None:
|
417 |
+
conv.set_system_message(conv_system_msg)
|
418 |
+
return conv
|
419 |
+
|
420 |
+
def reload_conv(conv):
|
421 |
+
"""
|
422 |
+
Reprints the conversation from the start.
|
423 |
+
"""
|
424 |
+
for message in conv.messages[conv.offset :]:
|
425 |
+
chatio.prompt_for_output(message[0])
|
426 |
+
chatio.print_output(message[1])
|
427 |
+
|
428 |
+
conv = None
|
429 |
+
|
430 |
+
while True:
|
431 |
+
if not history or not conv:
|
432 |
+
conv = new_chat()
|
433 |
+
|
434 |
+
try:
|
435 |
+
inp = chatio.prompt_for_input(conv.roles[0])
|
436 |
+
except EOFError:
|
437 |
+
inp = ""
|
438 |
+
|
439 |
+
if inp == "!!exit":# or not inp:
|
440 |
+
print("exit...")
|
441 |
+
break
|
442 |
+
elif inp == "!!reset":
|
443 |
+
print("resetting...")
|
444 |
+
conv = new_chat()
|
445 |
+
continue
|
446 |
+
elif inp == "!!remove":
|
447 |
+
print("removing last message...")
|
448 |
+
if len(conv.messages) > conv.offset:
|
449 |
+
# Assistant
|
450 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
451 |
+
conv.messages.pop()
|
452 |
+
# User
|
453 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
454 |
+
conv.messages.pop()
|
455 |
+
reload_conv(conv)
|
456 |
+
else:
|
457 |
+
print("No messages to remove.")
|
458 |
+
continue
|
459 |
+
elif inp == "!!regen":
|
460 |
+
print("regenerating last message...")
|
461 |
+
if len(conv.messages) > conv.offset:
|
462 |
+
# Assistant
|
463 |
+
if conv.messages[-1][0] == conv.roles[1]:
|
464 |
+
conv.messages.pop()
|
465 |
+
# User
|
466 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
467 |
+
reload_conv(conv)
|
468 |
+
# Set inp to previous message
|
469 |
+
inp = conv.messages.pop()[1]
|
470 |
+
else:
|
471 |
+
# Shouldn't happen in normal circumstances
|
472 |
+
print("No user message to regenerate from.")
|
473 |
+
continue
|
474 |
+
else:
|
475 |
+
print("No messages to regenerate.")
|
476 |
+
continue
|
477 |
+
elif inp.startswith("!!save"):
|
478 |
+
args = inp.split(" ", 1)
|
479 |
+
|
480 |
+
if len(args) != 2:
|
481 |
+
print("usage: !!save <filename>")
|
482 |
+
continue
|
483 |
+
else:
|
484 |
+
filename = args[1]
|
485 |
+
|
486 |
+
# Add .json if extension not present
|
487 |
+
if not "." in filename:
|
488 |
+
filename += ".json"
|
489 |
+
|
490 |
+
print("saving...", filename)
|
491 |
+
with open(filename, "w", encoding="utf-8") as outfile:
|
492 |
+
json.dump(conv.dict(), outfile, ensure_ascii=False)
|
493 |
+
continue
|
494 |
+
elif inp.startswith("!!badcase"):
|
495 |
+
args = inp.split(" ", 1)
|
496 |
+
|
497 |
+
if len(args) != 2:
|
498 |
+
print("usage: !!save <filename>")
|
499 |
+
continue
|
500 |
+
else:
|
501 |
+
filename = args[1]
|
502 |
+
|
503 |
+
# Add .json if extension not present
|
504 |
+
if not "." in filename:
|
505 |
+
filename += ".jsonl"
|
506 |
+
|
507 |
+
print("saving...", filename)
|
508 |
+
with open(filename, "a+", encoding="utf-8") as outfile:
|
509 |
+
data = convert_message_format(conv.messages)
|
510 |
+
json.dump(data, outfile, ensure_ascii=False)
|
511 |
+
outfile.write('\n')
|
512 |
+
continue
|
513 |
+
elif inp.startswith("!!load"):
|
514 |
+
args = inp.split(" ", 1)
|
515 |
+
|
516 |
+
if len(args) != 2:
|
517 |
+
print("usage: !!load <filename>")
|
518 |
+
continue
|
519 |
+
else:
|
520 |
+
filename = args[1]
|
521 |
+
|
522 |
+
# Check if file exists and add .json if needed
|
523 |
+
if not os.path.exists(filename):
|
524 |
+
if (not filename.endswith(".json")) and os.path.exists(
|
525 |
+
filename + ".json"
|
526 |
+
):
|
527 |
+
filename += ".json"
|
528 |
+
else:
|
529 |
+
print("file not found:", filename)
|
530 |
+
continue
|
531 |
+
|
532 |
+
print("loading...", filename)
|
533 |
+
with open(filename, "r") as infile:
|
534 |
+
new_conv = json.load(infile)
|
535 |
+
|
536 |
+
conv = get_conv_template(new_conv["template_name"])
|
537 |
+
conv.set_system_message(new_conv["system_message"])
|
538 |
+
conv.messages = new_conv["messages"]
|
539 |
+
reload_conv(conv)
|
540 |
+
continue
|
541 |
+
|
542 |
+
conv.append_message(conv.roles[0], inp)
|
543 |
+
conv.append_message(conv.roles[1], None)
|
544 |
+
prompt = conv.get_prompt(tokenizer)
|
545 |
+
|
546 |
+
if is_codet5p: # codet5p is a code completion model.
|
547 |
+
prompt = inp
|
548 |
+
|
549 |
+
gen_params = {
|
550 |
+
"model": model_path,
|
551 |
+
"prompt": prompt,
|
552 |
+
"temperature": temperature,
|
553 |
+
"repetition_penalty": repetition_penalty,
|
554 |
+
"max_new_tokens": max_new_tokens,
|
555 |
+
"stop": conv.stop_str,
|
556 |
+
"stop_token_ids": conv.stop_token_ids,
|
557 |
+
"none_stop": conv.none_stop,
|
558 |
+
"skip_special_tokens": conv.skip_special_tokens,
|
559 |
+
"echo": False,
|
560 |
+
}
|
561 |
+
|
562 |
+
try:
|
563 |
+
chatio.prompt_for_output(conv.roles[1])
|
564 |
+
output_stream = generate_stream_func(
|
565 |
+
model,
|
566 |
+
tokenizer,
|
567 |
+
gen_params,
|
568 |
+
device,
|
569 |
+
context_len=context_len,
|
570 |
+
judge_sent_end=judge_sent_end,
|
571 |
+
)
|
572 |
+
t = time.time()
|
573 |
+
outputs = chatio.stream_output(output_stream)
|
574 |
+
duration = time.time() - t
|
575 |
+
conv.update_last_message(outputs.strip())
|
576 |
+
|
577 |
+
if debug:
|
578 |
+
num_tokens = len(tokenizer.encode(outputs))
|
579 |
+
msg = {
|
580 |
+
"conv_template": conv.name,
|
581 |
+
"prompt": prompt,
|
582 |
+
"outputs": outputs,
|
583 |
+
"speed (token/s)": round(num_tokens / duration, 2),
|
584 |
+
}
|
585 |
+
print(f"\n{msg}\n")
|
586 |
+
|
587 |
+
except KeyboardInterrupt:
|
588 |
+
print("stopped generation.")
|
589 |
+
# If generation didn't finish
|
590 |
+
if conv.messages[-1][1] is None:
|
591 |
+
conv.messages.pop()
|
592 |
+
# Remove last user message, so there isn't a double up
|
593 |
+
if conv.messages[-1][0] == conv.roles[0]:
|
594 |
+
conv.messages.pop()
|
595 |
+
|
596 |
+
reload_conv(conv)
|
launch_all_serve.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Usage: python launch_all_serve_by_shell.py --model-path-address "THUDM/chatglm2-6b@localhost@2021" "huggyllama/llama-7b@localhost@2022"
|
3 |
+
|
4 |
+
Workers are listed in format of `model-path`@`host`@`port`
|
5 |
+
|
6 |
+
The key mechanism behind this scripts is:
|
7 |
+
1, execute shell cmd to launch the controller/worker/openai-api-server;
|
8 |
+
2, check the log of controller/worker/openai-api-server to ensure that the serve is launched properly.
|
9 |
+
Note that a few of non-critical `fastchat.serve` cmd options are not supported currently.
|
10 |
+
"""
|
11 |
+
import sys
|
12 |
+
import os
|
13 |
+
|
14 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
15 |
+
|
16 |
+
import subprocess
|
17 |
+
import re
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
LOGDIR = "./logs/"
|
21 |
+
|
22 |
+
if not os.path.exists(LOGDIR):
|
23 |
+
os.makedirs(LOGDIR)
|
24 |
+
|
25 |
+
parser = argparse.ArgumentParser()
|
26 |
+
# ------multi worker-----------------
|
27 |
+
parser.add_argument(
|
28 |
+
"--model-path-address",
|
29 |
+
default="THUDM/chatglm2-6b@localhost@20002",
|
30 |
+
nargs="+",
|
31 |
+
type=str,
|
32 |
+
help="model path, host, and port, formatted as model-path@host@port",
|
33 |
+
)
|
34 |
+
# ---------------controller-------------------------
|
35 |
+
|
36 |
+
parser.add_argument("--controller-host", type=str, default="localhost")
|
37 |
+
parser.add_argument("--controller-port", type=int, default=21001)
|
38 |
+
parser.add_argument(
|
39 |
+
"--dispatch-method",
|
40 |
+
type=str,
|
41 |
+
choices=["lottery", "shortest_queue"],
|
42 |
+
default="shortest_queue",
|
43 |
+
)
|
44 |
+
controller_args = ["controller-host", "controller-port", "dispatch-method"]
|
45 |
+
|
46 |
+
# ----------------------worker------------------------------------------
|
47 |
+
|
48 |
+
parser.add_argument("--worker-host", type=str, default="localhost")
|
49 |
+
parser.add_argument("--worker-port", type=int, default=21002)
|
50 |
+
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
51 |
+
# parser.add_argument(
|
52 |
+
# "--controller-address", type=str, default="http://localhost:21001"
|
53 |
+
# )
|
54 |
+
parser.add_argument(
|
55 |
+
"--model-path",
|
56 |
+
type=str,
|
57 |
+
default="lmsys/vicuna-7b-v1.5",
|
58 |
+
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
59 |
+
)
|
60 |
+
parser.add_argument(
|
61 |
+
"--revision",
|
62 |
+
type=str,
|
63 |
+
default="main",
|
64 |
+
help="Hugging Face Hub model revision identifier",
|
65 |
+
)
|
66 |
+
parser.add_argument(
|
67 |
+
"--device",
|
68 |
+
type=str,
|
69 |
+
choices=["cpu", "cuda", "mps", "xpu", "npu"],
|
70 |
+
default="cuda",
|
71 |
+
help="The device type",
|
72 |
+
)
|
73 |
+
parser.add_argument(
|
74 |
+
"--gpus",
|
75 |
+
type=str,
|
76 |
+
default="0",
|
77 |
+
help="A single GPU like 1 or multiple GPUs like 0,2",
|
78 |
+
)
|
79 |
+
parser.add_argument("--num-gpus", type=int, default=1)
|
80 |
+
parser.add_argument(
|
81 |
+
"--max-gpu-memory",
|
82 |
+
type=str,
|
83 |
+
help="The maximum memory per gpu. Use a string like '13Gib'",
|
84 |
+
)
|
85 |
+
parser.add_argument("--load-8bit", action="store_true", help="Use 8-bit quantization")
|
86 |
+
parser.add_argument(
|
87 |
+
"--cpu-offloading",
|
88 |
+
action="store_true",
|
89 |
+
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
90 |
+
)
|
91 |
+
parser.add_argument(
|
92 |
+
"--gptq-ckpt",
|
93 |
+
type=str,
|
94 |
+
default=None,
|
95 |
+
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
96 |
+
)
|
97 |
+
parser.add_argument(
|
98 |
+
"--gptq-wbits",
|
99 |
+
type=int,
|
100 |
+
default=16,
|
101 |
+
choices=[2, 3, 4, 8, 16],
|
102 |
+
help="#bits to use for quantization",
|
103 |
+
)
|
104 |
+
parser.add_argument(
|
105 |
+
"--gptq-groupsize",
|
106 |
+
type=int,
|
107 |
+
default=-1,
|
108 |
+
help="Groupsize to use for quantization; default uses full row.",
|
109 |
+
)
|
110 |
+
parser.add_argument(
|
111 |
+
"--gptq-act-order",
|
112 |
+
action="store_true",
|
113 |
+
help="Whether to apply the activation order GPTQ heuristic",
|
114 |
+
)
|
115 |
+
parser.add_argument(
|
116 |
+
"--model-names",
|
117 |
+
type=lambda s: s.split(","),
|
118 |
+
help="Optional display comma separated names",
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--limit-worker-concurrency",
|
122 |
+
type=int,
|
123 |
+
default=5,
|
124 |
+
help="Limit the model concurrency to prevent OOM.",
|
125 |
+
)
|
126 |
+
parser.add_argument("--stream-interval", type=int, default=2)
|
127 |
+
parser.add_argument("--no-register", action="store_true")
|
128 |
+
|
129 |
+
worker_args = [
|
130 |
+
"worker-host",
|
131 |
+
"worker-port",
|
132 |
+
"model-path",
|
133 |
+
"revision",
|
134 |
+
"device",
|
135 |
+
"gpus",
|
136 |
+
"num-gpus",
|
137 |
+
"max-gpu-memory",
|
138 |
+
"load-8bit",
|
139 |
+
"cpu-offloading",
|
140 |
+
"gptq-ckpt",
|
141 |
+
"gptq-wbits",
|
142 |
+
"gptq-groupsize",
|
143 |
+
"gptq-act-order",
|
144 |
+
"model-names",
|
145 |
+
"limit-worker-concurrency",
|
146 |
+
"stream-interval",
|
147 |
+
"no-register",
|
148 |
+
"controller-address",
|
149 |
+
]
|
150 |
+
# -----------------openai server---------------------------
|
151 |
+
|
152 |
+
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
153 |
+
parser.add_argument("--server-port", type=int, default=8001, help="port number")
|
154 |
+
parser.add_argument(
|
155 |
+
"--allow-credentials", action="store_true", help="allow credentials"
|
156 |
+
)
|
157 |
+
# parser.add_argument(
|
158 |
+
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
159 |
+
# )
|
160 |
+
# parser.add_argument(
|
161 |
+
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
162 |
+
# )
|
163 |
+
# parser.add_argument(
|
164 |
+
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
165 |
+
# )
|
166 |
+
parser.add_argument(
|
167 |
+
"--api-keys",
|
168 |
+
type=lambda s: s.split(","),
|
169 |
+
help="Optional list of comma separated API keys",
|
170 |
+
)
|
171 |
+
server_args = [
|
172 |
+
"server-host",
|
173 |
+
"server-port",
|
174 |
+
"allow-credentials",
|
175 |
+
"api-keys",
|
176 |
+
"controller-address",
|
177 |
+
]
|
178 |
+
|
179 |
+
args = parser.parse_args()
|
180 |
+
|
181 |
+
args = argparse.Namespace(
|
182 |
+
**vars(args),
|
183 |
+
**{"controller-address": f"http://{args.controller_host}:{args.controller_port}"},
|
184 |
+
)
|
185 |
+
|
186 |
+
if args.gpus:
|
187 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
188 |
+
raise ValueError(
|
189 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
190 |
+
)
|
191 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
192 |
+
|
193 |
+
# 0,controller, model_worker, openai_api_server
|
194 |
+
# 1, cmd options
|
195 |
+
# 2,LOGDIR
|
196 |
+
# 3, log file name
|
197 |
+
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
|
198 |
+
|
199 |
+
# 0 LOGDIR
|
200 |
+
#! 1 log file name
|
201 |
+
# 2 controller, worker, openai_api_server
|
202 |
+
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
|
203 |
+
sleep 1s;
|
204 |
+
echo "wait {2} running"
|
205 |
+
done
|
206 |
+
echo '{2} running' """
|
207 |
+
|
208 |
+
|
209 |
+
def string_args(args, args_list):
|
210 |
+
args_str = ""
|
211 |
+
for key, value in args._get_kwargs():
|
212 |
+
key = key.replace("_", "-")
|
213 |
+
if key not in args_list:
|
214 |
+
continue
|
215 |
+
|
216 |
+
key = key.split("-")[-1] if re.search("port|host", key) else key
|
217 |
+
if not value:
|
218 |
+
pass
|
219 |
+
# 1==True -> True
|
220 |
+
elif isinstance(value, bool) and value == True:
|
221 |
+
args_str += f" --{key} "
|
222 |
+
elif (
|
223 |
+
isinstance(value, list)
|
224 |
+
or isinstance(value, tuple)
|
225 |
+
or isinstance(value, set)
|
226 |
+
):
|
227 |
+
value = " ".join(value)
|
228 |
+
args_str += f" --{key} {value} "
|
229 |
+
else:
|
230 |
+
args_str += f" --{key} {value} "
|
231 |
+
|
232 |
+
return args_str
|
233 |
+
|
234 |
+
|
235 |
+
def launch_worker(item):
|
236 |
+
log_name = (
|
237 |
+
item.split("/")[-1]
|
238 |
+
.split("\\")[-1]
|
239 |
+
.replace("-", "_")
|
240 |
+
.replace("@", "_")
|
241 |
+
.replace(".", "_")
|
242 |
+
)
|
243 |
+
|
244 |
+
args.model_path, args.worker_host, args.worker_port = item.split("@")
|
245 |
+
print("*" * 80)
|
246 |
+
worker_str_args = string_args(args, worker_args)
|
247 |
+
print(worker_str_args)
|
248 |
+
worker_sh = base_launch_sh.format(
|
249 |
+
"model_worker", worker_str_args, LOGDIR, f"worker_{log_name}"
|
250 |
+
)
|
251 |
+
worker_check_sh = base_check_sh.format(LOGDIR, f"worker_{log_name}", "model_worker")
|
252 |
+
subprocess.run(worker_sh, shell=True, check=True)
|
253 |
+
subprocess.run(worker_check_sh, shell=True, check=True)
|
254 |
+
|
255 |
+
|
256 |
+
def launch_all():
|
257 |
+
controller_str_args = string_args(args, controller_args)
|
258 |
+
controller_sh = base_launch_sh.format(
|
259 |
+
"controller", controller_str_args, LOGDIR, "controller"
|
260 |
+
)
|
261 |
+
controller_check_sh = base_check_sh.format(LOGDIR, "controller", "controller")
|
262 |
+
subprocess.run(controller_sh, shell=True, check=True)
|
263 |
+
subprocess.run(controller_check_sh, shell=True, check=True)
|
264 |
+
|
265 |
+
if isinstance(args.model_path_address, str):
|
266 |
+
launch_worker(args.model_path_address)
|
267 |
+
else:
|
268 |
+
for idx, item in enumerate(args.model_path_address):
|
269 |
+
print(f"loading {idx}th model:{item}")
|
270 |
+
launch_worker(item)
|
271 |
+
|
272 |
+
server_str_args = string_args(args, server_args)
|
273 |
+
server_sh = base_launch_sh.format(
|
274 |
+
"openai_api_server", server_str_args, LOGDIR, "openai_api_server"
|
275 |
+
)
|
276 |
+
server_check_sh = base_check_sh.format(
|
277 |
+
LOGDIR, "openai_api_server", "openai_api_server"
|
278 |
+
)
|
279 |
+
subprocess.run(server_sh, shell=True, check=True)
|
280 |
+
subprocess.run(server_check_sh, shell=True, check=True)
|
281 |
+
|
282 |
+
|
283 |
+
if __name__ == "__main__":
|
284 |
+
launch_all()
|
model_worker.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker that executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import base64
|
6 |
+
import gc
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from typing import List, Optional
|
10 |
+
import uuid
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn.functional as F
|
14 |
+
from transformers import set_seed
|
15 |
+
import uvicorn
|
16 |
+
|
17 |
+
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
|
18 |
+
from fastchat.model.model_adapter import (
|
19 |
+
load_model,
|
20 |
+
add_model_args,
|
21 |
+
get_generate_stream_function,
|
22 |
+
)
|
23 |
+
from fastchat.modules.awq import AWQConfig
|
24 |
+
from fastchat.modules.exllama import ExllamaConfig
|
25 |
+
from fastchat.modules.xfastertransformer import XftConfig
|
26 |
+
from fastchat.modules.gptq import GptqConfig
|
27 |
+
from fastchat.serve.base_model_worker import BaseModelWorker, app
|
28 |
+
from fastchat.utils import (
|
29 |
+
build_logger,
|
30 |
+
get_context_length,
|
31 |
+
str_to_torch_dtype,
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
worker_id = str(uuid.uuid4())[:8]
|
36 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
37 |
+
|
38 |
+
|
39 |
+
class ModelWorker(BaseModelWorker):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
controller_addr: str,
|
43 |
+
worker_addr: str,
|
44 |
+
worker_id: str,
|
45 |
+
model_path: str,
|
46 |
+
model_names: List[str],
|
47 |
+
limit_worker_concurrency: int,
|
48 |
+
no_register: bool,
|
49 |
+
device: str,
|
50 |
+
num_gpus: int,
|
51 |
+
max_gpu_memory: str,
|
52 |
+
dtype: Optional[torch.dtype] = None,
|
53 |
+
load_8bit: bool = False,
|
54 |
+
cpu_offloading: bool = False,
|
55 |
+
gptq_config: Optional[GptqConfig] = None,
|
56 |
+
awq_config: Optional[AWQConfig] = None,
|
57 |
+
exllama_config: Optional[ExllamaConfig] = None,
|
58 |
+
xft_config: Optional[XftConfig] = None,
|
59 |
+
stream_interval: int = 2,
|
60 |
+
conv_template: Optional[str] = None,
|
61 |
+
embed_in_truncate: bool = False,
|
62 |
+
seed: Optional[int] = None,
|
63 |
+
debug: bool = False,
|
64 |
+
**kwargs,
|
65 |
+
):
|
66 |
+
super().__init__(
|
67 |
+
controller_addr,
|
68 |
+
worker_addr,
|
69 |
+
worker_id,
|
70 |
+
model_path,
|
71 |
+
model_names,
|
72 |
+
limit_worker_concurrency,
|
73 |
+
conv_template=conv_template,
|
74 |
+
)
|
75 |
+
|
76 |
+
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
|
77 |
+
self.model, self.tokenizer = load_model(
|
78 |
+
model_path,
|
79 |
+
device=device,
|
80 |
+
num_gpus=num_gpus,
|
81 |
+
max_gpu_memory=max_gpu_memory,
|
82 |
+
dtype=dtype,
|
83 |
+
load_8bit=load_8bit,
|
84 |
+
cpu_offloading=cpu_offloading,
|
85 |
+
gptq_config=gptq_config,
|
86 |
+
awq_config=awq_config,
|
87 |
+
exllama_config=exllama_config,
|
88 |
+
xft_config=xft_config,
|
89 |
+
debug=debug,
|
90 |
+
model_name=model_names[0],
|
91 |
+
)
|
92 |
+
self.device = device
|
93 |
+
if self.tokenizer.pad_token == None:
|
94 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
95 |
+
self.context_len = get_context_length(self.model.config)
|
96 |
+
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
|
97 |
+
self.stream_interval = stream_interval
|
98 |
+
self.embed_in_truncate = embed_in_truncate
|
99 |
+
self.seed = seed
|
100 |
+
|
101 |
+
if not no_register:
|
102 |
+
self.init_heart_beat()
|
103 |
+
|
104 |
+
def generate_stream_gate(self, params):
|
105 |
+
self.call_ct += 1
|
106 |
+
|
107 |
+
try:
|
108 |
+
if self.seed is not None:
|
109 |
+
set_seed(self.seed)
|
110 |
+
for output in self.generate_stream_func(
|
111 |
+
self.model,
|
112 |
+
self.tokenizer,
|
113 |
+
params,
|
114 |
+
self.device,
|
115 |
+
self.context_len,
|
116 |
+
self.stream_interval,
|
117 |
+
):
|
118 |
+
ret = {
|
119 |
+
"text": output["text"],
|
120 |
+
"error_code": 0,
|
121 |
+
}
|
122 |
+
if "usage" in output:
|
123 |
+
ret["usage"] = output["usage"]
|
124 |
+
if "finish_reason" in output:
|
125 |
+
ret["finish_reason"] = output["finish_reason"]
|
126 |
+
if "logprobs" in output:
|
127 |
+
ret["logprobs"] = output["logprobs"]
|
128 |
+
yield json.dumps(ret).encode() + b"\0"
|
129 |
+
except torch.cuda.OutOfMemoryError as e:
|
130 |
+
ret = {
|
131 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
132 |
+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
133 |
+
}
|
134 |
+
yield json.dumps(ret).encode() + b"\0"
|
135 |
+
except (ValueError, RuntimeError) as e:
|
136 |
+
ret = {
|
137 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
138 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
139 |
+
}
|
140 |
+
yield json.dumps(ret).encode() + b"\0"
|
141 |
+
|
142 |
+
def generate_gate(self, params):
|
143 |
+
for x in self.generate_stream_gate(params):
|
144 |
+
pass
|
145 |
+
return json.loads(x[:-1].decode())
|
146 |
+
|
147 |
+
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
|
148 |
+
if model_type_dict.get("is_bert"):
|
149 |
+
model_output = self.model(input_ids)
|
150 |
+
if model_type_dict.get("is_robert"):
|
151 |
+
data = model_output.last_hidden_state
|
152 |
+
else:
|
153 |
+
data = model_output[0]
|
154 |
+
elif model_type_dict.get("is_t5"):
|
155 |
+
model_output = self.model(input_ids, decoder_input_ids=input_ids)
|
156 |
+
data = model_output.encoder_last_hidden_state
|
157 |
+
else:
|
158 |
+
model_output = self.model(input_ids, output_hidden_states=True)
|
159 |
+
if model_type_dict.get("is_chatglm"):
|
160 |
+
data = model_output.hidden_states[-1].transpose(0, 1)
|
161 |
+
else:
|
162 |
+
data = model_output.hidden_states[-1]
|
163 |
+
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
|
164 |
+
masked_embeddings = data * mask
|
165 |
+
sum_embeddings = torch.sum(masked_embeddings, dim=1)
|
166 |
+
token_num = torch.sum(attention_mask).item()
|
167 |
+
|
168 |
+
return sum_embeddings, token_num
|
169 |
+
|
170 |
+
def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
|
171 |
+
embeddings = embeddings.cpu()
|
172 |
+
return [
|
173 |
+
base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
|
174 |
+
]
|
175 |
+
|
176 |
+
@torch.inference_mode()
|
177 |
+
def get_embeddings(self, params):
|
178 |
+
self.call_ct += 1
|
179 |
+
|
180 |
+
try:
|
181 |
+
tokenizer = self.tokenizer
|
182 |
+
ret = {"embedding": [], "token_num": 0}
|
183 |
+
|
184 |
+
model_type_dict = {
|
185 |
+
"is_llama": "llama" in str(type(self.model)),
|
186 |
+
"is_t5": "t5" in str(type(self.model)),
|
187 |
+
"is_chatglm": "chatglm" in str(type(self.model)),
|
188 |
+
"is_bert": "bert" in str(type(self.model)),
|
189 |
+
"is_robert": "robert" in str(type(self.model)),
|
190 |
+
}
|
191 |
+
|
192 |
+
if self.embed_in_truncate:
|
193 |
+
encoding = tokenizer.batch_encode_plus(
|
194 |
+
params["input"],
|
195 |
+
padding=True,
|
196 |
+
truncation="longest_first",
|
197 |
+
return_tensors="pt",
|
198 |
+
max_length=self.context_len,
|
199 |
+
)
|
200 |
+
else:
|
201 |
+
encoding = tokenizer.batch_encode_plus(
|
202 |
+
params["input"], padding=True, return_tensors="pt"
|
203 |
+
)
|
204 |
+
input_ids = encoding["input_ids"].to(self.device)
|
205 |
+
attention_mask = input_ids != tokenizer.pad_token_id
|
206 |
+
|
207 |
+
base64_encode = params.get("encoding_format", None)
|
208 |
+
|
209 |
+
if self.embed_in_truncate:
|
210 |
+
chunk_embeddings, token_num = self.__process_embed_chunk(
|
211 |
+
input_ids, attention_mask, **model_type_dict
|
212 |
+
)
|
213 |
+
embedding = chunk_embeddings / token_num
|
214 |
+
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
215 |
+
ret["token_num"] = token_num
|
216 |
+
else:
|
217 |
+
all_embeddings = []
|
218 |
+
all_token_num = 0
|
219 |
+
for i in range(0, input_ids.size(1), self.context_len):
|
220 |
+
chunk_input_ids = input_ids[:, i : i + self.context_len]
|
221 |
+
chunk_attention_mask = attention_mask[:, i : i + self.context_len]
|
222 |
+
|
223 |
+
chunk_embeddings, token_num = self.__process_embed_chunk(
|
224 |
+
chunk_input_ids, chunk_attention_mask, **model_type_dict
|
225 |
+
)
|
226 |
+
all_embeddings.append(chunk_embeddings)
|
227 |
+
all_token_num += token_num
|
228 |
+
|
229 |
+
all_embeddings_tensor = torch.stack(all_embeddings)
|
230 |
+
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
|
231 |
+
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
|
232 |
+
|
233 |
+
ret["token_num"] = all_token_num
|
234 |
+
|
235 |
+
if base64_encode == "base64":
|
236 |
+
out_embeddings = self.__encode_base64(normalized_embeddings)
|
237 |
+
else:
|
238 |
+
out_embeddings = normalized_embeddings.tolist()
|
239 |
+
ret["embedding"] = out_embeddings
|
240 |
+
|
241 |
+
gc.collect()
|
242 |
+
torch.cuda.empty_cache()
|
243 |
+
if self.device == "xpu":
|
244 |
+
torch.xpu.empty_cache()
|
245 |
+
if self.device == "npu":
|
246 |
+
torch.npu.empty_cache()
|
247 |
+
except torch.cuda.OutOfMemoryError as e:
|
248 |
+
ret = {
|
249 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
250 |
+
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
|
251 |
+
}
|
252 |
+
except (ValueError, RuntimeError) as e:
|
253 |
+
ret = {
|
254 |
+
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
|
255 |
+
"error_code": ErrorCode.INTERNAL_ERROR,
|
256 |
+
}
|
257 |
+
return ret
|
258 |
+
|
259 |
+
|
260 |
+
def create_model_worker():
|
261 |
+
parser = argparse.ArgumentParser()
|
262 |
+
parser.add_argument("--host", type=str, default="localhost")
|
263 |
+
parser.add_argument("--port", type=int, default=21002)
|
264 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
265 |
+
parser.add_argument(
|
266 |
+
"--controller-address", type=str, default="http://localhost:21001"
|
267 |
+
)
|
268 |
+
add_model_args(parser)
|
269 |
+
parser.add_argument(
|
270 |
+
"--model-names",
|
271 |
+
type=lambda s: s.split(","),
|
272 |
+
help="Optional display comma separated names",
|
273 |
+
)
|
274 |
+
parser.add_argument(
|
275 |
+
"--conv-template", type=str, default=None, help="Conversation prompt template."
|
276 |
+
)
|
277 |
+
parser.add_argument("--embed-in-truncate", action="store_true")
|
278 |
+
parser.add_argument(
|
279 |
+
"--limit-worker-concurrency",
|
280 |
+
type=int,
|
281 |
+
default=5,
|
282 |
+
help="Limit the model concurrency to prevent OOM.",
|
283 |
+
)
|
284 |
+
parser.add_argument("--stream-interval", type=int, default=2)
|
285 |
+
parser.add_argument("--no-register", action="store_true")
|
286 |
+
parser.add_argument(
|
287 |
+
"--seed",
|
288 |
+
type=int,
|
289 |
+
default=None,
|
290 |
+
help="Overwrite the random seed for each generation.",
|
291 |
+
)
|
292 |
+
parser.add_argument(
|
293 |
+
"--debug", type=bool, default=False, help="Print debugging messages"
|
294 |
+
)
|
295 |
+
args = parser.parse_args()
|
296 |
+
logger.info(f"args: {args}")
|
297 |
+
|
298 |
+
if args.gpus:
|
299 |
+
if len(args.gpus.split(",")) < args.num_gpus:
|
300 |
+
raise ValueError(
|
301 |
+
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
302 |
+
)
|
303 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
304 |
+
|
305 |
+
gptq_config = GptqConfig(
|
306 |
+
ckpt=args.gptq_ckpt or args.model_path,
|
307 |
+
wbits=args.gptq_wbits,
|
308 |
+
groupsize=args.gptq_groupsize,
|
309 |
+
act_order=args.gptq_act_order,
|
310 |
+
)
|
311 |
+
awq_config = AWQConfig(
|
312 |
+
ckpt=args.awq_ckpt or args.model_path,
|
313 |
+
wbits=args.awq_wbits,
|
314 |
+
groupsize=args.awq_groupsize,
|
315 |
+
)
|
316 |
+
if args.enable_exllama:
|
317 |
+
exllama_config = ExllamaConfig(
|
318 |
+
max_seq_len=args.exllama_max_seq_len,
|
319 |
+
gpu_split=args.exllama_gpu_split,
|
320 |
+
)
|
321 |
+
else:
|
322 |
+
exllama_config = None
|
323 |
+
if args.enable_xft:
|
324 |
+
xft_config = XftConfig(
|
325 |
+
max_seq_len=args.xft_max_seq_len,
|
326 |
+
data_type=args.xft_dtype,
|
327 |
+
)
|
328 |
+
if args.device != "cpu":
|
329 |
+
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
|
330 |
+
args.device = "cpu"
|
331 |
+
else:
|
332 |
+
xft_config = None
|
333 |
+
|
334 |
+
worker = ModelWorker(
|
335 |
+
args.controller_address,
|
336 |
+
args.worker_address,
|
337 |
+
worker_id,
|
338 |
+
args.model_path,
|
339 |
+
args.model_names,
|
340 |
+
args.limit_worker_concurrency,
|
341 |
+
no_register=args.no_register,
|
342 |
+
device=args.device,
|
343 |
+
num_gpus=args.num_gpus,
|
344 |
+
max_gpu_memory=args.max_gpu_memory,
|
345 |
+
dtype=str_to_torch_dtype(args.dtype),
|
346 |
+
load_8bit=args.load_8bit,
|
347 |
+
cpu_offloading=args.cpu_offloading,
|
348 |
+
gptq_config=gptq_config,
|
349 |
+
awq_config=awq_config,
|
350 |
+
exllama_config=exllama_config,
|
351 |
+
xft_config=xft_config,
|
352 |
+
stream_interval=args.stream_interval,
|
353 |
+
conv_template=args.conv_template,
|
354 |
+
embed_in_truncate=args.embed_in_truncate,
|
355 |
+
seed=args.seed,
|
356 |
+
debug=args.debug,
|
357 |
+
)
|
358 |
+
return args, worker
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == "__main__":
|
362 |
+
args, worker = create_model_worker()
|
363 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
monitor/basic_stats.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import code
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pytz import timezone
|
7 |
+
import time
|
8 |
+
|
9 |
+
import pandas as pd # pandas>=2.0.3
|
10 |
+
import plotly.express as px
|
11 |
+
import plotly.graph_objects as go
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
NUM_SERVERS = 14
|
16 |
+
|
17 |
+
|
18 |
+
def get_log_files(max_num_files=None):
|
19 |
+
dates = []
|
20 |
+
for month in range(4, 12):
|
21 |
+
for day in range(1, 33):
|
22 |
+
dates.append(f"2023-{month:02d}-{day:02d}")
|
23 |
+
|
24 |
+
filenames = []
|
25 |
+
for d in dates:
|
26 |
+
for i in range(NUM_SERVERS):
|
27 |
+
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
28 |
+
if os.path.exists(name):
|
29 |
+
filenames.append(name)
|
30 |
+
max_num_files = max_num_files or len(filenames)
|
31 |
+
filenames = filenames[-max_num_files:]
|
32 |
+
return filenames
|
33 |
+
|
34 |
+
|
35 |
+
def load_log_files(log_files):
|
36 |
+
data = []
|
37 |
+
for filename in tqdm(log_files, desc="read files"):
|
38 |
+
for retry in range(5):
|
39 |
+
try:
|
40 |
+
lines = open(filename).readlines()
|
41 |
+
break
|
42 |
+
except FileNotFoundError:
|
43 |
+
time.sleep(2)
|
44 |
+
|
45 |
+
for l in lines:
|
46 |
+
row = json.loads(l)
|
47 |
+
|
48 |
+
data.append(
|
49 |
+
dict(
|
50 |
+
type=row["type"],
|
51 |
+
tstamp=row["tstamp"],
|
52 |
+
model=row.get("model", ""),
|
53 |
+
models=row.get("models", ["", ""]),
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
return data
|
58 |
+
|
59 |
+
|
60 |
+
def get_anony_vote_df(df):
|
61 |
+
anony_vote_df = df[
|
62 |
+
df["type"].isin(["leftvote", "rightvote", "tievote", "bothbad_vote"])
|
63 |
+
]
|
64 |
+
anony_vote_df = anony_vote_df[anony_vote_df["models"].apply(lambda x: x[0] == "")]
|
65 |
+
return anony_vote_df
|
66 |
+
|
67 |
+
|
68 |
+
def merge_counts(series, on, names):
|
69 |
+
ret = pd.merge(series[0], series[1], on=on)
|
70 |
+
for i in range(2, len(series)):
|
71 |
+
ret = pd.merge(ret, series[i], on=on)
|
72 |
+
ret = ret.reset_index()
|
73 |
+
old_names = list(ret.columns)[-len(series) :]
|
74 |
+
rename = {old_name: new_name for old_name, new_name in zip(old_names, names)}
|
75 |
+
ret = ret.rename(columns=rename)
|
76 |
+
return ret
|
77 |
+
|
78 |
+
|
79 |
+
def report_basic_stats(log_files):
|
80 |
+
df_all = load_log_files(log_files)
|
81 |
+
df_all = pd.DataFrame(df_all)
|
82 |
+
now_t = df_all["tstamp"].max()
|
83 |
+
df_1_hour = df_all[df_all["tstamp"] > (now_t - 3600)]
|
84 |
+
df_1_day = df_all[df_all["tstamp"] > (now_t - 3600 * 24)]
|
85 |
+
anony_vote_df_all = get_anony_vote_df(df_all)
|
86 |
+
|
87 |
+
# Chat trends
|
88 |
+
chat_dates = [
|
89 |
+
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
|
90 |
+
"%Y-%m-%d"
|
91 |
+
)
|
92 |
+
for x in df_all[df_all["type"] == "chat"]["tstamp"]
|
93 |
+
]
|
94 |
+
chat_dates_counts = pd.value_counts(chat_dates)
|
95 |
+
vote_dates = [
|
96 |
+
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime(
|
97 |
+
"%Y-%m-%d"
|
98 |
+
)
|
99 |
+
for x in anony_vote_df_all["tstamp"]
|
100 |
+
]
|
101 |
+
vote_dates_counts = pd.value_counts(vote_dates)
|
102 |
+
chat_dates_bar = go.Figure(
|
103 |
+
data=[
|
104 |
+
go.Bar(
|
105 |
+
name="Anony. Vote",
|
106 |
+
x=vote_dates_counts.index,
|
107 |
+
y=vote_dates_counts,
|
108 |
+
text=[f"{val:.0f}" for val in vote_dates_counts],
|
109 |
+
textposition="auto",
|
110 |
+
),
|
111 |
+
go.Bar(
|
112 |
+
name="Chat",
|
113 |
+
x=chat_dates_counts.index,
|
114 |
+
y=chat_dates_counts,
|
115 |
+
text=[f"{val:.0f}" for val in chat_dates_counts],
|
116 |
+
textposition="auto",
|
117 |
+
),
|
118 |
+
]
|
119 |
+
)
|
120 |
+
chat_dates_bar.update_layout(
|
121 |
+
barmode="stack",
|
122 |
+
xaxis_title="Dates",
|
123 |
+
yaxis_title="Count",
|
124 |
+
height=300,
|
125 |
+
width=1200,
|
126 |
+
)
|
127 |
+
|
128 |
+
# Model call counts
|
129 |
+
model_hist_all = df_all[df_all["type"] == "chat"]["model"].value_counts()
|
130 |
+
model_hist_1_day = df_1_day[df_1_day["type"] == "chat"]["model"].value_counts()
|
131 |
+
model_hist_1_hour = df_1_hour[df_1_hour["type"] == "chat"]["model"].value_counts()
|
132 |
+
model_hist = merge_counts(
|
133 |
+
[model_hist_all, model_hist_1_day, model_hist_1_hour],
|
134 |
+
on="model",
|
135 |
+
names=["All", "Last Day", "Last Hour"],
|
136 |
+
)
|
137 |
+
model_hist_md = model_hist.to_markdown(index=False, tablefmt="github")
|
138 |
+
|
139 |
+
# Action counts
|
140 |
+
action_hist_all = df_all["type"].value_counts()
|
141 |
+
action_hist_1_day = df_1_day["type"].value_counts()
|
142 |
+
action_hist_1_hour = df_1_hour["type"].value_counts()
|
143 |
+
action_hist = merge_counts(
|
144 |
+
[action_hist_all, action_hist_1_day, action_hist_1_hour],
|
145 |
+
on="type",
|
146 |
+
names=["All", "Last Day", "Last Hour"],
|
147 |
+
)
|
148 |
+
action_hist_md = action_hist.to_markdown(index=False, tablefmt="github")
|
149 |
+
|
150 |
+
# Anony vote counts
|
151 |
+
anony_vote_hist_all = anony_vote_df_all["type"].value_counts()
|
152 |
+
anony_vote_df_1_day = get_anony_vote_df(df_1_day)
|
153 |
+
anony_vote_hist_1_day = anony_vote_df_1_day["type"].value_counts()
|
154 |
+
# anony_vote_df_1_hour = get_anony_vote_df(df_1_hour)
|
155 |
+
# anony_vote_hist_1_hour = anony_vote_df_1_hour["type"].value_counts()
|
156 |
+
anony_vote_hist = merge_counts(
|
157 |
+
[anony_vote_hist_all, anony_vote_hist_1_day],
|
158 |
+
on="type",
|
159 |
+
names=["All", "Last Day"],
|
160 |
+
)
|
161 |
+
anony_vote_hist_md = anony_vote_hist.to_markdown(index=False, tablefmt="github")
|
162 |
+
|
163 |
+
# Last 24 hours
|
164 |
+
chat_1_day = df_1_day[df_1_day["type"] == "chat"]
|
165 |
+
num_chats_last_24_hours = []
|
166 |
+
base = df_1_day["tstamp"].min()
|
167 |
+
for i in range(24, 0, -1):
|
168 |
+
left = base + (i - 1) * 3600
|
169 |
+
right = base + i * 3600
|
170 |
+
num = ((chat_1_day["tstamp"] >= left) & (chat_1_day["tstamp"] < right)).sum()
|
171 |
+
num_chats_last_24_hours.append(num)
|
172 |
+
times = [
|
173 |
+
datetime.datetime.fromtimestamp(
|
174 |
+
base + i * 3600, tz=timezone("US/Pacific")
|
175 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
176 |
+
for i in range(24, 0, -1)
|
177 |
+
]
|
178 |
+
last_24_hours_df = pd.DataFrame({"time": times, "value": num_chats_last_24_hours})
|
179 |
+
last_24_hours_md = last_24_hours_df.to_markdown(index=False, tablefmt="github")
|
180 |
+
|
181 |
+
# Last update datetime
|
182 |
+
last_updated_tstamp = now_t
|
183 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
184 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
185 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
186 |
+
|
187 |
+
# code.interact(local=locals())
|
188 |
+
|
189 |
+
return {
|
190 |
+
"chat_dates_bar": chat_dates_bar,
|
191 |
+
"model_hist_md": model_hist_md,
|
192 |
+
"action_hist_md": action_hist_md,
|
193 |
+
"anony_vote_hist_md": anony_vote_hist_md,
|
194 |
+
"num_chats_last_24_hours": last_24_hours_md,
|
195 |
+
"last_updated_datetime": last_updated_datetime,
|
196 |
+
}
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
parser = argparse.ArgumentParser()
|
201 |
+
parser.add_argument("--max-num-files", type=int)
|
202 |
+
args = parser.parse_args()
|
203 |
+
|
204 |
+
log_files = get_log_files(args.max_num_files)
|
205 |
+
basic_stats = report_basic_stats(log_files)
|
206 |
+
|
207 |
+
print(basic_stats["action_hist_md"] + "\n")
|
208 |
+
print(basic_stats["model_hist_md"] + "\n")
|
209 |
+
print(basic_stats["anony_vote_hist_md"] + "\n")
|
210 |
+
print(basic_stats["num_chats_last_24_hours"] + "\n")
|
monitor/clean_battle_data.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Clean chatbot arena battle log.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 clean_battle_data.py --mode conv_release
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import datetime
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
from pytz import timezone
|
12 |
+
import time
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from fastchat.serve.monitor.basic_stats import get_log_files, NUM_SERVERS
|
17 |
+
from fastchat.utils import detect_language
|
18 |
+
|
19 |
+
|
20 |
+
VOTES = ["tievote", "leftvote", "rightvote", "bothbad_vote"]
|
21 |
+
IDENTITY_WORDS = [
|
22 |
+
"vicuna",
|
23 |
+
"lmsys",
|
24 |
+
"koala",
|
25 |
+
"uc berkeley",
|
26 |
+
"open assistant",
|
27 |
+
"laion",
|
28 |
+
"chatglm",
|
29 |
+
"chatgpt",
|
30 |
+
"openai",
|
31 |
+
"anthropic",
|
32 |
+
"claude",
|
33 |
+
"bard",
|
34 |
+
"palm",
|
35 |
+
"lamda",
|
36 |
+
"google",
|
37 |
+
"llama",
|
38 |
+
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.",
|
39 |
+
"$MODERATION$ YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES.",
|
40 |
+
]
|
41 |
+
|
42 |
+
for i in range(len(IDENTITY_WORDS)):
|
43 |
+
IDENTITY_WORDS[i] = IDENTITY_WORDS[i].lower()
|
44 |
+
|
45 |
+
|
46 |
+
def get_log_files(max_num_files=None):
|
47 |
+
dates = []
|
48 |
+
for month in range(4, 12):
|
49 |
+
for day in range(1, 33):
|
50 |
+
dates.append(f"2023-{month:02d}-{day:02d}")
|
51 |
+
|
52 |
+
filenames = []
|
53 |
+
for d in dates:
|
54 |
+
for i in range(NUM_SERVERS):
|
55 |
+
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
56 |
+
if os.path.exists(name):
|
57 |
+
filenames.append(name)
|
58 |
+
max_num_files = max_num_files or len(filenames)
|
59 |
+
filenames = filenames[-max_num_files:]
|
60 |
+
return filenames
|
61 |
+
|
62 |
+
|
63 |
+
def remove_html(raw):
|
64 |
+
if raw.startswith("<h3>"):
|
65 |
+
return raw[raw.find(": ") + 2 : -len("</h3>\n")]
|
66 |
+
return raw
|
67 |
+
|
68 |
+
|
69 |
+
def to_openai_format(messages):
|
70 |
+
roles = ["user", "assistant"]
|
71 |
+
ret = []
|
72 |
+
for i, x in enumerate(messages):
|
73 |
+
ret.append({"role": roles[i % 2], "content": x[1]})
|
74 |
+
return ret
|
75 |
+
|
76 |
+
|
77 |
+
def replace_model_name(old_name):
|
78 |
+
return (
|
79 |
+
old_name.replace("bard", "palm-2")
|
80 |
+
.replace("claude-v1", "claude-1")
|
81 |
+
.replace("claude-instant-v1", "claude-instant-1")
|
82 |
+
.replace("oasst-sft-1-pythia-12b", "oasst-pythia-12b")
|
83 |
+
)
|
84 |
+
|
85 |
+
|
86 |
+
def clean_battle_data(log_files, exclude_model_names):
|
87 |
+
data = []
|
88 |
+
for filename in tqdm(log_files, desc="read files"):
|
89 |
+
for retry in range(5):
|
90 |
+
try:
|
91 |
+
lines = open(filename).readlines()
|
92 |
+
break
|
93 |
+
except FileNotFoundError:
|
94 |
+
time.sleep(2)
|
95 |
+
|
96 |
+
for l in lines:
|
97 |
+
row = json.loads(l)
|
98 |
+
if row["type"] in VOTES:
|
99 |
+
data.append(row)
|
100 |
+
|
101 |
+
convert_type = {
|
102 |
+
"leftvote": "model_a",
|
103 |
+
"rightvote": "model_b",
|
104 |
+
"tievote": "tie",
|
105 |
+
"bothbad_vote": "tie (bothbad)",
|
106 |
+
}
|
107 |
+
|
108 |
+
all_models = set()
|
109 |
+
all_ips = dict()
|
110 |
+
ct_anony = 0
|
111 |
+
ct_invalid = 0
|
112 |
+
ct_leaked_identity = 0
|
113 |
+
battles = []
|
114 |
+
for row in data:
|
115 |
+
if row["models"][0] is None or row["models"][1] is None:
|
116 |
+
continue
|
117 |
+
|
118 |
+
# Resolve model names
|
119 |
+
models_public = [remove_html(row["models"][0]), remove_html(row["models"][1])]
|
120 |
+
if "model_name" in row["states"][0]:
|
121 |
+
models_hidden = [
|
122 |
+
row["states"][0]["model_name"],
|
123 |
+
row["states"][1]["model_name"],
|
124 |
+
]
|
125 |
+
if models_hidden[0] is None:
|
126 |
+
models_hidden = models_public
|
127 |
+
else:
|
128 |
+
models_hidden = models_public
|
129 |
+
|
130 |
+
if (models_public[0] == "" and models_public[1] != "") or (
|
131 |
+
models_public[1] == "" and models_public[0] != ""
|
132 |
+
):
|
133 |
+
ct_invalid += 1
|
134 |
+
continue
|
135 |
+
|
136 |
+
if models_public[0] == "" or models_public[0] == "Model A":
|
137 |
+
anony = True
|
138 |
+
models = models_hidden
|
139 |
+
ct_anony += 1
|
140 |
+
else:
|
141 |
+
anony = False
|
142 |
+
models = models_public
|
143 |
+
if not models_public == models_hidden:
|
144 |
+
ct_invalid += 1
|
145 |
+
continue
|
146 |
+
|
147 |
+
# Detect langauge
|
148 |
+
state = row["states"][0]
|
149 |
+
if state["offset"] >= len(state["messages"]):
|
150 |
+
ct_invalid += 1
|
151 |
+
continue
|
152 |
+
lang_code = detect_language(state["messages"][state["offset"]][1])
|
153 |
+
|
154 |
+
# Drop conversations if the model names are leaked
|
155 |
+
leaked_identity = False
|
156 |
+
messages = ""
|
157 |
+
for i in range(2):
|
158 |
+
state = row["states"][i]
|
159 |
+
for role, msg in state["messages"][state["offset"] :]:
|
160 |
+
if msg:
|
161 |
+
messages += msg.lower()
|
162 |
+
for word in IDENTITY_WORDS:
|
163 |
+
if word in messages:
|
164 |
+
leaked_identity = True
|
165 |
+
break
|
166 |
+
|
167 |
+
if leaked_identity:
|
168 |
+
ct_leaked_identity += 1
|
169 |
+
continue
|
170 |
+
|
171 |
+
# Replace bard with palm
|
172 |
+
models = [replace_model_name(m) for m in models]
|
173 |
+
|
174 |
+
# Exclude certain models
|
175 |
+
if any(x in exclude_model_names for x in models):
|
176 |
+
ct_invalid += 1
|
177 |
+
continue
|
178 |
+
|
179 |
+
question_id = row["states"][0]["conv_id"]
|
180 |
+
conversation_a = to_openai_format(
|
181 |
+
row["states"][0]["messages"][row["states"][0]["offset"] :]
|
182 |
+
)
|
183 |
+
conversation_b = to_openai_format(
|
184 |
+
row["states"][1]["messages"][row["states"][1]["offset"] :]
|
185 |
+
)
|
186 |
+
|
187 |
+
ip = row["ip"]
|
188 |
+
if ip not in all_ips:
|
189 |
+
all_ips[ip] = len(all_ips)
|
190 |
+
user_id = all_ips[ip]
|
191 |
+
|
192 |
+
# Save the results
|
193 |
+
battles.append(
|
194 |
+
dict(
|
195 |
+
question_id=question_id,
|
196 |
+
model_a=models[0],
|
197 |
+
model_b=models[1],
|
198 |
+
winner=convert_type[row["type"]],
|
199 |
+
judge=f"arena_user_{user_id}",
|
200 |
+
conversation_a=conversation_a,
|
201 |
+
conversation_b=conversation_b,
|
202 |
+
turn=len(conversation_a) // 2,
|
203 |
+
anony=anony,
|
204 |
+
language=lang_code,
|
205 |
+
tstamp=row["tstamp"],
|
206 |
+
)
|
207 |
+
)
|
208 |
+
|
209 |
+
all_models.update(models_hidden)
|
210 |
+
battles.sort(key=lambda x: x["tstamp"])
|
211 |
+
last_updated_tstamp = battles[-1]["tstamp"]
|
212 |
+
|
213 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
214 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
215 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
216 |
+
|
217 |
+
print(
|
218 |
+
f"#votes: {len(data)}, #invalid votes: {ct_invalid}, "
|
219 |
+
f"#leaked_identity: {ct_leaked_identity}"
|
220 |
+
)
|
221 |
+
print(f"#battles: {len(battles)}, #anony: {ct_anony}")
|
222 |
+
print(f"#models: {len(all_models)}, {all_models}")
|
223 |
+
print(f"last-updated: {last_updated_datetime}")
|
224 |
+
|
225 |
+
return battles
|
226 |
+
|
227 |
+
|
228 |
+
if __name__ == "__main__":
|
229 |
+
parser = argparse.ArgumentParser()
|
230 |
+
parser.add_argument("--max-num-files", type=int)
|
231 |
+
parser.add_argument(
|
232 |
+
"--mode", type=str, choices=["simple", "conv_release"], default="simple"
|
233 |
+
)
|
234 |
+
parser.add_argument("--exclude-model-names", type=str, nargs="+")
|
235 |
+
args = parser.parse_args()
|
236 |
+
|
237 |
+
log_files = get_log_files(args.max_num_files)
|
238 |
+
battles = clean_battle_data(log_files, args.exclude_model_names or [])
|
239 |
+
last_updated_tstamp = battles[-1]["tstamp"]
|
240 |
+
cutoff_date = datetime.datetime.fromtimestamp(
|
241 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
242 |
+
).strftime("%Y%m%d")
|
243 |
+
|
244 |
+
if args.mode == "simple":
|
245 |
+
for x in battles:
|
246 |
+
for key in [
|
247 |
+
"conversation_a",
|
248 |
+
"conversation_b",
|
249 |
+
"question_id",
|
250 |
+
]:
|
251 |
+
del x[key]
|
252 |
+
print("Samples:")
|
253 |
+
for i in range(4):
|
254 |
+
print(battles[i])
|
255 |
+
output = f"clean_battle_{cutoff_date}.json"
|
256 |
+
elif args.mode == "conv_release":
|
257 |
+
new_battles = []
|
258 |
+
for x in battles:
|
259 |
+
if not x["anony"]:
|
260 |
+
continue
|
261 |
+
for key in []:
|
262 |
+
del x[key]
|
263 |
+
new_battles.append(x)
|
264 |
+
battles = new_battles
|
265 |
+
output = f"clean_battle_conv_{cutoff_date}.json"
|
266 |
+
|
267 |
+
with open(output, "w") as fout:
|
268 |
+
json.dump(battles, fout, indent=2, ensure_ascii=False)
|
269 |
+
print(f"Write cleaned data to {output}")
|
monitor/clean_chat_data.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Clean chatbot arena chat log.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 clean_chat_data.py --mode conv_release
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
import datetime
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
from pytz import timezone
|
12 |
+
import time
|
13 |
+
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from fastchat.serve.monitor.basic_stats import NUM_SERVERS
|
17 |
+
from fastchat.serve.monitor.clean_battle_data import (
|
18 |
+
to_openai_format,
|
19 |
+
replace_model_name,
|
20 |
+
)
|
21 |
+
from fastchat.utils import detect_language
|
22 |
+
|
23 |
+
|
24 |
+
NETWORK_ERROR_MSG = (
|
25 |
+
"NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.".lower()
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
def get_log_files(max_num_files=None):
|
30 |
+
dates = []
|
31 |
+
for month in range(4, 12):
|
32 |
+
for day in range(1, 33):
|
33 |
+
dates.append(f"2023-{month:02d}-{day:02d}")
|
34 |
+
|
35 |
+
filenames = []
|
36 |
+
for d in dates:
|
37 |
+
for i in range(NUM_SERVERS):
|
38 |
+
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
39 |
+
if os.path.exists(name):
|
40 |
+
filenames.append(name)
|
41 |
+
max_num_files = max_num_files or len(filenames)
|
42 |
+
# filenames = list(reversed(filenames))
|
43 |
+
filenames = filenames[-max_num_files:]
|
44 |
+
return filenames
|
45 |
+
|
46 |
+
|
47 |
+
def clean_chat_data(log_files, action_type):
|
48 |
+
raw_data = []
|
49 |
+
for filename in tqdm(log_files, desc="read files"):
|
50 |
+
for retry in range(5):
|
51 |
+
try:
|
52 |
+
lines = open(filename).readlines()
|
53 |
+
break
|
54 |
+
except FileNotFoundError:
|
55 |
+
time.sleep(2)
|
56 |
+
|
57 |
+
for l in lines:
|
58 |
+
row = json.loads(l)
|
59 |
+
if row["type"] == action_type:
|
60 |
+
raw_data.append(row)
|
61 |
+
|
62 |
+
all_models = set()
|
63 |
+
all_ips = dict()
|
64 |
+
chats = []
|
65 |
+
ct_invalid_conv_id = 0
|
66 |
+
ct_invalid = 0
|
67 |
+
ct_network_error = 0
|
68 |
+
for row in raw_data:
|
69 |
+
try:
|
70 |
+
if action_type in ["chat", "upvote", "downvote"]:
|
71 |
+
state = row["state"]
|
72 |
+
model = row["model"]
|
73 |
+
elif action_type == "leftvote":
|
74 |
+
state = row["states"][0]
|
75 |
+
model = row["states"][0]["model_name"]
|
76 |
+
elif action_type == "rightvote":
|
77 |
+
state = row["states"][1]
|
78 |
+
model = row["states"][1]["model_name"]
|
79 |
+
conversation_id = state["conv_id"]
|
80 |
+
except KeyError:
|
81 |
+
ct_invalid_conv_id += 1
|
82 |
+
continue
|
83 |
+
|
84 |
+
if conversation_id is None:
|
85 |
+
ct_invalid_conv_id += 1
|
86 |
+
continue
|
87 |
+
|
88 |
+
conversation = to_openai_format(state["messages"][state["offset"] :])
|
89 |
+
if not isinstance(model, str):
|
90 |
+
ct_invalid += 1
|
91 |
+
continue
|
92 |
+
model = replace_model_name(model)
|
93 |
+
|
94 |
+
try:
|
95 |
+
lang_code = detect_language(state["messages"][state["offset"]][1])
|
96 |
+
except IndexError:
|
97 |
+
ct_invalid += 1
|
98 |
+
continue
|
99 |
+
|
100 |
+
if not all(isinstance(x["content"], str) for x in conversation):
|
101 |
+
ct_invalid += 1
|
102 |
+
continue
|
103 |
+
|
104 |
+
messages = "".join([x["content"] for x in conversation]).lower()
|
105 |
+
if NETWORK_ERROR_MSG in messages:
|
106 |
+
ct_network_error += 1
|
107 |
+
continue
|
108 |
+
|
109 |
+
ip = row["ip"]
|
110 |
+
if ip not in all_ips:
|
111 |
+
all_ips[ip] = len(all_ips)
|
112 |
+
user_id = all_ips[ip]
|
113 |
+
|
114 |
+
chats.append(
|
115 |
+
dict(
|
116 |
+
conversation_id=conversation_id,
|
117 |
+
model=model,
|
118 |
+
conversation=conversation,
|
119 |
+
turn=len(conversation) // 2,
|
120 |
+
language=lang_code,
|
121 |
+
user_id=user_id,
|
122 |
+
tstamp=row["tstamp"],
|
123 |
+
)
|
124 |
+
)
|
125 |
+
|
126 |
+
all_models.update([model])
|
127 |
+
|
128 |
+
chats.sort(key=lambda x: x["tstamp"])
|
129 |
+
last_updated_tstamp = chats[-1]["tstamp"]
|
130 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
131 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
132 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
133 |
+
|
134 |
+
# Deduplication
|
135 |
+
dedup_chats = []
|
136 |
+
visited_conv_ids = set()
|
137 |
+
for i in reversed(range(len(chats))):
|
138 |
+
if chats[i]["conversation_id"] in visited_conv_ids:
|
139 |
+
continue
|
140 |
+
visited_conv_ids.add(chats[i]["conversation_id"])
|
141 |
+
dedup_chats.append(chats[i])
|
142 |
+
|
143 |
+
print(
|
144 |
+
f"#raw: {len(raw_data)}, #chat: {len(chats)}, #dedup_chat: {len(dedup_chats)}"
|
145 |
+
)
|
146 |
+
print(
|
147 |
+
f"#invalid_conv_id: {ct_invalid_conv_id}, #network_error: {ct_network_error}, #invalid: {ct_invalid}"
|
148 |
+
)
|
149 |
+
print(f"#models: {len(all_models)}, {all_models}")
|
150 |
+
print(f"last-updated: {last_updated_datetime}")
|
151 |
+
|
152 |
+
return list(reversed(dedup_chats))
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
parser = argparse.ArgumentParser()
|
157 |
+
parser.add_argument("--action-type", type=str, default="chat")
|
158 |
+
parser.add_argument("--max-num-files", type=int)
|
159 |
+
args = parser.parse_args()
|
160 |
+
|
161 |
+
log_files = get_log_files(args.max_num_files)
|
162 |
+
chats = clean_chat_data(log_files, args.action_type)
|
163 |
+
last_updated_tstamp = chats[-1]["tstamp"]
|
164 |
+
cutoff_date = datetime.datetime.fromtimestamp(
|
165 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
166 |
+
).strftime("%Y%m%d")
|
167 |
+
|
168 |
+
output = f"clean_{args.action_type}_conv_{cutoff_date}.json"
|
169 |
+
with open(output, "w") as fout:
|
170 |
+
json.dump(chats, fout, indent=2, ensure_ascii=False)
|
171 |
+
print(f"Write cleaned data to {output}")
|
monitor/dataset_release_scripts/arena_33k/count_unique_users.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Count the unique users in a battle log file."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--input", type=str)
|
10 |
+
args = parser.parse_args()
|
11 |
+
|
12 |
+
lines = json.load(open(args.input))
|
13 |
+
ct_anony_votes = 0
|
14 |
+
all_users = set()
|
15 |
+
all_models = set()
|
16 |
+
for l in lines:
|
17 |
+
if not l["anony"]:
|
18 |
+
continue
|
19 |
+
all_users.add(l["judge"])
|
20 |
+
all_models.add(l["model_a"])
|
21 |
+
all_models.add(l["model_b"])
|
22 |
+
ct_anony_votes += 1
|
23 |
+
|
24 |
+
print(f"#anony_vote: {ct_anony_votes}, #user: {len(all_users)}")
|
25 |
+
print(f"#model: {len(all_models)}")
|
monitor/dataset_release_scripts/arena_33k/filter_bad_conv.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Filter conversations for release.
|
3 |
+
|
4 |
+
Usage: python3 filter_bad_conv.py --in clean_battle_conv_20230630_tagged_v1_pii.json
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
from collections import defaultdict
|
8 |
+
from enum import Enum, auto
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import random
|
12 |
+
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
BLOCKED_WORDS_FILENAME = "blocked_words.json"
|
16 |
+
blocked_words = []
|
17 |
+
frequency = defaultdict(lambda: 0)
|
18 |
+
|
19 |
+
|
20 |
+
class TypeCode(Enum):
|
21 |
+
CORRECT = auto()
|
22 |
+
ANONYMIZED = auto()
|
23 |
+
REDACTED = auto()
|
24 |
+
BAD_FORMAT = auto()
|
25 |
+
BLOCKED_WORD = auto()
|
26 |
+
BLOCKED_MODEL = auto()
|
27 |
+
TOO_SHORT = auto()
|
28 |
+
TOO_FREQUENT = auto()
|
29 |
+
|
30 |
+
|
31 |
+
def detect_type(conv):
|
32 |
+
for key in ["conversation_a", "conversation_b"]:
|
33 |
+
messages = [row["content"] for row in conv[key]]
|
34 |
+
for msg in messages:
|
35 |
+
if not isinstance(msg, str):
|
36 |
+
return TypeCode.BAD_FORMAT
|
37 |
+
|
38 |
+
user_prompts = [
|
39 |
+
row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
|
40 |
+
]
|
41 |
+
if len(messages) <= 2 and all(len(x) < 16 for x in user_prompts):
|
42 |
+
return TypeCode.TOO_SHORT
|
43 |
+
|
44 |
+
if all(x in frequent_prompts for x in user_prompts):
|
45 |
+
return TypeCode.TOO_FREQUENT
|
46 |
+
|
47 |
+
for msg in messages:
|
48 |
+
msg = msg.lower()
|
49 |
+
if "<anonymized>" in msg:
|
50 |
+
return TypeCode.ANONYMIZED
|
51 |
+
if "<redacted>" in msg:
|
52 |
+
return TypeCode.REDACTED
|
53 |
+
|
54 |
+
for w in blocked_words:
|
55 |
+
if w in msg:
|
56 |
+
return TypeCode.BLOCKED_WORD
|
57 |
+
|
58 |
+
for key in ["model_a", "model_b"]:
|
59 |
+
if conv[key] in ["vicuna-33b", "mpt-30b-chat"]:
|
60 |
+
return TypeCode.BLOCKED_MODEL
|
61 |
+
|
62 |
+
return TypeCode.CORRECT
|
63 |
+
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
parser = argparse.ArgumentParser()
|
67 |
+
parser.add_argument("--in-file", type=str, required=True)
|
68 |
+
parser.add_argument("--sample", type=int)
|
69 |
+
args = parser.parse_args()
|
70 |
+
|
71 |
+
# Read conversations
|
72 |
+
convs = json.load(open(args.in_file))
|
73 |
+
print(f"#conv: {len(convs)}")
|
74 |
+
|
75 |
+
# Read blocked words
|
76 |
+
if os.path.exists(BLOCKED_WORDS_FILENAME):
|
77 |
+
blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
|
78 |
+
|
79 |
+
# Count frequency
|
80 |
+
for conv in convs:
|
81 |
+
for key in ["conversation_a", "conversation_b"]:
|
82 |
+
messages = [row["content"] for row in conv[key] if row["role"] == "user"]
|
83 |
+
for msg in messages:
|
84 |
+
if not isinstance(msg, str):
|
85 |
+
continue
|
86 |
+
msg = msg.lower().strip()
|
87 |
+
frequency[msg] += 1
|
88 |
+
|
89 |
+
keys = list(frequency.keys())
|
90 |
+
keys.sort(key=lambda x: -frequency[x])
|
91 |
+
frequent_prompts = keys[:10]
|
92 |
+
frequent_prompts = set(frequent_prompts)
|
93 |
+
frequent_prompts.add("")
|
94 |
+
|
95 |
+
# Start filter
|
96 |
+
ct_bad_format = 0
|
97 |
+
ct_anonymized = 0
|
98 |
+
ct_redacted = 0
|
99 |
+
ct_error = 0
|
100 |
+
ct_lang_filter = 0
|
101 |
+
ct_flagged = 0
|
102 |
+
ct_blocked_word = 0
|
103 |
+
ct_blocked_model = 0
|
104 |
+
ct_too_short = 0
|
105 |
+
ct_too_frequent = 0
|
106 |
+
|
107 |
+
new_convs = []
|
108 |
+
for conv in tqdm(convs):
|
109 |
+
type_code = detect_type(conv)
|
110 |
+
|
111 |
+
if type_code == TypeCode.BAD_FORMAT:
|
112 |
+
ct_bad_format += 1
|
113 |
+
continue
|
114 |
+
|
115 |
+
if type_code == TypeCode.ANONYMIZED:
|
116 |
+
ct_anonymized += 1
|
117 |
+
continue
|
118 |
+
elif type_code == TypeCode.REDACTED:
|
119 |
+
ct_redacted += 1
|
120 |
+
continue
|
121 |
+
elif type_code == TypeCode.BLOCKED_WORD:
|
122 |
+
ct_blocked_word += 1
|
123 |
+
continue
|
124 |
+
elif type_code == TypeCode.BLOCKED_MODEL:
|
125 |
+
ct_blocked_model += 1
|
126 |
+
continue
|
127 |
+
elif type_code == TypeCode.TOO_SHORT:
|
128 |
+
ct_too_short += 1
|
129 |
+
continue
|
130 |
+
elif type_code == TypeCode.TOO_FREQUENT:
|
131 |
+
ct_too_frequent += 1
|
132 |
+
continue
|
133 |
+
|
134 |
+
if conv["openai_moderation"]["flagged"]:
|
135 |
+
ct_flagged += 1
|
136 |
+
continue
|
137 |
+
|
138 |
+
if type_code in [TypeCode.CORRECT]:
|
139 |
+
new_convs.append(conv)
|
140 |
+
|
141 |
+
if args.sample:
|
142 |
+
# random.seed(0)
|
143 |
+
# random.shuffle(new_convs)
|
144 |
+
new_convs = new_convs[: args.sample]
|
145 |
+
|
146 |
+
print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
|
147 |
+
print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
|
148 |
+
print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
|
149 |
+
print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_anonymized}")
|
150 |
+
print(f"new_conv: {len(new_convs)}")
|
151 |
+
|
152 |
+
out_file = args.in_file.replace(".json", ".out.json")
|
153 |
+
print(f"Output to {out_file}")
|
154 |
+
with open(out_file, "w") as fout:
|
155 |
+
json.dump(new_convs, fout, indent=2, ensure_ascii=False)
|
monitor/dataset_release_scripts/arena_33k/merge_field.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Count the unique users in a battle log file."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--input", type=str)
|
10 |
+
parser.add_argument("--tag-file", type=str)
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
# build index
|
14 |
+
objs = json.load(open(args.tag_file))
|
15 |
+
new_field_dict = {}
|
16 |
+
for obj in objs:
|
17 |
+
new_field_dict[obj["question_id"]] = obj["toxic_chat"]
|
18 |
+
|
19 |
+
objs = json.load(open(args.input))
|
20 |
+
for obj in objs:
|
21 |
+
obj["toxic_chat_tag"] = new_field_dict[obj["question_id"]]
|
22 |
+
|
23 |
+
output = args.input.replace(".json", "_added.json")
|
24 |
+
with open(output, "w") as fout:
|
25 |
+
json.dump(objs, fout, indent=2, ensure_ascii=False)
|
monitor/dataset_release_scripts/arena_33k/sample.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Count the unique users in a battle log file.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -input in.json --number 1000
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
|
12 |
+
K = 1000
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--input", type=str)
|
17 |
+
parser.add_argument("--number", type=int, nargs="+")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
convs = json.load(open(args.input))
|
21 |
+
random.seed(0)
|
22 |
+
random.shuffle(convs)
|
23 |
+
|
24 |
+
for number in args.number:
|
25 |
+
new_convs = convs[:number]
|
26 |
+
|
27 |
+
output = args.input.replace(".json", f"_{number//K}k.json")
|
28 |
+
with open(output, "w") as fout:
|
29 |
+
json.dump(new_convs, fout, indent=2, ensure_ascii=False)
|
30 |
+
|
31 |
+
print(f"#in: {len(convs)}, #out: {len(new_convs)}")
|
32 |
+
print(f"Write to file: {output}")
|
monitor/dataset_release_scripts/arena_33k/upload_hf_dataset.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Upload to huggingface.
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
6 |
+
|
7 |
+
objs = json.load(open("clean_battle_conv_20230630_tagged_v3_pii_33k_added.json"))
|
8 |
+
data = Dataset.from_list(objs)
|
9 |
+
data.push_to_hub("lmsys/chatbot_arena_conversations", private=True)
|
monitor/dataset_release_scripts/lmsys_chat_1m/approve_all.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
|
3 |
+
headers = {"authorization": "Bearer hf_XXX"}
|
4 |
+
|
5 |
+
url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/pending"
|
6 |
+
a = requests.get(url, headers=headers)
|
7 |
+
|
8 |
+
for u in a.json():
|
9 |
+
user = u["user"]["user"]
|
10 |
+
url = "https://huggingface.co/api/datasets/lmsys/lmsys-chat-1m/user-access-request/grant"
|
11 |
+
ret = requests.post(url, headers=headers, json={"user": user})
|
12 |
+
print(user, ret.status_code)
|
13 |
+
assert ret.status_code == 200
|
monitor/dataset_release_scripts/lmsys_chat_1m/compute_stats.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
From colab:
|
3 |
+
https://colab.research.google.com/drive/1oMdw_Lqgmd6DletSOLHsyD-Rc96cRShs?usp=sharing
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import datetime
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
from pytz import timezone
|
10 |
+
import time
|
11 |
+
|
12 |
+
import kaleido
|
13 |
+
import numpy as np
|
14 |
+
import pandas as pd
|
15 |
+
import plotly.express as px
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
import plotly.io as pio
|
20 |
+
|
21 |
+
pio.kaleido.scope.mathjax = None
|
22 |
+
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--in-file", type=str, required=True)
|
25 |
+
parser.add_argument("--scale", type=int, required=True)
|
26 |
+
args = parser.parse_args()
|
27 |
+
|
28 |
+
filename = args.in_file
|
29 |
+
scale = args.scale
|
30 |
+
convs = json.load(open(filename))
|
31 |
+
df = pd.DataFrame(convs)
|
32 |
+
df
|
33 |
+
|
34 |
+
print(f"#ips: {df['user_id'].nunique() * scale}")
|
35 |
+
print(f"#models: {df['model'].nunique()}")
|
36 |
+
print(f"#language: {df['language'].nunique()}")
|
37 |
+
print(f"#turns: {df['turn'].mean()}")
|
38 |
+
|
39 |
+
model_counts = df["model"].value_counts() * scale
|
40 |
+
# print("model counts", model_counts)
|
41 |
+
fig = px.bar(x=model_counts.index, y=model_counts)
|
42 |
+
fig.update_layout(
|
43 |
+
xaxis_title=None,
|
44 |
+
yaxis_title="Count",
|
45 |
+
height=200,
|
46 |
+
width=950,
|
47 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
48 |
+
)
|
49 |
+
fig.show()
|
50 |
+
fig.write_image("model_count.pdf")
|
51 |
+
|
52 |
+
|
53 |
+
model_counts = df["language"].value_counts().head(25) * scale
|
54 |
+
fig = px.bar(x=model_counts.index, y=model_counts)
|
55 |
+
fig.update_layout(
|
56 |
+
xaxis_title=None,
|
57 |
+
yaxis_title="Count",
|
58 |
+
height=200,
|
59 |
+
width=950,
|
60 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
61 |
+
)
|
62 |
+
fig.show()
|
63 |
+
fig.write_image("language_count.pdf")
|
64 |
+
|
65 |
+
chat_dates = [
|
66 |
+
datetime.datetime.fromtimestamp(x, tz=timezone("US/Pacific")).strftime("%Y-%m-%d")
|
67 |
+
for x in df["tstamp"]
|
68 |
+
]
|
69 |
+
|
70 |
+
|
71 |
+
def to_remove(x):
|
72 |
+
for d in ["08-09", "08-08", "08-07", "08-06", "08-05", "08-04"]:
|
73 |
+
if d in x:
|
74 |
+
return True
|
75 |
+
return False
|
76 |
+
|
77 |
+
|
78 |
+
chat_dates = [x for x in chat_dates if not to_remove(x)]
|
79 |
+
|
80 |
+
chat_dates_counts = pd.value_counts(chat_dates) * scale
|
81 |
+
print(f"mean #chat per day: {np.mean(chat_dates_counts):.2f}")
|
82 |
+
|
83 |
+
fig = px.bar(x=chat_dates_counts.index, y=chat_dates_counts)
|
84 |
+
fig.update_layout(
|
85 |
+
xaxis_title="Dates",
|
86 |
+
yaxis_title="Count",
|
87 |
+
height=200,
|
88 |
+
width=950,
|
89 |
+
margin=dict(l=0, r=0, t=0, b=0),
|
90 |
+
)
|
91 |
+
fig.show()
|
92 |
+
fig.write_image("daily_conversation_count.pdf")
|
93 |
+
|
94 |
+
import transformers
|
95 |
+
|
96 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
97 |
+
"lmsys/vicuna-7b-v1.5", use_fast=False
|
98 |
+
)
|
99 |
+
|
100 |
+
prompts = []
|
101 |
+
responses = []
|
102 |
+
for conv in df["conversation"]:
|
103 |
+
for row in conv:
|
104 |
+
if row["role"] == "user":
|
105 |
+
prompts.append(row["content"])
|
106 |
+
else:
|
107 |
+
responses.append(row["content"])
|
108 |
+
|
109 |
+
print(f"#prompts: {len(prompts)}")
|
110 |
+
print(f"#responses: {len(responses)}")
|
111 |
+
|
112 |
+
|
113 |
+
prompt_lens = [len(tokenizer(x).input_ids) for x in tqdm(prompts)]
|
114 |
+
print()
|
115 |
+
print(f"mean prompt len: {np.mean(prompt_lens):.2f}")
|
116 |
+
|
117 |
+
response_lens = [len(tokenizer(x).input_ids) if x else 0 for x in tqdm(responses)]
|
118 |
+
print()
|
119 |
+
print(f"mean response len: {np.mean(response_lens):.2f}")
|
monitor/dataset_release_scripts/lmsys_chat_1m/filter_bad_conv.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Filter conversations for release.
|
3 |
+
|
4 |
+
Dependency:
|
5 |
+
pip install opencc-python-reimplementedpip install opencc-python-reimplemented
|
6 |
+
|
7 |
+
Usage:
|
8 |
+
python3 filter_bad_conv_lmsys_chat_1m.py --in clean_battle_conv_20230630_tagged_v1_pii.json
|
9 |
+
"""
|
10 |
+
import argparse
|
11 |
+
from concurrent.futures import ProcessPoolExecutor
|
12 |
+
from collections import defaultdict
|
13 |
+
from enum import Enum, auto
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
import random
|
17 |
+
|
18 |
+
from tqdm import tqdm
|
19 |
+
import opencc
|
20 |
+
|
21 |
+
BLOCKED_WORDS_FILENAME = "blocked_words.json"
|
22 |
+
blocked_words = []
|
23 |
+
frequency = defaultdict(lambda: 0)
|
24 |
+
|
25 |
+
cc_converter = opencc.OpenCC("t2s")
|
26 |
+
|
27 |
+
|
28 |
+
class TypeCode(Enum):
|
29 |
+
CORRECT = auto()
|
30 |
+
ANONYMIZED = auto()
|
31 |
+
REDACTED = auto()
|
32 |
+
BAD_FORMAT = auto()
|
33 |
+
BLOCKED_WORD = auto()
|
34 |
+
BLOCKED_MODEL = auto()
|
35 |
+
TOO_SHORT = auto()
|
36 |
+
TOO_FREQUENT = auto()
|
37 |
+
|
38 |
+
|
39 |
+
def detect_type(conv):
|
40 |
+
for key in ["conversation_a", "conversation_b", "conversation"]:
|
41 |
+
if key not in conv:
|
42 |
+
continue
|
43 |
+
|
44 |
+
messages = [row["content"] for row in conv[key]]
|
45 |
+
for msg in messages:
|
46 |
+
if not isinstance(msg, str):
|
47 |
+
return TypeCode.BAD_FORMAT
|
48 |
+
|
49 |
+
if len(messages) == 0:
|
50 |
+
return TypeCode.BAD_FORMAT
|
51 |
+
|
52 |
+
user_prompts = [
|
53 |
+
row["content"].lower().strip() for row in conv[key] if row["role"] == "user"
|
54 |
+
]
|
55 |
+
|
56 |
+
for msg in messages:
|
57 |
+
msg = cc_converter.convert(msg.lower())
|
58 |
+
if "<anonymized>" in msg:
|
59 |
+
return TypeCode.ANONYMIZED
|
60 |
+
if "<redacted>" in msg:
|
61 |
+
return TypeCode.REDACTED
|
62 |
+
|
63 |
+
for w in blocked_words:
|
64 |
+
if w in msg:
|
65 |
+
return TypeCode.BLOCKED_WORD
|
66 |
+
|
67 |
+
return TypeCode.CORRECT
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
parser = argparse.ArgumentParser()
|
72 |
+
parser.add_argument("--in-file", type=str, required=True)
|
73 |
+
parser.add_argument("--sample", type=int)
|
74 |
+
args = parser.parse_args()
|
75 |
+
|
76 |
+
# Read conversations
|
77 |
+
convs = json.load(open(args.in_file))
|
78 |
+
print(f"#conv: {len(convs)}")
|
79 |
+
|
80 |
+
# Read blocked words
|
81 |
+
if os.path.exists(BLOCKED_WORDS_FILENAME):
|
82 |
+
blocked_words = json.load(open(BLOCKED_WORDS_FILENAME))
|
83 |
+
blocked_words = [cc_converter.convert(w) for w in blocked_words]
|
84 |
+
|
85 |
+
# Start filter
|
86 |
+
ct_bad_format = 0
|
87 |
+
ct_anonymized = 0
|
88 |
+
ct_redacted = 0
|
89 |
+
ct_error = 0
|
90 |
+
ct_lang_filter = 0
|
91 |
+
ct_flagged = 0
|
92 |
+
ct_blocked_word = 0
|
93 |
+
ct_blocked_model = 0
|
94 |
+
ct_too_short = 0
|
95 |
+
ct_too_frequent = 0
|
96 |
+
|
97 |
+
type_codes = []
|
98 |
+
with ProcessPoolExecutor() as executor:
|
99 |
+
for result in tqdm(executor.map(detect_type, convs), total=len(convs)):
|
100 |
+
type_codes.append(result)
|
101 |
+
|
102 |
+
new_convs = []
|
103 |
+
for conv, type_code in zip(convs, type_codes):
|
104 |
+
if type_code == TypeCode.BAD_FORMAT:
|
105 |
+
ct_bad_format += 1
|
106 |
+
continue
|
107 |
+
|
108 |
+
if type_code == TypeCode.ANONYMIZED:
|
109 |
+
ct_anonymized += 1
|
110 |
+
continue
|
111 |
+
elif type_code == TypeCode.REDACTED:
|
112 |
+
ct_redacted += 1
|
113 |
+
continue
|
114 |
+
elif type_code == TypeCode.BLOCKED_WORD:
|
115 |
+
ct_blocked_word += 1
|
116 |
+
continue
|
117 |
+
elif type_code == TypeCode.BLOCKED_MODEL:
|
118 |
+
ct_blocked_model += 1
|
119 |
+
continue
|
120 |
+
elif type_code == TypeCode.TOO_SHORT:
|
121 |
+
ct_too_short += 1
|
122 |
+
continue
|
123 |
+
elif type_code == TypeCode.TOO_FREQUENT:
|
124 |
+
ct_too_frequent += 1
|
125 |
+
continue
|
126 |
+
|
127 |
+
if "openai_moderation" in conv and conv["openai_moderation"]["flagged"]:
|
128 |
+
ct_flagged += 1
|
129 |
+
continue
|
130 |
+
|
131 |
+
if type_code in [TypeCode.CORRECT]:
|
132 |
+
new_convs.append(conv)
|
133 |
+
|
134 |
+
if args.sample:
|
135 |
+
random.seed(42)
|
136 |
+
random.shuffle(new_convs)
|
137 |
+
new_convs = new_convs[: args.sample]
|
138 |
+
|
139 |
+
print(f"ct_anonymized: {ct_anonymized}, ct_redacted: {ct_redacted}")
|
140 |
+
print(f"ct_bad_format: {ct_bad_format}, ct_flagged: {ct_flagged}")
|
141 |
+
print(f"ct_blocked_word: {ct_blocked_word}, ct_blocked_model: {ct_blocked_model}")
|
142 |
+
print(f"ct_too_short: {ct_too_short}, ct_too_frequent: {ct_too_frequent}")
|
143 |
+
print(f"new_conv: {len(new_convs)}")
|
144 |
+
|
145 |
+
out_file = args.in_file.replace(".json", ".s1.json")
|
146 |
+
print(f"Output to {out_file}")
|
147 |
+
with open(out_file, "w") as fout:
|
148 |
+
json.dump(new_convs, fout, indent=2, ensure_ascii=False)
|
monitor/dataset_release_scripts/lmsys_chat_1m/final_post_processing.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--in-file", type=str, required=True)
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
# Read conversations
|
14 |
+
convs = json.load(open(args.in_file))
|
15 |
+
print(f"#conv: {len(convs)}")
|
16 |
+
|
17 |
+
# Delete some fileds
|
18 |
+
for c in convs:
|
19 |
+
del c["tstamp"]
|
20 |
+
del c["user_id"]
|
21 |
+
|
22 |
+
# Write
|
23 |
+
print(f"#out conv: {len(convs)}")
|
24 |
+
out_file = args.in_file.replace(".json", ".s2.json")
|
25 |
+
print(f"Output to {out_file}")
|
26 |
+
with open(out_file, "w") as fout:
|
27 |
+
json.dump(convs, fout, indent=2, ensure_ascii=False)
|
monitor/dataset_release_scripts/lmsys_chat_1m/instructions.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
```
|
2 |
+
export BASE=clean_conv_20230809_100k_pii
|
3 |
+
export SCALE=10
|
4 |
+
|
5 |
+
# filter words
|
6 |
+
python3 filter_bad_conv.py --in $BASE.json
|
7 |
+
|
8 |
+
# Clean up some fileds (e.g., timestamps)
|
9 |
+
python3 final_post_processing.py --in $BASE.s1.json
|
10 |
+
|
11 |
+
# upload to hf
|
12 |
+
python3 upload_hf_dataset.py --in $BASE.s1.s2.json
|
13 |
+
|
14 |
+
# Make another version with openai moderation tag
|
15 |
+
python3 merge_oai_tag.py --in $BASE.s1.s2.json
|
16 |
+
|
17 |
+
# Make visualizations
|
18 |
+
python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
|
19 |
+
|
20 |
+
# Copy figures
|
21 |
+
scp "atlas:/data/lmzheng/FastChat/fastchat/serve/monitor/dataset_release_scripts/lmsys_chat_1m/*.pdf" .
|
22 |
+
```
|
23 |
+
|
monitor/dataset_release_scripts/lmsys_chat_1m/merge_oai_tag.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--in-file", type=str, required=True)
|
11 |
+
parser.add_argument("--sample", type=int)
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
tag_file = "clean_conv_20230809_1.5M_oai_filter_v2.json"
|
15 |
+
# tag_file = "clean_conv_20230809_1.5M_oai_filter_v2_100k.json"
|
16 |
+
in_file = args.in_file
|
17 |
+
tic = time.time()
|
18 |
+
|
19 |
+
# Load tags
|
20 |
+
print("Load tags...")
|
21 |
+
tag_data = json.load(open(tag_file))
|
22 |
+
tag_dict = {}
|
23 |
+
for c in tqdm(tag_data):
|
24 |
+
tag_dict[c["conversation_id"]] = [x["oai_filter"] for x in c["conversation"]]
|
25 |
+
print(f"elapsed: {time.time() - tic:.2f} s")
|
26 |
+
|
27 |
+
# Append to input_file
|
28 |
+
print("Load inputs...")
|
29 |
+
input_data = json.load(open(in_file))
|
30 |
+
for c in tqdm(input_data):
|
31 |
+
cid = c["conversation_id"]
|
32 |
+
if cid in tag_dict:
|
33 |
+
c["openai_moderation"] = tag_dict[cid]
|
34 |
+
else:
|
35 |
+
print(f"missing tag for conv {cid}")
|
36 |
+
exit()
|
37 |
+
print(f"elapsed: {time.time() - tic:.2f} s")
|
38 |
+
|
39 |
+
# Write output
|
40 |
+
print("Write outputs...")
|
41 |
+
out_file = in_file.replace(".json", ".with_tag.json")
|
42 |
+
print(f"Output to {out_file}")
|
43 |
+
with open(out_file, "w") as fout:
|
44 |
+
json.dump(input_data, fout, indent=2, ensure_ascii=False)
|
45 |
+
print(f"elapsed: {time.time() - tic:.2f} s")
|
monitor/dataset_release_scripts/lmsys_chat_1m/process_all.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export BASE=clean_conv_20230809_1.5M_pii
|
2 |
+
#export BASE=clean_conv_20230809_100k_pii
|
3 |
+
export SCALE=1
|
4 |
+
|
5 |
+
# Filter words
|
6 |
+
python3 filter_bad_conv.py --in $BASE.json --sample 1000000
|
7 |
+
|
8 |
+
# Clean up some fileds (e.g., timestamps)
|
9 |
+
python3 final_post_processing.py --in $BASE.s1.json
|
10 |
+
|
11 |
+
# Upload to hf
|
12 |
+
python3 upload_hf_dataset.py --in $BASE.s1.s2.json
|
13 |
+
|
14 |
+
# Make another version with openai moderation tag
|
15 |
+
python3 merge_oai_tag.py --in $BASE.s1.s2.json
|
16 |
+
|
17 |
+
# Make visualizations
|
18 |
+
python3 compute_stats.py --in $BASE.s1.json --scale $SCALE
|
monitor/dataset_release_scripts/lmsys_chat_1m/sample.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Count the unique users in a battle log file.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 -input in.json --number 1000
|
6 |
+
"""
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import json
|
10 |
+
import random
|
11 |
+
|
12 |
+
K = 1000
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument("--input", type=str)
|
17 |
+
parser.add_argument("--number", type=int, nargs="+")
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
convs = json.load(open(args.input))
|
21 |
+
random.seed(42)
|
22 |
+
random.shuffle(convs)
|
23 |
+
|
24 |
+
for number in args.number:
|
25 |
+
new_convs = convs[:number]
|
26 |
+
|
27 |
+
output = args.input.replace(".json", f"_{number//K}k.json")
|
28 |
+
with open(output, "w") as fout:
|
29 |
+
json.dump(new_convs, fout, indent=2, ensure_ascii=False)
|
30 |
+
|
31 |
+
print(f"#in: {len(convs)}, #out: {len(new_convs)}")
|
32 |
+
print(f"Write to file: {output}")
|
monitor/dataset_release_scripts/lmsys_chat_1m/upload_hf_dataset.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Upload to huggingface.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import json
|
6 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
7 |
+
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--in-file", type=str, required=True)
|
12 |
+
args = parser.parse_args()
|
13 |
+
|
14 |
+
objs = json.load(open(args.in_file))
|
15 |
+
print(f"#convs: {len(objs)}")
|
16 |
+
data = Dataset.from_list(objs)
|
17 |
+
data.push_to_hub("lmsys/lmsys-chat-1m", private=True)
|
monitor/elo_analysis.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from collections import defaultdict
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import math
|
6 |
+
import pickle
|
7 |
+
from pytz import timezone
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import plotly.express as px
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
from fastchat.model.model_registry import get_model_info
|
15 |
+
from fastchat.serve.monitor.basic_stats import get_log_files
|
16 |
+
from fastchat.serve.monitor.clean_battle_data import clean_battle_data
|
17 |
+
|
18 |
+
|
19 |
+
pd.options.display.float_format = "{:.2f}".format
|
20 |
+
|
21 |
+
|
22 |
+
def compute_elo(battles, K=4, SCALE=400, BASE=10, INIT_RATING=1000):
|
23 |
+
rating = defaultdict(lambda: INIT_RATING)
|
24 |
+
|
25 |
+
for rd, model_a, model_b, winner in battles[
|
26 |
+
["model_a", "model_b", "winner"]
|
27 |
+
].itertuples():
|
28 |
+
ra = rating[model_a]
|
29 |
+
rb = rating[model_b]
|
30 |
+
ea = 1 / (1 + BASE ** ((rb - ra) / SCALE))
|
31 |
+
eb = 1 / (1 + BASE ** ((ra - rb) / SCALE))
|
32 |
+
if winner == "model_a":
|
33 |
+
sa = 1
|
34 |
+
elif winner == "model_b":
|
35 |
+
sa = 0
|
36 |
+
elif winner == "tie" or winner == "tie (bothbad)":
|
37 |
+
sa = 0.5
|
38 |
+
else:
|
39 |
+
raise Exception(f"unexpected vote {winner}")
|
40 |
+
rating[model_a] += K * (sa - ea)
|
41 |
+
rating[model_b] += K * (1 - sa - eb)
|
42 |
+
|
43 |
+
return dict(rating)
|
44 |
+
|
45 |
+
|
46 |
+
def get_bootstrap_result(battles, func_compute_elo, num_round=1000):
|
47 |
+
rows = []
|
48 |
+
for i in tqdm(range(num_round), desc="bootstrap"):
|
49 |
+
tmp_battles = battles.sample(frac=1.0, replace=True)
|
50 |
+
rows.append(func_compute_elo(tmp_battles))
|
51 |
+
df = pd.DataFrame(rows)
|
52 |
+
return df[df.median().sort_values(ascending=False).index]
|
53 |
+
|
54 |
+
|
55 |
+
def get_median_elo_from_bootstrap(bootstrap_df):
|
56 |
+
median = dict(bootstrap_df.quantile(0.5))
|
57 |
+
median = {k: int(v + 0.5) for k, v in median.items()}
|
58 |
+
return median
|
59 |
+
|
60 |
+
|
61 |
+
def compute_pairwise_win_fraction(battles, model_order, limit_show_number=None):
|
62 |
+
# Times each model wins as Model A
|
63 |
+
a_win_ptbl = pd.pivot_table(
|
64 |
+
battles[battles["winner"] == "model_a"],
|
65 |
+
index="model_a",
|
66 |
+
columns="model_b",
|
67 |
+
aggfunc="size",
|
68 |
+
fill_value=0,
|
69 |
+
)
|
70 |
+
|
71 |
+
# Table counting times each model wins as Model B
|
72 |
+
b_win_ptbl = pd.pivot_table(
|
73 |
+
battles[battles["winner"] == "model_b"],
|
74 |
+
index="model_a",
|
75 |
+
columns="model_b",
|
76 |
+
aggfunc="size",
|
77 |
+
fill_value=0,
|
78 |
+
)
|
79 |
+
|
80 |
+
# Table counting number of A-B pairs
|
81 |
+
num_battles_ptbl = pd.pivot_table(
|
82 |
+
battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
|
83 |
+
)
|
84 |
+
|
85 |
+
# Computing the proportion of wins for each model as A and as B
|
86 |
+
# against all other models
|
87 |
+
row_beats_col_freq = (a_win_ptbl + b_win_ptbl.T) / (
|
88 |
+
num_battles_ptbl + num_battles_ptbl.T
|
89 |
+
)
|
90 |
+
|
91 |
+
if model_order is None:
|
92 |
+
prop_wins = row_beats_col_freq.mean(axis=1).sort_values(ascending=False)
|
93 |
+
model_order = list(prop_wins.keys())
|
94 |
+
|
95 |
+
if limit_show_number is not None:
|
96 |
+
model_order = model_order[:limit_show_number]
|
97 |
+
|
98 |
+
# Arrange ordering according to proprition of wins
|
99 |
+
row_beats_col = row_beats_col_freq.loc[model_order, model_order]
|
100 |
+
return row_beats_col
|
101 |
+
|
102 |
+
|
103 |
+
def visualize_leaderboard_table(rating):
|
104 |
+
models = list(rating.keys())
|
105 |
+
models.sort(key=lambda k: -rating[k])
|
106 |
+
|
107 |
+
emoji_dict = {
|
108 |
+
1: "🥇",
|
109 |
+
2: "🥈",
|
110 |
+
3: "🥉",
|
111 |
+
}
|
112 |
+
|
113 |
+
md = ""
|
114 |
+
md += "| Rank | Model | Elo Rating | Description |\n"
|
115 |
+
md += "| --- | --- | --- | --- |\n"
|
116 |
+
for i, model in enumerate(models):
|
117 |
+
rank = i + 1
|
118 |
+
minfo = get_model_info(model)
|
119 |
+
emoji = emoji_dict.get(rank, "")
|
120 |
+
md += f"| {rank} | {emoji} [{model}]({minfo.link}) | {rating[model]:.0f} | {minfo.description} |\n"
|
121 |
+
|
122 |
+
return md
|
123 |
+
|
124 |
+
|
125 |
+
def visualize_pairwise_win_fraction(battles, model_order):
|
126 |
+
row_beats_col = compute_pairwise_win_fraction(battles, model_order)
|
127 |
+
fig = px.imshow(
|
128 |
+
row_beats_col,
|
129 |
+
color_continuous_scale="RdBu",
|
130 |
+
text_auto=".2f",
|
131 |
+
height=700,
|
132 |
+
width=700,
|
133 |
+
)
|
134 |
+
fig.update_layout(
|
135 |
+
xaxis_title="Model B",
|
136 |
+
yaxis_title="Model A",
|
137 |
+
xaxis_side="top",
|
138 |
+
title_y=0.07,
|
139 |
+
title_x=0.5,
|
140 |
+
)
|
141 |
+
fig.update_traces(
|
142 |
+
hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Fraction of A Wins: %{z}<extra></extra>"
|
143 |
+
)
|
144 |
+
|
145 |
+
return fig
|
146 |
+
|
147 |
+
|
148 |
+
def visualize_battle_count(battles, model_order):
|
149 |
+
ptbl = pd.pivot_table(
|
150 |
+
battles, index="model_a", columns="model_b", aggfunc="size", fill_value=0
|
151 |
+
)
|
152 |
+
battle_counts = ptbl + ptbl.T
|
153 |
+
fig = px.imshow(
|
154 |
+
battle_counts.loc[model_order, model_order],
|
155 |
+
text_auto=True,
|
156 |
+
height=700,
|
157 |
+
width=700,
|
158 |
+
)
|
159 |
+
fig.update_layout(
|
160 |
+
xaxis_title="Model B",
|
161 |
+
yaxis_title="Model A",
|
162 |
+
xaxis_side="top",
|
163 |
+
title_y=0.07,
|
164 |
+
title_x=0.5,
|
165 |
+
)
|
166 |
+
fig.update_traces(
|
167 |
+
hovertemplate="Model A: %{y}<br>Model B: %{x}<br>Count: %{z}<extra></extra>"
|
168 |
+
)
|
169 |
+
return fig
|
170 |
+
|
171 |
+
|
172 |
+
def visualize_average_win_rate(battles, limit_show_number):
|
173 |
+
row_beats_col_freq = compute_pairwise_win_fraction(
|
174 |
+
battles, None, limit_show_number=limit_show_number
|
175 |
+
)
|
176 |
+
fig = px.bar(
|
177 |
+
row_beats_col_freq.mean(axis=1).sort_values(ascending=False),
|
178 |
+
text_auto=".2f",
|
179 |
+
height=500,
|
180 |
+
width=700,
|
181 |
+
)
|
182 |
+
fig.update_layout(
|
183 |
+
yaxis_title="Average Win Rate", xaxis_title="Model", showlegend=False
|
184 |
+
)
|
185 |
+
return fig
|
186 |
+
|
187 |
+
|
188 |
+
def visualize_bootstrap_elo_rating(df, limit_show_number):
|
189 |
+
bars = (
|
190 |
+
pd.DataFrame(
|
191 |
+
dict(
|
192 |
+
lower=df.quantile(0.025),
|
193 |
+
rating=df.quantile(0.5),
|
194 |
+
upper=df.quantile(0.975),
|
195 |
+
)
|
196 |
+
)
|
197 |
+
.reset_index(names="model")
|
198 |
+
.sort_values("rating", ascending=False)
|
199 |
+
)
|
200 |
+
bars = bars[:limit_show_number]
|
201 |
+
bars["error_y"] = bars["upper"] - bars["rating"]
|
202 |
+
bars["error_y_minus"] = bars["rating"] - bars["lower"]
|
203 |
+
bars["rating_rounded"] = np.round(bars["rating"], 2)
|
204 |
+
fig = px.scatter(
|
205 |
+
bars,
|
206 |
+
x="model",
|
207 |
+
y="rating",
|
208 |
+
error_y="error_y",
|
209 |
+
error_y_minus="error_y_minus",
|
210 |
+
text="rating_rounded",
|
211 |
+
height=500,
|
212 |
+
width=700,
|
213 |
+
)
|
214 |
+
fig.update_layout(xaxis_title="Model", yaxis_title="Rating")
|
215 |
+
return fig
|
216 |
+
|
217 |
+
|
218 |
+
def report_elo_analysis_results(battles_json):
|
219 |
+
battles = pd.DataFrame(battles_json)
|
220 |
+
battles = battles.sort_values(ascending=True, by=["tstamp"])
|
221 |
+
# Only use anonymous votes
|
222 |
+
battles = battles[battles["anony"]].reset_index(drop=True)
|
223 |
+
battles_no_ties = battles[~battles["winner"].str.contains("tie")]
|
224 |
+
|
225 |
+
# Online update
|
226 |
+
elo_rating_online = compute_elo(battles)
|
227 |
+
|
228 |
+
# Bootstrap
|
229 |
+
bootstrap_df = get_bootstrap_result(battles, compute_elo)
|
230 |
+
elo_rating_median = get_median_elo_from_bootstrap(bootstrap_df)
|
231 |
+
model_order = list(elo_rating_median.keys())
|
232 |
+
model_order.sort(key=lambda k: -elo_rating_median[k])
|
233 |
+
|
234 |
+
limit_show_number = 25 # limit show number to make plots smaller
|
235 |
+
model_order = model_order[:limit_show_number]
|
236 |
+
|
237 |
+
# Plots
|
238 |
+
leaderboard_table = visualize_leaderboard_table(elo_rating_median)
|
239 |
+
win_fraction_heatmap = visualize_pairwise_win_fraction(battles_no_ties, model_order)
|
240 |
+
battle_count_heatmap = visualize_battle_count(battles_no_ties, model_order)
|
241 |
+
average_win_rate_bar = visualize_average_win_rate(
|
242 |
+
battles_no_ties, limit_show_number
|
243 |
+
)
|
244 |
+
bootstrap_elo_rating = visualize_bootstrap_elo_rating(
|
245 |
+
bootstrap_df, limit_show_number
|
246 |
+
)
|
247 |
+
|
248 |
+
last_updated_tstamp = battles["tstamp"].max()
|
249 |
+
last_updated_datetime = datetime.datetime.fromtimestamp(
|
250 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
251 |
+
).strftime("%Y-%m-%d %H:%M:%S %Z")
|
252 |
+
|
253 |
+
return {
|
254 |
+
"elo_rating_online": elo_rating_online,
|
255 |
+
"elo_rating_median": elo_rating_median,
|
256 |
+
"leaderboard_table": leaderboard_table,
|
257 |
+
"win_fraction_heatmap": win_fraction_heatmap,
|
258 |
+
"battle_count_heatmap": battle_count_heatmap,
|
259 |
+
"average_win_rate_bar": average_win_rate_bar,
|
260 |
+
"bootstrap_elo_rating": bootstrap_elo_rating,
|
261 |
+
"last_updated_datetime": last_updated_datetime,
|
262 |
+
"last_updated_tstamp": last_updated_tstamp,
|
263 |
+
}
|
264 |
+
|
265 |
+
|
266 |
+
def pretty_print_elo_rating(rating):
|
267 |
+
model_order = list(rating.keys())
|
268 |
+
model_order.sort(key=lambda k: -rating[k])
|
269 |
+
for i, model in enumerate(model_order):
|
270 |
+
print(f"{i+1:2d}, {model:25s}, {rating[model]:.0f}")
|
271 |
+
|
272 |
+
|
273 |
+
if __name__ == "__main__":
|
274 |
+
parser = argparse.ArgumentParser()
|
275 |
+
parser.add_argument("--clean-battle-file", type=str)
|
276 |
+
parser.add_argument("--max-num-files", type=int)
|
277 |
+
args = parser.parse_args()
|
278 |
+
|
279 |
+
np.random.seed(42)
|
280 |
+
|
281 |
+
if args.clean_battle_file:
|
282 |
+
# Read data from a cleaned battle files
|
283 |
+
battles = pd.read_json(args.clean_battle_file)
|
284 |
+
else:
|
285 |
+
# Read data from all log files
|
286 |
+
log_files = get_log_files(args.max_num_files)
|
287 |
+
battles = clean_battle_data(log_files)
|
288 |
+
|
289 |
+
results = report_elo_analysis_results(battles)
|
290 |
+
|
291 |
+
print("# Online")
|
292 |
+
pretty_print_elo_rating(results["elo_rating_online"])
|
293 |
+
print("# Median")
|
294 |
+
pretty_print_elo_rating(results["elo_rating_median"])
|
295 |
+
print(f"last update : {results['last_updated_datetime']}")
|
296 |
+
|
297 |
+
last_updated_tstamp = results["last_updated_tstamp"]
|
298 |
+
cutoff_date = datetime.datetime.fromtimestamp(
|
299 |
+
last_updated_tstamp, tz=timezone("US/Pacific")
|
300 |
+
).strftime("%Y%m%d")
|
301 |
+
|
302 |
+
with open(f"elo_results_{cutoff_date}.pkl", "wb") as fout:
|
303 |
+
pickle.dump(results, fout)
|
monitor/inspect_conv.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import code
|
3 |
+
import datetime
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
from pytz import timezone
|
7 |
+
import time
|
8 |
+
|
9 |
+
import pandas as pd
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
|
13 |
+
def get_log_files(max_num_files=None):
|
14 |
+
dates = []
|
15 |
+
for month in [4, 5]:
|
16 |
+
for day in range(1, 32):
|
17 |
+
dates.append(f"2023-{month:02d}-{day:02d}")
|
18 |
+
|
19 |
+
num_servers = 14
|
20 |
+
filenames = []
|
21 |
+
for d in dates:
|
22 |
+
for i in range(num_servers):
|
23 |
+
name = os.path.expanduser(f"~/fastchat_logs/server{i}/{d}-conv.json")
|
24 |
+
if os.path.exists(name):
|
25 |
+
filenames.append(name)
|
26 |
+
max_num_files = max_num_files or len(filenames)
|
27 |
+
filenames = filenames[-max_num_files:]
|
28 |
+
return filenames
|
29 |
+
|
30 |
+
|
31 |
+
def pretty_print_conversation(messages):
|
32 |
+
for role, msg in messages:
|
33 |
+
print(f"[[{role}]]: {msg}")
|
34 |
+
|
35 |
+
|
36 |
+
def inspect_convs(log_files):
|
37 |
+
data = []
|
38 |
+
for filename in tqdm(log_files, desc="read files"):
|
39 |
+
for retry in range(5):
|
40 |
+
try:
|
41 |
+
lines = open(filename).readlines()
|
42 |
+
break
|
43 |
+
except FileNotFoundError:
|
44 |
+
time.sleep(2)
|
45 |
+
|
46 |
+
for l in lines:
|
47 |
+
row = json.loads(l)
|
48 |
+
|
49 |
+
if "states" not in row:
|
50 |
+
continue
|
51 |
+
if row["type"] not in ["leftvote", "rightvote", "bothbad_vote"]:
|
52 |
+
continue
|
53 |
+
|
54 |
+
model_names = row["states"][0]["model_name"], row["states"][1]["model_name"]
|
55 |
+
if row["type"] == "leftvote":
|
56 |
+
winner, loser = model_names[0], model_names[1]
|
57 |
+
winner_conv, loser_conv = row["states"][0], row["states"][1]
|
58 |
+
elif row["type"] == "rightvote":
|
59 |
+
loser, winner = model_names[0], model_names[1]
|
60 |
+
loser_conv, winner_conv = row["states"][0], row["states"][1]
|
61 |
+
|
62 |
+
if loser == "bard" and winner == "vicuna-13b":
|
63 |
+
print("=" * 20)
|
64 |
+
print(f"Winner: {winner}")
|
65 |
+
pretty_print_conversation(winner_conv["messages"])
|
66 |
+
print(f"Loser: {loser}")
|
67 |
+
pretty_print_conversation(loser_conv["messages"])
|
68 |
+
print("=" * 20)
|
69 |
+
input()
|
70 |
+
|
71 |
+
# if row["type"] == "bothbad_vote" and "gpt-4" in model_names:
|
72 |
+
# print("=" * 20)
|
73 |
+
# print(f"Model A: {model_names[0]}")
|
74 |
+
# pretty_print_conversation(row["states"][0]["messages"])
|
75 |
+
# print(f"Model B: {model_names[1]}")
|
76 |
+
# pretty_print_conversation(row["states"][1]["messages"])
|
77 |
+
# print("=" * 20)
|
78 |
+
# input()
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
parser = argparse.ArgumentParser()
|
83 |
+
parser.add_argument("--max-num-files", type=int)
|
84 |
+
args = parser.parse_args()
|
85 |
+
|
86 |
+
log_files = get_log_files(args.max_num_files)
|
87 |
+
inspect_convs(log_files)
|
monitor/intersect_conv_file.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Take the intersection of two conversation files.
|
3 |
+
|
4 |
+
Usage: python3 -m fastchat.data.merge --input input.json --conv-id conv_id_file.json --out intersect.json
|
5 |
+
"""
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import json
|
9 |
+
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument("--input", type=str, required=True)
|
14 |
+
parser.add_argument("--conv-id", type=str, required=True)
|
15 |
+
parser.add_argument("--out-file", type=str, default="intersect.json")
|
16 |
+
args = parser.parse_args()
|
17 |
+
|
18 |
+
conv_id_objs = json.load(open(args.conv_id, "r"))
|
19 |
+
conv_ids = set(x["conversation_id"] for x in conv_id_objs)
|
20 |
+
|
21 |
+
objs = json.load(open(args.input, "r"))
|
22 |
+
after_objs = [x for x in objs if x["conversation_id"] in conv_ids]
|
23 |
+
|
24 |
+
print(f"#in: {len(objs)}, #out: {len(after_objs)}")
|
25 |
+
json.dump(after_objs, open(args.out_file, "w"), indent=2, ensure_ascii=False)
|
monitor/leaderboard_csv_to_html.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Convert a leaderboard csv file to html table used in the blog.
|
3 |
+
|
4 |
+
Usage:
|
5 |
+
python3 leaderboard_csv_to_html.py --in leaderboard_table_20230619.csv
|
6 |
+
"""
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from fastchat.serve.monitor.monitor import load_leaderboard_table_csv
|
12 |
+
|
13 |
+
|
14 |
+
def model_hyperlink(model_name, link):
|
15 |
+
return f'<a target="_blank" href="{link}"> {model_name} </a>'
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
parser = argparse.ArgumentParser()
|
20 |
+
parser.add_argument("--input", type=str, required=True)
|
21 |
+
args = parser.parse_args()
|
22 |
+
|
23 |
+
data = load_leaderboard_table_csv(args.input, add_hyperlink=False)
|
24 |
+
headers = [
|
25 |
+
"Model",
|
26 |
+
"MT-bench (score)",
|
27 |
+
"Arena Elo rating",
|
28 |
+
"MMLU",
|
29 |
+
"License",
|
30 |
+
]
|
31 |
+
values = []
|
32 |
+
for item in data:
|
33 |
+
row = []
|
34 |
+
for key in headers:
|
35 |
+
value = item[key]
|
36 |
+
row.append(value)
|
37 |
+
row[0] = model_hyperlink(item["Model"], item["Link"])
|
38 |
+
values.append(row)
|
39 |
+
values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
|
40 |
+
|
41 |
+
for value in values:
|
42 |
+
row = "<tr>"
|
43 |
+
for x in value:
|
44 |
+
try:
|
45 |
+
if np.isnan(x):
|
46 |
+
x = "-"
|
47 |
+
except TypeError:
|
48 |
+
pass
|
49 |
+
row += f" <td>{x}</td> "
|
50 |
+
row += "</tr>"
|
51 |
+
print(row)
|
monitor/monitor.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Live monitor of the website statistics and leaderboard.
|
3 |
+
|
4 |
+
Dependency:
|
5 |
+
sudo apt install pkg-config libicu-dev
|
6 |
+
pip install pytz gradio gdown plotly polyglot pyicu pycld2 tabulate
|
7 |
+
"""
|
8 |
+
|
9 |
+
import argparse
|
10 |
+
import ast
|
11 |
+
import pickle
|
12 |
+
import os
|
13 |
+
import threading
|
14 |
+
import time
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
from fastchat.serve.monitor.basic_stats import report_basic_stats, get_log_files
|
20 |
+
from fastchat.serve.monitor.clean_battle_data import clean_battle_data
|
21 |
+
from fastchat.serve.monitor.elo_analysis import report_elo_analysis_results
|
22 |
+
from fastchat.utils import build_logger, get_window_url_params_js
|
23 |
+
|
24 |
+
|
25 |
+
notebook_url = "https://colab.research.google.com/drive/1RAWb22-PFNI-X1gPVzc927SGUdfr6nsR?usp=sharing"
|
26 |
+
|
27 |
+
|
28 |
+
basic_component_values = [None] * 6
|
29 |
+
leader_component_values = [None] * 5
|
30 |
+
|
31 |
+
|
32 |
+
def make_leaderboard_md(elo_results):
|
33 |
+
leaderboard_md = f"""
|
34 |
+
# 🏆 Chatbot Arena Leaderboard
|
35 |
+
| [Blog](https://lmsys.org/blog/2023-05-03-arena/) | [GitHub](https://github.com/lm-sys/FastChat) | [Paper](https://arxiv.org/abs/2306.05685) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
|
36 |
+
|
37 |
+
This leaderboard is based on the following three benchmarks.
|
38 |
+
- [Chatbot Arena](https://lmsys.org/blog/2023-05-03-arena/) - a crowdsourced, randomized battle platform. We use 100K+ user votes to compute Elo ratings.
|
39 |
+
- [MT-Bench](https://arxiv.org/abs/2306.05685) - a set of challenging multi-turn questions. We use GPT-4 to grade the model responses.
|
40 |
+
- [MMLU](https://arxiv.org/abs/2009.03300) (5-shot) - a test to measure a model's multitask accuracy on 57 tasks.
|
41 |
+
|
42 |
+
💻 Code: The Arena Elo ratings are computed by this [notebook]({notebook_url}). The MT-bench scores (single-answer grading on a scale of 10) are computed by [fastchat.llm_judge](https://github.com/lm-sys/FastChat/tree/main/fastchat/llm_judge). The MMLU scores are mostly computed by [InstructEval](https://github.com/declare-lab/instruct-eval). Higher values are better for all benchmarks. Empty cells mean not available. Last updated: November, 2023.
|
43 |
+
"""
|
44 |
+
return leaderboard_md
|
45 |
+
|
46 |
+
|
47 |
+
def make_leaderboard_md_live(elo_results):
|
48 |
+
leaderboard_md = f"""
|
49 |
+
# Leaderboard
|
50 |
+
Last updated: {elo_results["last_updated_datetime"]}
|
51 |
+
{elo_results["leaderboard_table"]}
|
52 |
+
"""
|
53 |
+
return leaderboard_md
|
54 |
+
|
55 |
+
|
56 |
+
def update_elo_components(max_num_files, elo_results_file):
|
57 |
+
log_files = get_log_files(max_num_files)
|
58 |
+
|
59 |
+
# Leaderboard
|
60 |
+
if elo_results_file is None: # Do live update
|
61 |
+
battles = clean_battle_data(log_files, [])
|
62 |
+
elo_results = report_elo_analysis_results(battles)
|
63 |
+
|
64 |
+
leader_component_values[0] = make_leaderboard_md_live(elo_results)
|
65 |
+
leader_component_values[1] = elo_results["win_fraction_heatmap"]
|
66 |
+
leader_component_values[2] = elo_results["battle_count_heatmap"]
|
67 |
+
leader_component_values[3] = elo_results["bootstrap_elo_rating"]
|
68 |
+
leader_component_values[4] = elo_results["average_win_rate_bar"]
|
69 |
+
|
70 |
+
# Basic stats
|
71 |
+
basic_stats = report_basic_stats(log_files)
|
72 |
+
md0 = f"Last updated: {basic_stats['last_updated_datetime']}"
|
73 |
+
|
74 |
+
md1 = "### Action Histogram\n"
|
75 |
+
md1 += basic_stats["action_hist_md"] + "\n"
|
76 |
+
|
77 |
+
md2 = "### Anony. Vote Histogram\n"
|
78 |
+
md2 += basic_stats["anony_vote_hist_md"] + "\n"
|
79 |
+
|
80 |
+
md3 = "### Model Call Histogram\n"
|
81 |
+
md3 += basic_stats["model_hist_md"] + "\n"
|
82 |
+
|
83 |
+
md4 = "### Model Call (Last 24 Hours)\n"
|
84 |
+
md4 += basic_stats["num_chats_last_24_hours"] + "\n"
|
85 |
+
|
86 |
+
basic_component_values[0] = md0
|
87 |
+
basic_component_values[1] = basic_stats["chat_dates_bar"]
|
88 |
+
basic_component_values[2] = md1
|
89 |
+
basic_component_values[3] = md2
|
90 |
+
basic_component_values[4] = md3
|
91 |
+
basic_component_values[5] = md4
|
92 |
+
|
93 |
+
|
94 |
+
def update_worker(max_num_files, interval, elo_results_file):
|
95 |
+
while True:
|
96 |
+
tic = time.time()
|
97 |
+
update_elo_components(max_num_files, elo_results_file)
|
98 |
+
durtaion = time.time() - tic
|
99 |
+
print(f"update duration: {durtaion:.2f} s")
|
100 |
+
time.sleep(max(interval - durtaion, 0))
|
101 |
+
|
102 |
+
|
103 |
+
def load_demo(url_params, request: gr.Request):
|
104 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
105 |
+
return basic_component_values + leader_component_values
|
106 |
+
|
107 |
+
|
108 |
+
def model_hyperlink(model_name, link):
|
109 |
+
return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_name}</a>'
|
110 |
+
|
111 |
+
|
112 |
+
def load_leaderboard_table_csv(filename, add_hyperlink=True):
|
113 |
+
lines = open(filename).readlines()
|
114 |
+
heads = [v.strip() for v in lines[0].split(",")]
|
115 |
+
rows = []
|
116 |
+
for i in range(1, len(lines)):
|
117 |
+
row = [v.strip() for v in lines[i].split(",")]
|
118 |
+
for j in range(len(heads)):
|
119 |
+
item = {}
|
120 |
+
for h, v in zip(heads, row):
|
121 |
+
if h == "Arena Elo rating":
|
122 |
+
if v != "-":
|
123 |
+
v = int(ast.literal_eval(v))
|
124 |
+
else:
|
125 |
+
v = np.nan
|
126 |
+
elif h == "MMLU":
|
127 |
+
if v != "-":
|
128 |
+
v = round(ast.literal_eval(v) * 100, 1)
|
129 |
+
else:
|
130 |
+
v = np.nan
|
131 |
+
elif h == "MT-bench (win rate %)":
|
132 |
+
if v != "-":
|
133 |
+
v = round(ast.literal_eval(v[:-1]), 1)
|
134 |
+
else:
|
135 |
+
v = np.nan
|
136 |
+
elif h == "MT-bench (score)":
|
137 |
+
if v != "-":
|
138 |
+
v = round(ast.literal_eval(v), 2)
|
139 |
+
else:
|
140 |
+
v = np.nan
|
141 |
+
item[h] = v
|
142 |
+
if add_hyperlink:
|
143 |
+
item["Model"] = model_hyperlink(item["Model"], item["Link"])
|
144 |
+
rows.append(item)
|
145 |
+
|
146 |
+
return rows
|
147 |
+
|
148 |
+
|
149 |
+
def build_basic_stats_tab():
|
150 |
+
empty = "Loading ..."
|
151 |
+
basic_component_values[:] = [empty, None, empty, empty, empty, empty]
|
152 |
+
|
153 |
+
md0 = gr.Markdown(empty)
|
154 |
+
gr.Markdown("#### Figure 1: Number of model calls and votes")
|
155 |
+
plot_1 = gr.Plot(show_label=False)
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column():
|
158 |
+
md1 = gr.Markdown(empty)
|
159 |
+
with gr.Column():
|
160 |
+
md2 = gr.Markdown(empty)
|
161 |
+
with gr.Row():
|
162 |
+
with gr.Column():
|
163 |
+
md3 = gr.Markdown(empty)
|
164 |
+
with gr.Column():
|
165 |
+
md4 = gr.Markdown(empty)
|
166 |
+
return [md0, plot_1, md1, md2, md3, md4]
|
167 |
+
|
168 |
+
|
169 |
+
def build_leaderboard_tab(elo_results_file, leaderboard_table_file):
|
170 |
+
if elo_results_file is None: # Do live update
|
171 |
+
md = "Loading ..."
|
172 |
+
p1 = p2 = p3 = p4 = None
|
173 |
+
else:
|
174 |
+
with open(elo_results_file, "rb") as fin:
|
175 |
+
elo_results = pickle.load(fin)
|
176 |
+
|
177 |
+
md = make_leaderboard_md(elo_results)
|
178 |
+
p1 = elo_results["win_fraction_heatmap"]
|
179 |
+
p2 = elo_results["battle_count_heatmap"]
|
180 |
+
p3 = elo_results["bootstrap_elo_rating"]
|
181 |
+
p4 = elo_results["average_win_rate_bar"]
|
182 |
+
|
183 |
+
md_1 = gr.Markdown(md, elem_id="leaderboard_markdown")
|
184 |
+
|
185 |
+
if leaderboard_table_file:
|
186 |
+
data = load_leaderboard_table_csv(leaderboard_table_file)
|
187 |
+
headers = [
|
188 |
+
"Model",
|
189 |
+
"Arena Elo rating",
|
190 |
+
"MT-bench (score)",
|
191 |
+
"MMLU",
|
192 |
+
"License",
|
193 |
+
]
|
194 |
+
values = []
|
195 |
+
for item in data:
|
196 |
+
row = []
|
197 |
+
for key in headers:
|
198 |
+
value = item[key]
|
199 |
+
row.append(value)
|
200 |
+
values.append(row)
|
201 |
+
values.sort(key=lambda x: -x[1] if not np.isnan(x[1]) else 1e9)
|
202 |
+
|
203 |
+
headers[1] = "⭐ " + headers[1]
|
204 |
+
headers[2] = "📈 " + headers[2]
|
205 |
+
|
206 |
+
gr.Dataframe(
|
207 |
+
headers=headers,
|
208 |
+
datatype=["markdown", "number", "number", "number", "str"],
|
209 |
+
value=values,
|
210 |
+
elem_id="leaderboard_dataframe",
|
211 |
+
)
|
212 |
+
gr.Markdown(
|
213 |
+
""" ## Visit our [HF space](https://huggingface.co/spaces/lmsys/chatbot-arena-leaderboard) for more analysis!
|
214 |
+
If you want to see more models, please help us [add them](https://github.com/lm-sys/FastChat/blob/main/docs/arena.md#how-to-add-a-new-model).
|
215 |
+
""",
|
216 |
+
elem_id="leaderboard_markdown",
|
217 |
+
)
|
218 |
+
else:
|
219 |
+
pass
|
220 |
+
|
221 |
+
leader_component_values[:] = [md, p1, p2, p3, p4]
|
222 |
+
|
223 |
+
"""
|
224 |
+
with gr.Row():
|
225 |
+
with gr.Column():
|
226 |
+
gr.Markdown(
|
227 |
+
"#### Figure 1: Fraction of Model A Wins for All Non-tied A vs. B Battles"
|
228 |
+
)
|
229 |
+
plot_1 = gr.Plot(p1, show_label=False)
|
230 |
+
with gr.Column():
|
231 |
+
gr.Markdown(
|
232 |
+
"#### Figure 2: Battle Count for Each Combination of Models (without Ties)"
|
233 |
+
)
|
234 |
+
plot_2 = gr.Plot(p2, show_label=False)
|
235 |
+
with gr.Row():
|
236 |
+
with gr.Column():
|
237 |
+
gr.Markdown(
|
238 |
+
"#### Figure 3: Bootstrap of Elo Estimates (1000 Rounds of Random Sampling)"
|
239 |
+
)
|
240 |
+
plot_3 = gr.Plot(p3, show_label=False)
|
241 |
+
with gr.Column():
|
242 |
+
gr.Markdown(
|
243 |
+
"#### Figure 4: Average Win Rate Against All Other Models (Assuming Uniform Sampling and No Ties)"
|
244 |
+
)
|
245 |
+
plot_4 = gr.Plot(p4, show_label=False)
|
246 |
+
"""
|
247 |
+
|
248 |
+
from fastchat.serve.gradio_web_server import acknowledgment_md
|
249 |
+
|
250 |
+
gr.Markdown(acknowledgment_md)
|
251 |
+
|
252 |
+
# return [md_1, plot_1, plot_2, plot_3, plot_4]
|
253 |
+
return [md_1]
|
254 |
+
|
255 |
+
|
256 |
+
def build_demo(elo_results_file, leaderboard_table_file):
|
257 |
+
from fastchat.serve.gradio_web_server import block_css
|
258 |
+
|
259 |
+
text_size = gr.themes.sizes.text_lg
|
260 |
+
|
261 |
+
with gr.Blocks(
|
262 |
+
title="Monitor",
|
263 |
+
theme=gr.themes.Base(text_size=text_size),
|
264 |
+
css=block_css,
|
265 |
+
) as demo:
|
266 |
+
with gr.Tabs() as tabs:
|
267 |
+
with gr.Tab("Leaderboard", id=0):
|
268 |
+
leader_components = build_leaderboard_tab(
|
269 |
+
elo_results_file, leaderboard_table_file
|
270 |
+
)
|
271 |
+
|
272 |
+
with gr.Tab("Basic Stats", id=1):
|
273 |
+
basic_components = build_basic_stats_tab()
|
274 |
+
|
275 |
+
url_params = gr.JSON(visible=False)
|
276 |
+
demo.load(
|
277 |
+
load_demo,
|
278 |
+
[url_params],
|
279 |
+
basic_components + leader_components,
|
280 |
+
_js=get_window_url_params_js,
|
281 |
+
)
|
282 |
+
|
283 |
+
return demo
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
parser = argparse.ArgumentParser()
|
288 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
289 |
+
parser.add_argument("--port", type=int)
|
290 |
+
parser.add_argument("--share", action="store_true")
|
291 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
292 |
+
parser.add_argument("--update-interval", type=int, default=300)
|
293 |
+
parser.add_argument("--max-num-files", type=int)
|
294 |
+
parser.add_argument("--elo-results-file", type=str)
|
295 |
+
parser.add_argument("--leaderboard-table-file", type=str)
|
296 |
+
args = parser.parse_args()
|
297 |
+
|
298 |
+
logger = build_logger("monitor", "monitor.log")
|
299 |
+
logger.info(f"args: {args}")
|
300 |
+
|
301 |
+
if args.elo_results_file is None: # Do live update
|
302 |
+
update_thread = threading.Thread(
|
303 |
+
target=update_worker,
|
304 |
+
args=(args.max_num_files, args.update_interval, args.elo_results_file),
|
305 |
+
)
|
306 |
+
update_thread.start()
|
307 |
+
|
308 |
+
demo = build_demo(args.elo_results_file, args.leaderboard_table_file)
|
309 |
+
demo.queue(
|
310 |
+
concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
|
311 |
+
).launch(
|
312 |
+
server_name=args.host, server_port=args.port, share=args.share, max_threads=200
|
313 |
+
)
|