import uuid from fastapi import FastAPI from fastapi.responses import StreamingResponse from langchain_core.messages import ( BaseMessage, HumanMessage, trim_messages, ) from langchain_core.tools import tool from langchain_openai import ChatOpenAI from langgraph.checkpoint.memory import MemorySaver from langgraph.prebuilt import create_react_agent from pydantic import BaseModel import json from typing import Optional, Annotated from langchain_core.runnables import RunnableConfig from langgraph.prebuilt import InjectedState from document_rag_router import router as document_rag_router from document_rag_router import QueryInput, query_collection, SearchResult from fastapi import HTTPException import requests from sse_starlette.sse import EventSourceResponse from fastapi.middleware.cors import CORSMiddleware import re app = FastAPI() app.include_router(document_rag_router) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @tool def get_user_age(name: str) -> str: """Use this tool to find the user's age.""" if "bob" in name.lower(): return "42 years old" return "41 years old" @tool async def query_documents( query: str, config: RunnableConfig, #state: Annotated[dict, InjectedState] ) -> str: """Use this tool to retrieve relevant data from the collection. Args: query: The search query to find relevant document passages """ # Get collection_id and user_id from config thread_config = config.get("configurable", {}) collection_id = thread_config.get("collection_id") user_id = thread_config.get("user_id") if not collection_id or not user_id: return "Error: collection_id and user_id are required in the config" try: # Create query input input_data = QueryInput( collection_id=collection_id, query=query, user_id=user_id, top_k=6 ) response = await query_collection(input_data) results = [] # Access response directly since it's a Pydantic model for r in response.results: result_dict = { "text": r.text, "distance": r.distance, "metadata": { "document_id": r.metadata.get("document_id"), "chunk_index": r.metadata.get("location", {}).get("chunk_index") } } results.append(result_dict) return str(results) except Exception as e: print(e) return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" async def query_documents_raw( query: str, config: RunnableConfig, #state: Annotated[dict, InjectedState] ) -> SearchResult: """Use this tool to retrieve relevant data from the collection. Args: query: The search query to find relevant document passages """ # Get collection_id and user_id from config thread_config = config.get("configurable", {}) collection_id = thread_config.get("collection_id") user_id = thread_config.get("user_id") if not collection_id or not user_id: return "Error: collection_id and user_id are required in the config" try: # Create query input input_data = QueryInput( collection_id=collection_id, query=query, user_id=user_id, top_k=6 ) response = await query_collection(input_data) return response.results except Exception as e: print(e) return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP" memory = MemorySaver() model = ChatOpenAI(model="gpt-4o-mini", streaming=True) def state_modifier(state) -> list[BaseMessage]: return trim_messages( state["messages"], token_counter=len, max_tokens=16000, strategy="last", start_on="human", include_system=True, allow_partial=False, ) agent = create_react_agent( model, tools=[query_documents], checkpointer=memory, state_modifier=state_modifier, ) class ChatInput(BaseModel): message: str thread_id: Optional[str] = None collection_id: Optional[str] = None user_id: Optional[str] = None @app.post("/chat") async def chat(input_data: ChatInput): thread_id = input_data.thread_id or str(uuid.uuid4()) config = { "configurable": { "thread_id": thread_id, "collection_id": input_data.collection_id, "user_id": input_data.user_id } } input_message = HumanMessage(content=input_data.message) async def generate(): async for event in agent.astream_events( {"messages": [input_message]}, config, version="v2" ): kind = event["event"] if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield f"{json.dumps({'type': 'token', 'content': content})}" elif kind == "on_tool_start": tool_input = str(event['data'].get('input', '')) yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}" elif kind == "on_tool_end": tool_output = str(event['data'].get('output', '')) yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}" return EventSourceResponse( generate(), media_type="text/event-stream" ) async def clean_tool_input(tool_input: str): # Use regex to parse the first key and value pattern = r"{\s*'([^']+)':\s*'([^']+)'" match = re.search(pattern, tool_input) if match: key, value = match.groups() return {key: value} return [tool_input] async def clean_tool_response(tool_output: str): """Clean and extract relevant information from tool response if it contains query_documents.""" if "query_documents" in tool_output: try: # First safely evaluate the string as a Python literal import ast print(tool_output) # Extract the list string from the content start = tool_output.find("[{") end = tool_output.rfind("}]") + 2 if start >= 0 and end > 0: list_str = tool_output[start:end] # Convert string to Python object using ast.literal_eval results = ast.literal_eval(list_str) # Return only relevant fields return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} for r in results] except SyntaxError as e: print(f"Syntax error in parsing: {e}") return f"Error parsing document results: {str(e)}" except Exception as e: print(f"General error: {e}") return f"Error processing results: {str(e)}" return tool_output @app.post("/chat2") async def chat2(input_data: ChatInput): thread_id = input_data.thread_id or str(uuid.uuid4()) config = { "configurable": { "thread_id": thread_id, "collection_id": input_data.collection_id, "user_id": input_data.user_id } } input_message = HumanMessage(content=input_data.message) async def generate(): async for event in agent.astream_events( {"messages": [input_message]}, config, version="v2" ): kind = event["event"] if kind == "on_chat_model_stream": content = event["data"]["chunk"].content if content: yield f"{json.dumps({'type': 'token', 'content': content})}" elif kind == "on_tool_start": tool_name = event['name'] tool_input = event['data'].get('input', '') clean_input = await clean_tool_input(str(tool_input)) yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}" elif kind == "on_tool_end": if "query_documents" in event['name']: print(event) raw_output = await query_documents_raw(str(event['data'].get('input', '')), config) try: serializable_output = [ { "text": result.text, "distance": result.distance, "metadata": result.metadata } for result in raw_output ] yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}" except Exception as e: print(e) yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}" else: tool_name = event['name'] raw_output = str(event['data'].get('output', '')) clean_output = await clean_tool_response(raw_output) yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}" return EventSourceResponse( generate(), media_type="text/event-stream" ) @app.get("/health") async def health_check(): return {"status": "healthy"}