Spaces:
Running
Running
import uuid | |
from fastapi import FastAPI | |
from fastapi.responses import StreamingResponse | |
from langchain_core.messages import ( | |
BaseMessage, | |
HumanMessage, | |
trim_messages, | |
) | |
from langchain_core.tools import tool | |
from langchain_openai import ChatOpenAI | |
from langgraph.checkpoint.memory import MemorySaver | |
from langgraph.prebuilt import create_react_agent | |
from pydantic import BaseModel | |
import json | |
from typing import Optional, Annotated | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.prebuilt import InjectedState | |
from document_rag_router import router as document_rag_router | |
from document_rag_router import QueryInput, query_collection, SearchResult | |
from fastapi import HTTPException | |
import requests | |
from sse_starlette.sse import EventSourceResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
import re | |
app = FastAPI() | |
app.include_router(document_rag_router) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
def get_user_age(name: str) -> str: | |
"""Use this tool to find the user's age.""" | |
if "bob" in name.lower(): | |
return "42 years old" | |
return "41 years old" | |
async def query_documents( | |
query: str, | |
config: RunnableConfig, | |
#state: Annotated[dict, InjectedState] | |
) -> str: | |
"""Use this tool to retrieve relevant data from the collection. | |
Args: | |
query: The search query to find relevant document passages | |
""" | |
# Get collection_id and user_id from config | |
thread_config = config.get("configurable", {}) | |
collection_id = thread_config.get("collection_id") | |
user_id = thread_config.get("user_id") | |
if not collection_id or not user_id: | |
return "Error: collection_id and user_id are required in the config" | |
try: | |
# Create query input | |
input_data = QueryInput( | |
collection_id=collection_id, | |
query=query, | |
user_id=user_id, | |
top_k=6 | |
) | |
response = await query_collection(input_data) | |
results = [] | |
# Access response directly since it's a Pydantic model | |
for r in response.results: | |
result_dict = { | |
"text": r.text, | |
"distance": r.distance, | |
"metadata": { | |
"document_id": r.metadata.get("document_id"), | |
"chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
} | |
} | |
results.append(result_dict) | |
return str(results) | |
except Exception as e: | |
print(e) | |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
async def query_documents_raw( | |
query: str, | |
config: RunnableConfig, | |
#state: Annotated[dict, InjectedState] | |
) -> SearchResult: | |
"""Use this tool to retrieve relevant data from the collection. | |
Args: | |
query: The search query to find relevant document passages | |
""" | |
# Get collection_id and user_id from config | |
thread_config = config.get("configurable", {}) | |
collection_id = thread_config.get("collection_id") | |
user_id = thread_config.get("user_id") | |
if not collection_id or not user_id: | |
return "Error: collection_id and user_id are required in the config" | |
try: | |
# Create query input | |
input_data = QueryInput( | |
collection_id=collection_id, | |
query=query, | |
user_id=user_id, | |
top_k=6 | |
) | |
response = await query_collection(input_data) | |
return response.results | |
except Exception as e: | |
print(e) | |
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" | |
memory = MemorySaver() | |
model = ChatOpenAI(model="gpt-4o-mini", streaming=True) | |
def state_modifier(state) -> list[BaseMessage]: | |
return trim_messages( | |
state["messages"], | |
token_counter=len, | |
max_tokens=16000, | |
strategy="last", | |
start_on="human", | |
include_system=True, | |
allow_partial=False, | |
) | |
agent = create_react_agent( | |
model, | |
tools=[query_documents], | |
checkpointer=memory, | |
state_modifier=state_modifier, | |
) | |
class ChatInput(BaseModel): | |
message: str | |
thread_id: Optional[str] = None | |
collection_id: Optional[str] = None | |
user_id: Optional[str] = None | |
async def chat(input_data: ChatInput): | |
thread_id = input_data.thread_id or str(uuid.uuid4()) | |
config = { | |
"configurable": { | |
"thread_id": thread_id, | |
"collection_id": input_data.collection_id, | |
"user_id": input_data.user_id | |
} | |
} | |
input_message = HumanMessage(content=input_data.message) | |
async def generate(): | |
async for event in agent.astream_events( | |
{"messages": [input_message]}, | |
config, | |
version="v2" | |
): | |
kind = event["event"] | |
if kind == "on_chat_model_stream": | |
content = event["data"]["chunk"].content | |
if content: | |
yield f"{json.dumps({'type': 'token', 'content': content})}" | |
elif kind == "on_tool_start": | |
tool_input = str(event['data'].get('input', '')) | |
yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" | |
elif kind == "on_tool_end": | |
tool_output = str(event['data'].get('output', '')) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" | |
return EventSourceResponse( | |
generate(), | |
media_type="text/event-stream" | |
) | |
async def clean_tool_input(tool_input: str): | |
# Use regex to parse the first key and value | |
pattern = r"{\s*'([^']+)':\s*'([^']+)'" | |
match = re.search(pattern, tool_input) | |
if match: | |
key, value = match.groups() | |
return {key: value} | |
return [tool_input] | |
async def clean_tool_response(tool_output: str): | |
"""Clean and extract relevant information from tool response if it contains query_documents.""" | |
if "query_documents" in tool_output: | |
try: | |
# First safely evaluate the string as a Python literal | |
import ast | |
print(tool_output) | |
# Extract the list string from the content | |
start = tool_output.find("[{") | |
end = tool_output.rfind("}]") + 2 | |
if start >= 0 and end > 0: | |
list_str = tool_output[start:end] | |
# Convert string to Python object using ast.literal_eval | |
results = ast.literal_eval(list_str) | |
# Return only relevant fields | |
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} | |
for r in results] | |
except SyntaxError as e: | |
print(f"Syntax error in parsing: {e}") | |
return f"Error parsing document results: {str(e)}" | |
except Exception as e: | |
print(f"General error: {e}") | |
return f"Error processing results: {str(e)}" | |
return tool_output | |
async def chat2(input_data: ChatInput): | |
thread_id = input_data.thread_id or str(uuid.uuid4()) | |
config = { | |
"configurable": { | |
"thread_id": thread_id, | |
"collection_id": input_data.collection_id, | |
"user_id": input_data.user_id | |
} | |
} | |
input_message = HumanMessage(content=input_data.message) | |
async def generate(): | |
async for event in agent.astream_events( | |
{"messages": [input_message]}, | |
config, | |
version="v2" | |
): | |
kind = event["event"] | |
if kind == "on_chat_model_stream": | |
content = event["data"]["chunk"].content | |
if content: | |
yield f"{json.dumps({'type': 'token', 'content': content})}" | |
elif kind == "on_tool_start": | |
tool_name = event['name'] | |
tool_input = event['data'].get('input', '') | |
clean_input = await clean_tool_input(str(tool_input)) | |
yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" | |
elif kind == "on_tool_end": | |
if "query_documents" in event['name']: | |
print(event) | |
raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) | |
try: | |
serializable_output = [ | |
{ | |
"text": result.text, | |
"distance": result.distance, | |
"metadata": result.metadata | |
} | |
for result in raw_output | |
] | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" | |
except Exception as e: | |
print(e) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" | |
else: | |
tool_name = event['name'] | |
raw_output = str(event['data'].get('output', '')) | |
clean_output = await clean_tool_response(raw_output) | |
yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" | |
return EventSourceResponse( | |
generate(), | |
media_type="text/event-stream" | |
) | |
async def health_check(): | |
return {"status": "healthy"} |