Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
from uuid import uuid4 | |
import requests | |
from pydantic import BaseModel | |
from typing import List | |
import ast | |
import os | |
import sseclient | |
QUEUE_MAX_SIZE = int(os.getenv("QUEUE_MAX_SIZE", 20)) | |
QUEUE_CONCURENCY_COUNT = int(os.getenv("QUEUE_CONCURENCY_COUNT", 10)) | |
USERNAME = os.getenv("USERNAME") | |
PASSWORD = os.getenv("PASSWORD") | |
CHATBOT_ENDPOINT = os.getenv("CHATBOT_ENDPOINT", "http://localhost:5000") | |
class LearningBotRequest(BaseModel): | |
message: List[dict] | |
persona: str | |
session_id: str | |
context: dict | |
user_serial: str | |
stream: bool | |
def generate_uuid(): | |
return str(uuid4()) | |
def construct_message(list_message): | |
messages = [] | |
for i, pair_message in enumerate(list_message): | |
if len(pair_message) < 2: | |
continue | |
content_user = {"human": pair_message[0]} | |
content_human = {"AI": pair_message[1]} | |
messages.append(content_user) | |
messages.append(content_human) | |
return messages | |
def send_message(url, request): | |
return requests.post(url, stream=True, data=request.json()) | |
def respond(chat_history, message, history, session_id, user_serial, persona, context, endpoint): | |
if history is None: | |
history = [] | |
history.append(message) | |
if session_id is None: | |
session_id = generate_uuid() | |
context = ast.literal_eval(context) | |
messages = construct_message(chat_history) | |
messages.append( | |
{"human": message} | |
) | |
request = LearningBotRequest( | |
message=messages, | |
persona=persona, | |
session_id=session_id, | |
context=context, | |
user_serial=user_serial, | |
stream=True, | |
product="learning_companion" | |
) | |
response = send_message(endpoint, request) | |
token_counter = 0 | |
partial_reply = "" | |
client = sseclient.SSEClient(response) | |
for event in client.events(): | |
data = json.loads(event.data)["data"] | |
partial_reply = partial_reply + data["data"]["reply"] | |
if token_counter == 0: | |
history.append(" " + partial_reply) | |
else: | |
history[-1] = partial_reply | |
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] # convert to tuples of list | |
token_counter += 1 | |
yield chat, history, "Success", session_id, session_id | |
def reset_textbox(): | |
return gr.update(value='') | |
with gr.Blocks() as demo: | |
session_id = gr.State(value=generate_uuid()) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
clear = gr.Button("Clear all converstation") | |
with gr.Column(scale=5): | |
endpoint = gr.Textbox(label="Endpoint API", value=CHATBOT_ENDPOINT) | |
with gr.Accordion("Parameters", open=False): | |
user_serial = gr.Textbox(label="User serial") | |
context = gr.Textbox(label="context", value={}) | |
persona = gr.Textbox(label="persona", value="a493700848d84d0dab8d0095c2477c1e") | |
chatbot = gr.Chatbot() | |
message = gr.Textbox(placeholder="Halo kak, aku mau bertanya", label="Chat Here") | |
state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
send = gr.Button("Send") | |
with gr.Column(scale=5): | |
status_box = gr.Textbox(label="Status code from OpenAI server") | |
session = gr.Textbox(label="session_id") | |
message.submit(respond, [chatbot, message, state, session_id, user_serial, persona, context, endpoint], [chatbot, state, status_box, session, session_id]) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
clear.click(lambda: None, None, session_id, queue=False) | |
clear.click(lambda: None, None, state, queue=False) | |
send.click(respond, [chatbot, message, state, session_id, user_serial, persona, context, endpoint], [chatbot, state, status_box, session, session_id]) | |
send.click(reset_textbox, [], [message]) | |
message.submit(reset_textbox, [], [message]) | |
( | |
demo | |
.queue(max_size=QUEUE_MAX_SIZE, concurrency_count=QUEUE_CONCURENCY_COUNT) | |
.launch(auth=(USERNAME, PASSWORD), debug=True) | |
) | |