|
import datetime |
|
import os |
|
import sqlite3 |
|
import websockets |
|
import websocket |
|
import asyncio |
|
import sqlite3 |
|
import json |
|
import requests |
|
import asyncio |
|
import time |
|
import gradio as gr |
|
import fireworks.client |
|
import openai |
|
import fireworks.client |
|
import chainlit as cl |
|
from chainlit import make_async |
|
from gradio_client import Client |
|
from websockets.sync.client import connect |
|
from tempfile import TemporaryDirectory |
|
from typing import List |
|
from chainlit.input_widget import Select, Switch, Slider |
|
from chainlit import AskUserMessage, Message, on_chat_start |
|
|
|
from langchain.embeddings import CohereEmbeddings |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.vectorstores.chroma import Chroma |
|
from langchain.chains import ( |
|
ConversationalRetrievalChain, |
|
) |
|
from langchain.llms.fireworks import Fireworks |
|
from langchain.chat_models.fireworks import ChatFireworks |
|
from langchain.prompts.chat import ( |
|
ChatPromptTemplate, |
|
SystemMessagePromptTemplate, |
|
HumanMessagePromptTemplate, |
|
) |
|
from langchain.docstore.document import Document |
|
from langchain.memory import ChatMessageHistory, ConversationBufferMemory |
|
from langsmith_config import setup_langsmith_config |
|
|
|
|
|
COHERE_API_KEY = os.getenv("COHERE_API_KEY") |
|
cohere_api_key = os.getenv("COHERE_API_KEY") |
|
FIREWORKS_API_KEY = os.getenv("FIREWORKS_API_KEY") |
|
fireworks_api_key = os.getenv("FIREWORKS_API_KEY") |
|
|
|
server_ports = [] |
|
client_ports = [] |
|
|
|
setup_langsmith_config() |
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) |
|
|
|
system_template = """Use the following pieces of context to answer the users question. |
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
ALWAYS return a "SOURCES" part in your answer. |
|
The "SOURCES" part should be a reference to the source of the document from which you got your answer. |
|
|
|
And if the user greets with greetings like Hi, hello, How are you, etc reply accordingly as well. |
|
|
|
Example of your response should be: |
|
|
|
The answer is foo |
|
SOURCES: xyz |
|
|
|
|
|
Begin! |
|
---------------- |
|
{summaries}""" |
|
messages = [ |
|
SystemMessagePromptTemplate.from_template(system_template), |
|
HumanMessagePromptTemplate.from_template("{question}"), |
|
] |
|
prompt = ChatPromptTemplate.from_messages(messages) |
|
chain_type_kwargs = {"prompt": prompt} |
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
|
|
files = None |
|
|
|
settings = await cl.ChatSettings( |
|
[ |
|
Slider( |
|
id="websocketPort", |
|
label="Websocket server port", |
|
initial=False, |
|
min=1000, |
|
max=9999, |
|
step=10, |
|
), |
|
Slider( |
|
id="clientPort", |
|
label="Websocket client port", |
|
initial=False, |
|
min=1000, |
|
max=9999, |
|
step=10, |
|
), |
|
], |
|
).send() |
|
|
|
|
|
while files == None: |
|
files = await cl.AskFileMessage( |
|
content="Please upload a text file to begin!", |
|
accept=["text/plain"], |
|
max_size_mb=20, |
|
timeout=180, |
|
).send() |
|
|
|
file = files[0] |
|
|
|
msg = cl.Message( |
|
content=f"Processing `{file.name}`...", disable_human_feedback=True |
|
) |
|
await msg.send() |
|
|
|
|
|
text = file.content.decode("utf-8") |
|
|
|
|
|
texts = text_splitter.split_text(text) |
|
|
|
|
|
metadatas = [{"source": f"{i}-pl"} for i in range(len(texts))] |
|
|
|
|
|
embeddings = CohereEmbeddings(cohere_api_key="Ev0v9wwQPa90xDucdHTyFsllXGVHXouakUMObkNb") |
|
|
|
docsearch = await cl.make_async(Chroma.from_texts)( |
|
texts, embeddings, metadatas=metadatas |
|
) |
|
|
|
message_history = ChatMessageHistory() |
|
|
|
memory = ConversationBufferMemory( |
|
memory_key="chat_history", |
|
output_key="answer", |
|
chat_memory=message_history, |
|
return_messages=True, |
|
) |
|
|
|
|
|
chain = ConversationalRetrievalChain.from_llm( |
|
ChatFireworks(model="accounts/fireworks/models/llama-v2-70b-chat", model_kwargs={"temperature":0, "max_tokens":1500, "top_p":1.0}, streaming=True), |
|
chain_type="stuff", |
|
retriever=docsearch.as_retriever(), |
|
memory=memory, |
|
return_source_documents=True, |
|
) |
|
|
|
|
|
msg.content = f"Processing `{file.name}` done. You can now ask questions!" |
|
await msg.update() |
|
|
|
cl.user_session.set("chain", chain) |
|
|
|
@cl.action_callback("server_button") |
|
async def on_server_button(action): |
|
websocketPort = settings["websocketPort"] |
|
await start_websockets(websocketPort) |
|
|
|
@cl.action_callback("client_button") |
|
async def on_client_button(action): |
|
clientPort = settings["clientPort"] |
|
await start_client(clientPort) |
|
|
|
@cl.on_settings_update |
|
async def server_start(settings): |
|
websocketPort = settings["websocketPort"] |
|
clientPort = settings["clientPort"] |
|
if websocketPort: |
|
await start_websockets(websocketPort) |
|
else: |
|
print("Server port number wasn't provided.") |
|
|
|
if clientPort: |
|
await start_client(clientPort) |
|
else: |
|
print("Client port number wasn't provided.") |
|
|
|
async def handleWebSocket(ws): |
|
print('New connection') |
|
instruction = "Hello! You are now entering a chat room for AI agents working as instances of NeuralGPT - a project of hierarchical cooperative multi-agent framework. Keep in mind that you are speaking with another chatbot. Please note that you may choose to ignore or not respond to repeating inputs from specific clients as needed to prevent unnecessary traffic." |
|
greetings = {'instructions': instruction} |
|
await ws.send(json.dumps(instruction)) |
|
while True: |
|
loop = asyncio.get_event_loop() |
|
message = await ws.recv() |
|
print(f'Received message: {message}') |
|
msg = "client: " + message |
|
timestamp = datetime.datetime.now().isoformat() |
|
sender = 'client' |
|
db = sqlite3.connect('chat-hub.db') |
|
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)', |
|
(sender, message, timestamp)) |
|
db.commit() |
|
try: |
|
response = await main(cl.Message(content=message)) |
|
serverResponse = "server response: " + response |
|
print(serverResponse) |
|
|
|
await ws.send(serverResponse) |
|
serverSender = 'server' |
|
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)', |
|
(serverSender, serverResponse, timestamp)) |
|
db.commit() |
|
return response |
|
followUp = await awaitMsg(message) |
|
|
|
except websockets.exceptions.ConnectionClosedError as e: |
|
print(f"Connection closed: {e}") |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
async def awaitMsg(ws): |
|
message = await ws.recv() |
|
print(f'Received message: {message}') |
|
timestamp = datetime.datetime.now().isoformat() |
|
sender = 'client' |
|
db = sqlite3.connect('chat-hub.db') |
|
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)', |
|
(sender, message, timestamp)) |
|
db.commit() |
|
try: |
|
response = await main(cl.Message(content=message)) |
|
serverResponse = "server response: " + response |
|
print(serverResponse) |
|
|
|
await ws.send(serverResponse) |
|
serverSender = 'server' |
|
db.execute('INSERT INTO messages (sender, message, timestamp) VALUES (?, ?, ?)', |
|
(serverSender, serverResponse, timestamp)) |
|
db.commit() |
|
return response |
|
except websockets.exceptions.ConnectionClosedError as e: |
|
print(f"Connection closed: {e}") |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
|
|
|
|
async def start_websockets(websocketPort): |
|
global server |
|
server = await(websockets.serve(handleWebSocket, 'localhost', websocketPort)) |
|
server_ports.append(websocketPort) |
|
print(f"Starting WebSocket server on port {websocketPort}...") |
|
return "Used ports:\n" + '\n'.join(map(str, server_ports)) |
|
await asyncio.Future() |
|
|
|
async def start_client(clientPort): |
|
uri = f'ws://localhost:{clientPort}' |
|
client_ports.append(clientPort) |
|
async with websockets.connect(uri) as ws: |
|
while True: |
|
|
|
input_message = await ws.recv() |
|
output_message = await main(cl.Message(content=input_message)) |
|
return input_message |
|
await ws.send(json.dumps(output_message)) |
|
await asyncio.sleep(0.1) |
|
|
|
|
|
def stop_websockets(): |
|
global server |
|
if server: |
|
cursor.close() |
|
db.close() |
|
server.close() |
|
print("WebSocket server stopped.") |
|
else: |
|
print("WebSocket server is not running.") |
|
|
|
@cl.on_message |
|
async def main(message: cl.Message): |
|
chain = cl.user_session.get("chain") |
|
cb = cl.AsyncLangchainCallbackHandler() |
|
|
|
res = await chain.acall(message.content, callbacks=[cb]) |
|
answer = res["answer"] |
|
source_documents = res["source_documents"] |
|
|
|
text_elements = [] |
|
|
|
if source_documents: |
|
for source_idx, source_doc in enumerate(source_documents): |
|
source_name = f"source_{source_idx}" |
|
|
|
text_elements.append( |
|
cl.Text(content=source_doc.page_content, name=source_name) |
|
) |
|
source_names = [text_el.name for text_el in text_elements] |
|
|
|
if source_names: |
|
answer += f"\nSources: {', '.join(source_names)}" |
|
else: |
|
answer += "\nNo sources found" |
|
|
|
return json.dumps(answer) |
|
await cl.Message(content=answer, elements=text_elements).send() |
|
|