documind-api / main.py
pvanand's picture
Upload 7 files
5d42805 verified
import uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain_core.messages import (
BaseMessage,
HumanMessage,
trim_messages,
)
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
import json
from typing import Optional, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import InjectedState
from document_rag_router import router as document_rag_router
from document_rag_router import QueryInput, query_collection, SearchResult
from fastapi import HTTPException
import requests
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import re
app = FastAPI()
app.include_router(document_rag_router)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@tool
def get_user_age(name: str) -> str:
"""Use this tool to find the user's age."""
if "bob" in name.lower():
return "42 years old"
return "41 years old"
@tool
async def query_documents(
query: str,
config: RunnableConfig,
#state: Annotated[dict, InjectedState]
) -> str:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
results = []
# Access response directly since it's a Pydantic model
for r in response.results:
result_dict = {
"text": r.text,
"distance": r.distance,
"metadata": {
"document_id": r.metadata.get("document_id"),
"chunk_index": r.metadata.get("location", {}).get("chunk_index")
}
}
results.append(result_dict)
return str(results)
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
async def query_documents_raw(
query: str,
config: RunnableConfig,
#state: Annotated[dict, InjectedState]
) -> SearchResult:
"""Use this tool to retrieve relevant data from the collection.
Args:
query: The search query to find relevant document passages
"""
# Get collection_id and user_id from config
thread_config = config.get("configurable", {})
collection_id = thread_config.get("collection_id")
user_id = thread_config.get("user_id")
if not collection_id or not user_id:
return "Error: collection_id and user_id are required in the config"
try:
# Create query input
input_data = QueryInput(
collection_id=collection_id,
query=query,
user_id=user_id,
top_k=6
)
response = await query_collection(input_data)
return response.results
except Exception as e:
print(e)
return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
memory = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
def state_modifier(state) -> list[BaseMessage]:
return trim_messages(
state["messages"],
token_counter=len,
max_tokens=16000,
strategy="last",
start_on="human",
include_system=True,
allow_partial=False,
)
agent = create_react_agent(
model,
tools=[query_documents],
checkpointer=memory,
state_modifier=state_modifier,
)
class ChatInput(BaseModel):
message: str
thread_id: Optional[str] = None
collection_id: Optional[str] = None
user_id: Optional[str] = None
@app.post("/chat")
async def chat(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_input = str(event['data'].get('input', ''))
yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
elif kind == "on_tool_end":
tool_output = str(event['data'].get('output', ''))
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
async def clean_tool_input(tool_input: str):
# Use regex to parse the first key and value
pattern = r"{\s*'([^']+)':\s*'([^']+)'"
match = re.search(pattern, tool_input)
if match:
key, value = match.groups()
return {key: value}
return [tool_input]
async def clean_tool_response(tool_output: str):
"""Clean and extract relevant information from tool response if it contains query_documents."""
if "query_documents" in tool_output:
try:
# First safely evaluate the string as a Python literal
import ast
print(tool_output)
# Extract the list string from the content
start = tool_output.find("[{")
end = tool_output.rfind("}]") + 2
if start >= 0 and end > 0:
list_str = tool_output[start:end]
# Convert string to Python object using ast.literal_eval
results = ast.literal_eval(list_str)
# Return only relevant fields
return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
for r in results]
except SyntaxError as e:
print(f"Syntax error in parsing: {e}")
return f"Error parsing document results: {str(e)}"
except Exception as e:
print(f"General error: {e}")
return f"Error processing results: {str(e)}"
return tool_output
@app.post("/chat2")
async def chat2(input_data: ChatInput):
thread_id = input_data.thread_id or str(uuid.uuid4())
config = {
"configurable": {
"thread_id": thread_id,
"collection_id": input_data.collection_id,
"user_id": input_data.user_id
}
}
input_message = HumanMessage(content=input_data.message)
async def generate():
async for event in agent.astream_events(
{"messages": [input_message]},
config,
version="v2"
):
kind = event["event"]
if kind == "on_chat_model_stream":
content = event["data"]["chunk"].content
if content:
yield f"{json.dumps({'type': 'token', 'content': content})}"
elif kind == "on_tool_start":
tool_name = event['name']
tool_input = event['data'].get('input', '')
clean_input = await clean_tool_input(str(tool_input))
yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
elif kind == "on_tool_end":
if "query_documents" in event['name']:
print(event)
raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
try:
serializable_output = [
{
"text": result.text,
"distance": result.distance,
"metadata": result.metadata
}
for result in raw_output
]
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
except Exception as e:
print(e)
yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
else:
tool_name = event['name']
raw_output = str(event['data'].get('output', ''))
clean_output = await clean_tool_response(raw_output)
yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
return EventSourceResponse(
generate(),
media_type="text/event-stream"
)
@app.get("/health")
async def health_check():
return {"status": "healthy"}