pvanand commited on
Commit
5d42805
·
verified ·
1 Parent(s): d53c11a

Upload 7 files

Browse files
Files changed (7) hide show
  1. Dockerfile +20 -0
  2. docker-compose.yml +10 -0
  3. document_rag_router.py +400 -0
  4. main.py +293 -0
  5. readme.md +91 -0
  6. requirements.txt +12 -0
  7. utils.py +253 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install Tkinter dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ tk \
6
+ && apt-get clean \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ WORKDIR /app
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt \
13
+ && pip install torch --index-url https://download.pytorch.org/whl/cpu \
14
+ && pip install sentence-transformers
15
+
16
+ COPY . .
17
+
18
+ EXPOSE 80
19
+
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "80", "--log-level", "debug"]
docker-compose.yml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ rag-api:
3
+ build: .
4
+ container_name: rag-api
5
+ restart: unless-stopped
6
+ environment:
7
+ - OPENAI_API_KEY=${OPENAI_API_KEY}
8
+
9
+ ports:
10
+ - "9004:80"
document_rag_router.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import UploadFile, File, Form, HTTPException, APIRouter
2
+ from typing import List, Optional, Dict, Tuple
3
+ import lancedb
4
+ from lancedb.pydantic import LanceModel, Vector
5
+ from lancedb.embeddings import get_registry
6
+ import pandas as pd
7
+ from utils import process_pdf_to_chunks
8
+ import hashlib
9
+ import uuid
10
+ import json
11
+ from datetime import datetime
12
+ from pydantic import BaseModel
13
+ import logging
14
+
15
+ # Create router
16
+ router = APIRouter(
17
+ prefix="/rag",
18
+ tags=["rag"]
19
+ )
20
+
21
+ # Initialize LanceDB and embedding model
22
+ db = lancedb.connect("/tmp/db")
23
+ model = get_registry().get("sentence-transformers").create(
24
+ name="Snowflake/snowflake-arctic-embed-xs",
25
+ device="cpu"
26
+ )
27
+
28
+ def get_user_collection(user_id: str, collection_name: str) -> str:
29
+ """Generate user-specific collection name"""
30
+ return f"{user_id}_{collection_name}"
31
+
32
+ class DocumentChunk(LanceModel):
33
+ text: str = model.SourceField()
34
+ vector: Vector(model.ndims()) = model.VectorField()
35
+ document_id: str
36
+ chunk_index: int
37
+ file_name: str
38
+ file_type: str
39
+ created_date: str
40
+ collection_id: str
41
+ user_id: str
42
+ metadata_json: str
43
+ char_start: int
44
+ char_end: int
45
+ page_numbers: List[int]
46
+ images: List[str]
47
+
48
+ class QueryInput(BaseModel):
49
+ collection_id: str
50
+ query: str
51
+ top_k: Optional[int] = 3
52
+ user_id: str
53
+
54
+ class SearchResult(BaseModel):
55
+ text: str
56
+ distance: float
57
+ metadata: Dict # Added metadata field
58
+
59
+ class SearchResponse(BaseModel):
60
+ results: List[SearchResult]
61
+
62
+ async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]:
63
+ """Process single file and return chunks with metadata"""
64
+ content = await file.read()
65
+ file_type = file.filename.split('.')[-1].lower()
66
+
67
+ chunks = []
68
+ doc_id = ""
69
+ if file_type == 'pdf':
70
+ chunks, doc_id = process_pdf_to_chunks(
71
+ pdf_content=content,
72
+ file_name=file.filename
73
+ )
74
+ elif file_type == 'txt':
75
+ doc_id = hashlib.sha256(content).hexdigest()[:4]
76
+ text_content = content.decode('utf-8')
77
+ chunks = [{
78
+ "text": text_content,
79
+ "metadata": {
80
+ "created_date": datetime.now().isoformat(),
81
+ "file_name": file.filename,
82
+ "document_id": doc_id,
83
+ "user_id": user_id,
84
+ "location": {
85
+ "chunk_index": 0,
86
+ "char_start": 0,
87
+ "char_end": len(text_content),
88
+ "pages": [1],
89
+ "total_chunks": 1
90
+ },
91
+ "images": []
92
+ }
93
+ }]
94
+
95
+ return chunks, doc_id
96
+
97
+ @router.post("/upload_files")
98
+ async def upload_files(
99
+ files: List[UploadFile] = File(...),
100
+ collection_name: Optional[str] = Form(None),
101
+ user_id: str = Form(...)
102
+ ):
103
+ try:
104
+ collection_id = get_user_collection(
105
+ user_id,
106
+ collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}"
107
+ )
108
+ all_chunks = []
109
+ doc_ids = {}
110
+ for file in files:
111
+ try:
112
+ chunks, doc_id = await process_file(file, collection_id, user_id)
113
+ for chunk in chunks:
114
+ chunk_data = {
115
+ "text": chunk["text"],
116
+ "document_id": chunk["metadata"]["document_id"],
117
+ "chunk_index": chunk["metadata"]["location"]["chunk_index"],
118
+ "file_name": chunk["metadata"]["file_name"],
119
+ "file_type": file.filename.split('.')[-1].lower(),
120
+ "created_date": chunk["metadata"]["created_date"],
121
+ "collection_id": collection_id,
122
+ "user_id": user_id,
123
+ "metadata_json": json.dumps(chunk["metadata"]),
124
+ "char_start": chunk["metadata"]["location"]["char_start"],
125
+ "char_end": chunk["metadata"]["location"]["char_end"],
126
+ "page_numbers": chunk["metadata"]["location"]["pages"],
127
+ "images": chunk["metadata"].get("images", [])
128
+ }
129
+ all_chunks.append(chunk_data)
130
+ doc_ids[doc_id] = file.filename
131
+ except Exception as e:
132
+ logging.error(f"Error processing file {file.filename}: {str(e)}")
133
+ raise HTTPException(
134
+ status_code=400,
135
+ detail=f"Error processing file {file.filename}: {str(e)}"
136
+ )
137
+
138
+ try:
139
+ table = db.open_table(collection_id)
140
+ except Exception as e:
141
+ logging.error(f"Error opening table: {str(e)}")
142
+ try:
143
+ table = db.create_table(
144
+ collection_id,
145
+ schema=DocumentChunk,
146
+ mode="create"
147
+ )
148
+ # Create FTS index on the text column for hybrid search support
149
+
150
+ # table.create_fts_index(
151
+ # field_names="text",
152
+ # replace=True,
153
+ # tokenizer_name="en_stem", # Use English stemming
154
+ # lower_case=True, # Convert text to lowercase
155
+ # remove_stop_words=True, # Remove common words like "the", "is", "at"
156
+ # writer_heap_size=1024 * 1024 * 1024 # 1GB heap size
157
+ # )
158
+
159
+ except Exception as e:
160
+ logging.error(f"Error creating table: {str(e)}")
161
+ raise HTTPException(
162
+ status_code=500,
163
+ detail=f"Error creating database table: {str(e)}"
164
+ )
165
+
166
+ try:
167
+ df = pd.DataFrame(all_chunks)
168
+ table.add(data=df)
169
+ except Exception as e:
170
+ logging.error(f"Error adding data to table: {str(e)}")
171
+ raise HTTPException(
172
+ status_code=500,
173
+ detail=f"Error adding data to database: {str(e)}"
174
+ )
175
+
176
+ return {
177
+ "message": f"Successfully processed {len(files)} files",
178
+ "collection_id": collection_id,
179
+ "total_chunks": len(all_chunks),
180
+ "user_id": user_id,
181
+ "document_ids": doc_ids
182
+ }
183
+
184
+ except HTTPException:
185
+ raise
186
+ except Exception as e:
187
+ logging.error(f"Unexpected error during file upload: {str(e)}")
188
+ raise HTTPException(
189
+ status_code=500,
190
+ detail=f"Unexpected error: {str(e)}"
191
+ )
192
+
193
+ @router.get("/get_document/{collection_id}/{document_id}")
194
+ async def get_document(
195
+ collection_id: str,
196
+ document_id: str,
197
+ user_id: str
198
+ ):
199
+ try:
200
+ table = db.open_table(f"{user_id}_{collection_id}")
201
+ except Exception as e:
202
+ logging.error(f"Error opening table: {str(e)}")
203
+ raise HTTPException(
204
+ status_code=404,
205
+ detail=f"Collection not found: {str(e)}"
206
+ )
207
+
208
+ try:
209
+ chunks = table.to_pandas()
210
+ doc_chunks = chunks[
211
+ (chunks['document_id'] == document_id) &
212
+ (chunks['user_id'] == user_id)
213
+ ].sort_values('chunk_index')
214
+
215
+ if len(doc_chunks) == 0:
216
+ raise HTTPException(
217
+ status_code=404,
218
+ detail=f"Document {document_id} not found in collection {collection_id}"
219
+ )
220
+
221
+ return {
222
+ "document_id": document_id,
223
+ "file_name": doc_chunks.iloc[0]['file_name'],
224
+ "chunks": [
225
+ {
226
+ "text": row['text'],
227
+ "metadata": json.loads(row['metadata_json'])
228
+ }
229
+ for _, row in doc_chunks.iterrows()
230
+ ]
231
+ }
232
+ except HTTPException:
233
+ raise
234
+ except Exception as e:
235
+ logging.error(f"Error retrieving document: {str(e)}")
236
+ raise HTTPException(
237
+ status_code=500,
238
+ detail=f"Error retrieving document: {str(e)}"
239
+ )
240
+
241
+ @router.post("/query_collection", response_model=SearchResponse)
242
+ async def query_collection(input_data: QueryInput):
243
+ try:
244
+ collection_id = get_user_collection(input_data.user_id, input_data.collection_id)
245
+
246
+ try:
247
+ table = db.open_table(collection_id)
248
+ except Exception as e:
249
+ logging.error(f"Error opening table: {str(e)}")
250
+ raise HTTPException(
251
+ status_code=404,
252
+ detail=f"Collection not found: {str(e)}"
253
+ )
254
+
255
+ try:
256
+ results = (
257
+ table.search(input_data.query)
258
+ .where(f"user_id = '{input_data.user_id}'")
259
+ .limit(input_data.top_k)
260
+ .to_list()
261
+ )
262
+ except Exception as e:
263
+ logging.error(f"Error searching collection: {str(e)}")
264
+ raise HTTPException(
265
+ status_code=500,
266
+ detail=f"Error searching collection: {str(e)}"
267
+ )
268
+
269
+ return SearchResponse(results=[
270
+ SearchResult(
271
+ text=r['text'],
272
+ distance=float(r['_distance']),
273
+ metadata=json.loads(r['metadata_json'])
274
+ )
275
+ for r in results
276
+ ])
277
+ except HTTPException:
278
+ raise
279
+ except Exception as e:
280
+ logging.error(f"Unexpected error during query: {str(e)}")
281
+ raise HTTPException(
282
+ status_code=500,
283
+ detail=f"Unexpected error: {str(e)}"
284
+ )
285
+
286
+
287
+
288
+ @router.get("/list_collections")
289
+ async def list_collections(user_id: str):
290
+ try:
291
+ all_collections = db.table_names()
292
+ user_collections = [
293
+ c for c in all_collections
294
+ if c.startswith(f"{user_id}_")
295
+ ]
296
+
297
+ # Get documents for each collection
298
+ collections_info = []
299
+ for collection_name in user_collections:
300
+ try:
301
+ table = db.open_table(collection_name)
302
+ df = table.to_pandas()
303
+
304
+ # Group by document_id to get unique documents
305
+ documents = df.groupby('document_id').agg({
306
+ 'file_name': 'first',
307
+ 'created_date': 'first'
308
+ }).reset_index()
309
+
310
+ collections_info.append({
311
+ "collection_id": collection_name.replace(f"{user_id}_", ""),
312
+ "documents": [
313
+ {
314
+ "document_id": row['document_id'],
315
+ "file_name": row['file_name'],
316
+ "created_date": row['created_date']
317
+ }
318
+ for _, row in documents.iterrows()
319
+ ]
320
+ })
321
+ except Exception as e:
322
+ logging.error(f"Error processing collection {collection_name}: {str(e)}")
323
+ continue
324
+
325
+ return {"collections": collections_info}
326
+ except Exception as e:
327
+ raise HTTPException(status_code=500, detail=str(e))
328
+
329
+ @router.delete("/delete_collection/{collection_id}")
330
+ async def delete_collection(collection_id: str, user_id: str):
331
+ try:
332
+ full_collection_id = f"{user_id}_{collection_id}"
333
+
334
+ # Check if collection exists
335
+ try:
336
+ table = db.open_table(full_collection_id)
337
+ except Exception as e:
338
+ logging.error(f"Collection not found: {str(e)}")
339
+ raise HTTPException(
340
+ status_code=404,
341
+ detail=f"Collection {collection_id} not found"
342
+ )
343
+
344
+ # Verify ownership
345
+ if not full_collection_id.startswith(f"{user_id}_"):
346
+ logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}")
347
+ raise HTTPException(
348
+ status_code=403,
349
+ detail="Not authorized to delete this collection"
350
+ )
351
+
352
+ try:
353
+ db.drop_table(full_collection_id)
354
+ except Exception as e:
355
+ logging.error(f"Error deleting collection {collection_id}: {str(e)}")
356
+ raise HTTPException(
357
+ status_code=500,
358
+ detail=f"Error deleting collection: {str(e)}"
359
+ )
360
+
361
+ return {
362
+ "message": f"Collection {collection_id} deleted successfully",
363
+ "collection_id": collection_id
364
+ }
365
+
366
+ except HTTPException:
367
+ raise
368
+ except Exception as e:
369
+ logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}")
370
+ raise HTTPException(
371
+ status_code=500,
372
+ detail=f"Unexpected error: {str(e)}"
373
+ )
374
+
375
+ @router.post("/query_collection_tool")
376
+ async def query_collection_tool(input_data: QueryInput):
377
+ try:
378
+ response = await query_collection(input_data)
379
+ results = []
380
+
381
+ # Access response directly since it's a Pydantic model
382
+ for r in response.results:
383
+ result_dict = {
384
+ "text": r.text,
385
+ "distance": r.distance,
386
+ "metadata": {
387
+ "document_id": r.metadata.get("document_id"),
388
+ "chunk_index": r.metadata.get("location", {}).get("chunk_index")
389
+ }
390
+ }
391
+ results.append(result_dict)
392
+
393
+ return str(results)
394
+
395
+ except Exception as e:
396
+ logging.error(f"Unexpected error during query: {str(e)}")
397
+ raise HTTPException(
398
+ status_code=500,
399
+ detail=f"Unexpected error: {str(e)}"
400
+ )
main.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from fastapi import FastAPI
3
+ from fastapi.responses import StreamingResponse
4
+ from langchain_core.messages import (
5
+ BaseMessage,
6
+ HumanMessage,
7
+ trim_messages,
8
+ )
9
+ from langchain_core.tools import tool
10
+ from langchain_openai import ChatOpenAI
11
+ from langgraph.checkpoint.memory import MemorySaver
12
+ from langgraph.prebuilt import create_react_agent
13
+ from pydantic import BaseModel
14
+ import json
15
+ from typing import Optional, Annotated
16
+ from langchain_core.runnables import RunnableConfig
17
+ from langgraph.prebuilt import InjectedState
18
+ from document_rag_router import router as document_rag_router
19
+ from document_rag_router import QueryInput, query_collection, SearchResult
20
+ from fastapi import HTTPException
21
+ import requests
22
+ from sse_starlette.sse import EventSourceResponse
23
+ from fastapi.middleware.cors import CORSMiddleware
24
+ import re
25
+
26
+ app = FastAPI()
27
+ app.include_router(document_rag_router)
28
+
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
36
+
37
+ @tool
38
+ def get_user_age(name: str) -> str:
39
+ """Use this tool to find the user's age."""
40
+ if "bob" in name.lower():
41
+ return "42 years old"
42
+ return "41 years old"
43
+
44
+ @tool
45
+ async def query_documents(
46
+ query: str,
47
+ config: RunnableConfig,
48
+ #state: Annotated[dict, InjectedState]
49
+ ) -> str:
50
+ """Use this tool to retrieve relevant data from the collection.
51
+
52
+ Args:
53
+ query: The search query to find relevant document passages
54
+ """
55
+ # Get collection_id and user_id from config
56
+ thread_config = config.get("configurable", {})
57
+ collection_id = thread_config.get("collection_id")
58
+ user_id = thread_config.get("user_id")
59
+
60
+ if not collection_id or not user_id:
61
+ return "Error: collection_id and user_id are required in the config"
62
+ try:
63
+ # Create query input
64
+ input_data = QueryInput(
65
+ collection_id=collection_id,
66
+ query=query,
67
+ user_id=user_id,
68
+ top_k=6
69
+ )
70
+
71
+ response = await query_collection(input_data)
72
+ results = []
73
+
74
+ # Access response directly since it's a Pydantic model
75
+ for r in response.results:
76
+ result_dict = {
77
+ "text": r.text,
78
+ "distance": r.distance,
79
+ "metadata": {
80
+ "document_id": r.metadata.get("document_id"),
81
+ "chunk_index": r.metadata.get("location", {}).get("chunk_index")
82
+ }
83
+ }
84
+ results.append(result_dict)
85
+
86
+ return str(results)
87
+
88
+ except Exception as e:
89
+ print(e)
90
+ return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
91
+
92
+
93
+ async def query_documents_raw(
94
+ query: str,
95
+ config: RunnableConfig,
96
+ #state: Annotated[dict, InjectedState]
97
+ ) -> SearchResult:
98
+ """Use this tool to retrieve relevant data from the collection.
99
+
100
+ Args:
101
+ query: The search query to find relevant document passages
102
+ """
103
+ # Get collection_id and user_id from config
104
+ thread_config = config.get("configurable", {})
105
+ collection_id = thread_config.get("collection_id")
106
+ user_id = thread_config.get("user_id")
107
+
108
+ if not collection_id or not user_id:
109
+ return "Error: collection_id and user_id are required in the config"
110
+ try:
111
+ # Create query input
112
+ input_data = QueryInput(
113
+ collection_id=collection_id,
114
+ query=query,
115
+ user_id=user_id,
116
+ top_k=6
117
+ )
118
+
119
+ response = await query_collection(input_data)
120
+ return response.results
121
+
122
+ except Exception as e:
123
+ print(e)
124
+ return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"
125
+
126
+ memory = MemorySaver()
127
+ model = ChatOpenAI(model="gpt-4o-mini", streaming=True)
128
+
129
+ def state_modifier(state) -> list[BaseMessage]:
130
+ return trim_messages(
131
+ state["messages"],
132
+ token_counter=len,
133
+ max_tokens=16000,
134
+ strategy="last",
135
+ start_on="human",
136
+ include_system=True,
137
+ allow_partial=False,
138
+ )
139
+
140
+ agent = create_react_agent(
141
+ model,
142
+ tools=[query_documents],
143
+ checkpointer=memory,
144
+ state_modifier=state_modifier,
145
+ )
146
+
147
+ class ChatInput(BaseModel):
148
+ message: str
149
+ thread_id: Optional[str] = None
150
+ collection_id: Optional[str] = None
151
+ user_id: Optional[str] = None
152
+
153
+ @app.post("/chat")
154
+ async def chat(input_data: ChatInput):
155
+ thread_id = input_data.thread_id or str(uuid.uuid4())
156
+
157
+ config = {
158
+ "configurable": {
159
+ "thread_id": thread_id,
160
+ "collection_id": input_data.collection_id,
161
+ "user_id": input_data.user_id
162
+ }
163
+ }
164
+
165
+ input_message = HumanMessage(content=input_data.message)
166
+
167
+ async def generate():
168
+ async for event in agent.astream_events(
169
+ {"messages": [input_message]},
170
+ config,
171
+ version="v2"
172
+ ):
173
+ kind = event["event"]
174
+
175
+ if kind == "on_chat_model_stream":
176
+ content = event["data"]["chunk"].content
177
+ if content:
178
+ yield f"{json.dumps({'type': 'token', 'content': content})}"
179
+
180
+ elif kind == "on_tool_start":
181
+ tool_input = str(event['data'].get('input', ''))
182
+ yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
183
+
184
+ elif kind == "on_tool_end":
185
+ tool_output = str(event['data'].get('output', ''))
186
+ yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
187
+
188
+ return EventSourceResponse(
189
+ generate(),
190
+ media_type="text/event-stream"
191
+ )
192
+
193
+ async def clean_tool_input(tool_input: str):
194
+ # Use regex to parse the first key and value
195
+ pattern = r"{\s*'([^']+)':\s*'([^']+)'"
196
+ match = re.search(pattern, tool_input)
197
+ if match:
198
+ key, value = match.groups()
199
+ return {key: value}
200
+ return [tool_input]
201
+
202
+ async def clean_tool_response(tool_output: str):
203
+ """Clean and extract relevant information from tool response if it contains query_documents."""
204
+ if "query_documents" in tool_output:
205
+ try:
206
+ # First safely evaluate the string as a Python literal
207
+ import ast
208
+ print(tool_output)
209
+ # Extract the list string from the content
210
+ start = tool_output.find("[{")
211
+ end = tool_output.rfind("}]") + 2
212
+ if start >= 0 and end > 0:
213
+ list_str = tool_output[start:end]
214
+
215
+ # Convert string to Python object using ast.literal_eval
216
+ results = ast.literal_eval(list_str)
217
+
218
+ # Return only relevant fields
219
+ return [{"text": r["text"], "document_id": r["metadata"]["document_id"]}
220
+ for r in results]
221
+
222
+ except SyntaxError as e:
223
+ print(f"Syntax error in parsing: {e}")
224
+ return f"Error parsing document results: {str(e)}"
225
+ except Exception as e:
226
+ print(f"General error: {e}")
227
+ return f"Error processing results: {str(e)}"
228
+ return tool_output
229
+
230
+ @app.post("/chat2")
231
+ async def chat2(input_data: ChatInput):
232
+ thread_id = input_data.thread_id or str(uuid.uuid4())
233
+
234
+ config = {
235
+ "configurable": {
236
+ "thread_id": thread_id,
237
+ "collection_id": input_data.collection_id,
238
+ "user_id": input_data.user_id
239
+ }
240
+ }
241
+
242
+ input_message = HumanMessage(content=input_data.message)
243
+
244
+ async def generate():
245
+ async for event in agent.astream_events(
246
+ {"messages": [input_message]},
247
+ config,
248
+ version="v2"
249
+ ):
250
+ kind = event["event"]
251
+
252
+ if kind == "on_chat_model_stream":
253
+ content = event["data"]["chunk"].content
254
+ if content:
255
+ yield f"{json.dumps({'type': 'token', 'content': content})}"
256
+
257
+ elif kind == "on_tool_start":
258
+ tool_name = event['name']
259
+ tool_input = event['data'].get('input', '')
260
+ clean_input = await clean_tool_input(str(tool_input))
261
+ yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
262
+
263
+ elif kind == "on_tool_end":
264
+ if "query_documents" in event['name']:
265
+ print(event)
266
+ raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
267
+ try:
268
+ serializable_output = [
269
+ {
270
+ "text": result.text,
271
+ "distance": result.distance,
272
+ "metadata": result.metadata
273
+ }
274
+ for result in raw_output
275
+ ]
276
+ yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
277
+ except Exception as e:
278
+ print(e)
279
+ yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
280
+ else:
281
+ tool_name = event['name']
282
+ raw_output = str(event['data'].get('output', ''))
283
+ clean_output = await clean_tool_response(raw_output)
284
+ yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
285
+
286
+ return EventSourceResponse(
287
+ generate(),
288
+ media_type="text/event-stream"
289
+ )
290
+
291
+ @app.get("/health")
292
+ async def health_check():
293
+ return {"status": "healthy"}
readme.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Document RAG User API
2
+
3
+ This is a FastAPI application for processing and managing document uploads, including PDF and text files. The application allows users to upload files, query collections, and manage their document data.
4
+
5
+ ## Features
6
+
7
+ - Upload files in various formats (PDF, TXT, etc.)
8
+ - Efficiently process and store document chunks with metadata
9
+ - Perform queries on collections using user-defined input
10
+ - Retrieve and list collections specific to each user
11
+ - Remove collections as needed
12
+
13
+ ## Requirements
14
+
15
+ - Python 3.7+
16
+ - FastAPI
17
+ - LanceDB
18
+ - Pydantic
19
+ - Pandas
20
+ - Other dependencies as specified in `requirements.txt`
21
+
22
+ ## Installation
23
+
24
+ 1. Clone the repository:
25
+ ```bash
26
+ git clone <repository-url>
27
+ cd <repository-directory>
28
+ ```
29
+
30
+ 2. Install the required packages:
31
+ ```bash
32
+ pip install -r requirements.txt
33
+ ```
34
+
35
+ 3. Run the application:
36
+ ```bash
37
+ uvicorn app.document_rag_user:app --reload
38
+ ```
39
+
40
+ ## API Endpoints
41
+
42
+ ### Upload Files
43
+
44
+ - **POST** `/upload_files`
45
+ - Upload multiple files.
46
+ - Parameters:
47
+ - `files`: List of files to upload.
48
+ - `collection_name`: Optional name for the collection.
49
+ - `user_id`: User identifier.
50
+
51
+ ### Get Document
52
+
53
+ - **GET** `/get_document/{collection_id}/{document_id}`
54
+ - Retrieve a specific document by its ID from a collection.
55
+ - Parameters:
56
+ - `collection_id`: ID of the collection.
57
+ - `document_id`: ID of the document.
58
+ - `user_id`: User identifier.
59
+
60
+ ### Query Collection
61
+
62
+ - **POST** `/query_collection`
63
+ - Query a collection based on user input.
64
+ - Request Body:
65
+ - `collection_id`: ID of the collection.
66
+ - `query`: Search query.
67
+ - `top_k`: Optional number of top results to return (default is 3).
68
+ - `user_id`: User identifier.
69
+
70
+ ### List Collections
71
+
72
+ - **GET** `/list_collections`
73
+ - List all collections for a specific user.
74
+ - Parameters:
75
+ - `user_id`: User identifier.
76
+
77
+ ### Delete Collection
78
+
79
+ - **DELETE** `/delete_collection/{collection_id}`
80
+ - Delete a specific collection.
81
+ - Parameters:
82
+ - `collection_id`: ID of the collection to delete.
83
+ - `user_id`: User identifier.
84
+
85
+ ## Contributing
86
+
87
+ Contributions are welcome! Please open an issue or submit a pull request for any improvements or bug fixes.
88
+
89
+ ## License
90
+
91
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi[standard]
2
+ langchain-core
3
+ langchain-openai
4
+ langgraph
5
+ pydantic
6
+ pandas
7
+ lancedb
8
+ pymupdf
9
+ langchain-text-splitters
10
+ sse-starlette
11
+ typing-extensions
12
+ tantivy
utils.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains Utility functions for LLM and Database module. Along with some other misllaneous functions.
3
+ """
4
+
5
+ from turtle import clear
6
+ from pymupdf import pymupdf
7
+ #from docx import Document
8
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
9
+ #import tiktoken
10
+ import base64
11
+ import hashlib
12
+ from typing import List
13
+ from openai import OpenAI
14
+ #from dotenv import load_dotenv
15
+ import os
16
+ import hashlib
17
+ from datetime import datetime
18
+ from typing import List, Optional, Dict, Any, Tuple
19
+
20
+ def generate_file_id(file_bytes: bytes) -> str:
21
+ """Generate a 4-character unique file ID for given file."""
22
+ hash_obj = hashlib.sha256()
23
+ hash_obj.update(file_bytes[:4096]) # Still hash the first 4096 bytes
24
+ # Take first 2 bytes (16 bits) and convert to base36 (alphanumeric)
25
+ file_id = hex(int.from_bytes(hash_obj.digest()[:2], 'big'))[2:].zfill(4)
26
+ return file_id
27
+
28
+
29
+ def process_pdf_to_chunks(
30
+ pdf_content: bytes,
31
+ file_name: str,
32
+ chunk_size: int = 512,
33
+ chunk_overlap: int = 20
34
+ ) -> Tuple[List[Dict[str, Any]], str]:
35
+ """
36
+ Process PDF content into chunks with column layout detection and proper image handling
37
+ """
38
+ doc = pymupdf.open(stream=pdf_content, filetype="pdf")
39
+ document_text = ""
40
+ all_images = []
41
+ image_positions = []
42
+ char_to_page_map = []
43
+ layout_info = {}
44
+
45
+ doc_id = generate_file_id(pdf_content)
46
+
47
+ def detect_columns(blocks):
48
+ """Detect if page has multiple columns based on text block positions"""
49
+ if not blocks:
50
+ return 1
51
+
52
+ x_positions = [block[0] for block in blocks]
53
+ x_positions.sort()
54
+
55
+ if len(x_positions) > 1:
56
+ gaps = [x_positions[i+1] - x_positions[i] for i in range(len(x_positions)-1)]
57
+ significant_gaps = [gap for gap in gaps if gap > page.rect.width * 0.15]
58
+ return len(significant_gaps) + 1
59
+ return 1
60
+
61
+ def sort_blocks_by_position(blocks, num_columns):
62
+ """Sort blocks by column and vertical position"""
63
+ if num_columns == 1:
64
+ return sorted(blocks, key=lambda b: b[0][1]) # b[0] is the bbox tuple, b[0][1] is y coordinate
65
+
66
+ page_width = page.rect.width
67
+ column_width = page_width / num_columns
68
+
69
+ def get_column(block):
70
+ bbox = block[0] # Get the bounding box tuple
71
+ x_coord = bbox[0] # Get the x coordinate (first element)
72
+ return int(x_coord // column_width)
73
+
74
+ return sorted(blocks, key=lambda b: (get_column(b), b[0][1]))
75
+
76
+ # Process each page
77
+ for page_num, page in enumerate(doc, 1):
78
+ blocks = page.get_text_blocks()
79
+ images = page.get_images()
80
+
81
+ # Detect layout
82
+ num_columns = detect_columns(blocks)
83
+ layout_info[page_num] = {
84
+ "columns": num_columns,
85
+ "width": page.rect.width,
86
+ "height": page.rect.height
87
+ }
88
+
89
+ # Create elements list with both text and images
90
+ elements = [(block[:4], block[4], "text") for block in blocks]
91
+
92
+ # Add images to elements
93
+
94
+ for img in images:
95
+ try:
96
+ img_rects = page.get_image_rects(img[0])
97
+ if img_rects and len(img_rects) > 0:
98
+ img_bbox = img_rects[0]
99
+ if img_bbox:
100
+ img_data = (img_bbox, img[0], "image")
101
+ elements.append(img_data)
102
+ except Exception as e:
103
+ print(f"Error processing image: {e}")
104
+ continue
105
+
106
+ # Sort elements by position
107
+ sorted_elements = sort_blocks_by_position(elements, num_columns)
108
+
109
+ # Process elements in order
110
+ page_text = ""
111
+ for element in sorted_elements:
112
+ if element[2] == "text":
113
+ text_content = element[1]
114
+ page_text += text_content
115
+ char_to_page_map.extend([page_num] * len(text_content))
116
+ else:
117
+ xref = element[1]
118
+ base_image = doc.extract_image(xref)
119
+ image_bytes = base_image["image"]
120
+ # Convert image bytes to base64
121
+ image_base64 = base64.b64encode(image_bytes).decode('utf-8')
122
+ all_images.append(image_base64) # Store base64 encoded image
123
+
124
+ image_marker = f"\n<img_{len(all_images)-1}>\n"
125
+ image_positions.append((len(all_images)-1, len(document_text) + len(page_text)))
126
+ page_text += image_marker
127
+ char_to_page_map.extend([page_num] * len(image_marker))
128
+
129
+ document_text += page_text
130
+
131
+ # Create chunks
132
+ splitter = RecursiveCharacterTextSplitter(
133
+ #separators=["\n\n", "\n", " ", ""],
134
+ #keep_separator=True
135
+ ).from_tiktoken_encoder(
136
+ encoding_name="cl100k_base",
137
+ chunk_size=chunk_size,
138
+ chunk_overlap=chunk_overlap
139
+ )
140
+
141
+ text_chunks = splitter.split_text(document_text)
142
+
143
+ # Process chunks with metadata
144
+ processed_chunks = []
145
+ for chunk_idx, chunk in enumerate(text_chunks):
146
+ chunk_start = document_text.find(chunk)
147
+ chunk_end = chunk_start + len(chunk)
148
+
149
+ # Get page range and layout info
150
+ chunk_pages = sorted(set(char_to_page_map[chunk_start:chunk_end]))
151
+ chunk_layouts = {page: layout_info[page] for page in chunk_pages}
152
+
153
+ # Get images for this chunk
154
+ chunk_images = []
155
+ for img_idx, img_pos in image_positions:
156
+ if chunk_start <= img_pos <= chunk_end:
157
+ chunk_images.append(all_images[img_idx]) # Already base64 encoded
158
+
159
+ # Clean the chunk text
160
+ #cleaned_chunk = clean_text_for_llm(chunk)
161
+
162
+ chunk_dict = {
163
+ "text": chunk,
164
+ "metadata": {
165
+ "created_date": datetime.now().isoformat(),
166
+ "file_name": file_name,
167
+ "images": chunk_images,
168
+ "document_id": doc_id,
169
+ "location": {
170
+ "char_start": chunk_start,
171
+ "char_end": chunk_end,
172
+ "pages": chunk_pages,
173
+ "chunk_index": chunk_idx,
174
+ "total_chunks": len(text_chunks),
175
+ "layout": chunk_layouts
176
+ }
177
+ }
178
+ }
179
+ processed_chunks.append(chunk_dict)
180
+
181
+ return processed_chunks, doc_id
182
+
183
+
184
+
185
+ # import re
186
+ # import unicodedata
187
+ # from typing import Optional
188
+
189
+ # # Compile regex patterns once
190
+ # HTML_TAG_PATTERN = re.compile(r'<[^>]+>')
191
+ # MULTIPLE_NEWLINES = re.compile(r'\n\s*\n')
192
+ # MULTIPLE_SPACES = re.compile(r'\s+')
193
+
194
+ # def clean_text_for_llm(text: Optional[str]) -> str:
195
+ # """
196
+ # Efficiently clean and normalize text for LLM processing.
197
+ # """
198
+ # # Early returns
199
+ # if not text:
200
+ # return ""
201
+ # if not isinstance(text, str):
202
+ # try:
203
+ # text = str(text)
204
+ # except Exception:
205
+ # return ""
206
+
207
+ # # Single-pass character filtering
208
+ # chars = []
209
+ # prev_char = ''
210
+ # space_pending = False
211
+
212
+ # for char in text:
213
+ # # Skip null bytes and most control characters
214
+ # if char == '\0' or unicodedata.category(char).startswith('C'):
215
+ # if char not in '\n\t':
216
+ # continue
217
+
218
+ # # Convert escaped sequences
219
+ # if prev_char == '\\':
220
+ # if char == 'n':
221
+ # chars[-1] = '\n'
222
+ # continue
223
+ # if char == 't':
224
+ # chars[-1] = '\t'
225
+ # continue
226
+
227
+ # # Handle whitespace
228
+ # if char.isspace():
229
+ # if not space_pending:
230
+ # space_pending = True
231
+ # continue
232
+
233
+ # if space_pending:
234
+ # chars.append(' ')
235
+ # space_pending = False
236
+
237
+ # chars.append(char)
238
+ # prev_char = char
239
+
240
+ # # Join characters and perform remaining operations
241
+ # text = ''.join(chars)
242
+
243
+ # # Remove HTML tags
244
+ # #text = HTML_TAG_PATTERN.sub('', text)
245
+
246
+ # # Normalize Unicode in a single pass
247
+ # text = unicodedata.normalize('NFKC', text)
248
+
249
+ # # Clean up newlines
250
+ # text = MULTIPLE_NEWLINES.sub('\n', text)
251
+
252
+ # Final trim
253
+ # return text.strip()