maulana-m's picture
reset chat history when clear
f41b192
import gradio as gr
import json
from uuid import uuid4
import requests
from pydantic import BaseModel
from typing import List
import ast
import os
import sseclient
QUEUE_MAX_SIZE = int(os.getenv("QUEUE_MAX_SIZE", 20))
QUEUE_CONCURENCY_COUNT = int(os.getenv("QUEUE_CONCURENCY_COUNT", 10))
USERNAME = os.getenv("USERNAME")
PASSWORD = os.getenv("PASSWORD")
CHATBOT_ENDPOINT = os.getenv("CHATBOT_ENDPOINT", "http://localhost:5000")
class LearningBotRequest(BaseModel):
message: List[dict]
persona: str
session_id: str
context: dict
user_serial: str
stream: bool
def generate_uuid():
return str(uuid4())
def construct_message(list_message):
messages = []
for i, pair_message in enumerate(list_message):
if len(pair_message) < 2:
continue
content_user = {"human": pair_message[0]}
content_human = {"AI": pair_message[1]}
messages.append(content_user)
messages.append(content_human)
return messages
def send_message(url, request):
return requests.post(url, stream=True, data=request.json())
def respond(chat_history, message, history, session_id, user_serial, persona, context, endpoint):
if history is None:
history = []
history.append(message)
if session_id is None:
session_id = generate_uuid()
context = ast.literal_eval(context)
messages = construct_message(chat_history)
messages.append(
{"human": message}
)
request = LearningBotRequest(
message=messages,
persona=persona,
session_id=session_id,
context=context,
user_serial=user_serial,
stream=True,
product="learning_companion"
)
response = send_message(endpoint, request)
token_counter = 0
partial_reply = ""
client = sseclient.SSEClient(response)
for event in client.events():
data = json.loads(event.data)["data"]
partial_reply = partial_reply + data["data"]["reply"]
if token_counter == 0:
history.append(" " + partial_reply)
else:
history[-1] = partial_reply
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] # convert to tuples of list
token_counter += 1
yield chat, history, "Success", session_id, session_id
def reset_textbox():
return gr.update(value='')
with gr.Blocks() as demo:
session_id = gr.State(value=generate_uuid())
with gr.Row():
with gr.Column(scale=5):
clear = gr.Button("Clear all converstation")
with gr.Column(scale=5):
endpoint = gr.Textbox(label="Endpoint API", value=CHATBOT_ENDPOINT)
with gr.Accordion("Parameters", open=False):
user_serial = gr.Textbox(label="User serial")
context = gr.Textbox(label="context", value={})
persona = gr.Textbox(label="persona", value="a493700848d84d0dab8d0095c2477c1e")
chatbot = gr.Chatbot()
message = gr.Textbox(placeholder="Halo kak, aku mau bertanya", label="Chat Here")
state = gr.State([])
with gr.Row():
with gr.Column(scale=5):
send = gr.Button("Send")
with gr.Column(scale=5):
status_box = gr.Textbox(label="Status code from OpenAI server")
session = gr.Textbox(label="session_id")
message.submit(respond, [chatbot, message, state, session_id, user_serial, persona, context, endpoint], [chatbot, state, status_box, session, session_id])
clear.click(lambda: None, None, chatbot, queue=False)
clear.click(lambda: None, None, session_id, queue=False)
clear.click(lambda: None, None, state, queue=False)
send.click(respond, [chatbot, message, state, session_id, user_serial, persona, context, endpoint], [chatbot, state, status_box, session, session_id])
send.click(reset_textbox, [], [message])
message.submit(reset_textbox, [], [message])
(
demo
.queue(max_size=QUEUE_MAX_SIZE, concurrency_count=QUEUE_CONCURENCY_COUNT)
.launch(auth=(USERNAME, PASSWORD), debug=True)
)