Spaces:
Running
Running
from fastapi import UploadFile, File, Form, HTTPException, APIRouter | |
from typing import List, Optional, Dict, Tuple | |
import lancedb | |
from lancedb.pydantic import LanceModel, Vector | |
from lancedb.embeddings import get_registry | |
import pandas as pd | |
from utils import process_pdf_to_chunks | |
import hashlib | |
import uuid | |
import json | |
from datetime import datetime | |
from pydantic import BaseModel | |
import logging | |
# Create router | |
router = APIRouter( | |
prefix="/rag", | |
tags=["rag"] | |
) | |
# Initialize LanceDB and embedding model | |
db = lancedb.connect("/tmp/db") | |
model = get_registry().get("sentence-transformers").create( | |
name="Snowflake/snowflake-arctic-embed-xs", | |
device="cpu" | |
) | |
def get_user_collection(user_id: str, collection_name: str) -> str: | |
"""Generate user-specific collection name""" | |
return f"{user_id}_{collection_name}" | |
class DocumentChunk(LanceModel): | |
text: str = model.SourceField() | |
vector: Vector(model.ndims()) = model.VectorField() | |
document_id: str | |
chunk_index: int | |
file_name: str | |
file_type: str | |
created_date: str | |
collection_id: str | |
user_id: str | |
metadata_json: str | |
char_start: int | |
char_end: int | |
page_numbers: List[int] | |
images: List[str] | |
class QueryInput(BaseModel): | |
collection_id: str | |
query: str | |
top_k: Optional[int] = 3 | |
user_id: str | |
class SearchResult(BaseModel): | |
text: str | |
distance: float | |
metadata: Dict # Added metadata field | |
class SearchResponse(BaseModel): | |
results: List[SearchResult] | |
async def process_file(file: UploadFile, collection_id: str, user_id: str) -> Tuple[List[dict], str]: | |
"""Process single file and return chunks with metadata""" | |
content = await file.read() | |
file_type = file.filename.split('.')[-1].lower() | |
chunks = [] | |
doc_id = "" | |
if file_type == 'pdf': | |
chunks, doc_id = process_pdf_to_chunks( | |
pdf_content=content, | |
file_name=file.filename | |
) | |
elif file_type == 'txt': | |
doc_id = hashlib.sha256(content).hexdigest()[:4] | |
text_content = content.decode('utf-8') | |
chunks = [{ | |
"text": text_content, | |
"metadata": { | |
"created_date": datetime.now().isoformat(), | |
"file_name": file.filename, | |
"document_id": doc_id, | |
"user_id": user_id, | |
"location": { | |
"chunk_index": 0, | |
"char_start": 0, | |
"char_end": len(text_content), | |
"pages": [1], | |
"total_chunks": 1 | |
}, | |
"images": [] | |
} | |
}] | |
return chunks, doc_id | |
async def upload_files( | |
files: List[UploadFile] = File(...), | |
collection_name: Optional[str] = Form(None), | |
user_id: str = Form(...) | |
): | |
try: | |
collection_id = get_user_collection( | |
user_id, | |
collection_name if collection_name else f"col_{uuid.uuid4().hex[:8]}" | |
) | |
all_chunks = [] | |
doc_ids = {} | |
for file in files: | |
try: | |
chunks, doc_id = await process_file(file, collection_id, user_id) | |
for chunk in chunks: | |
chunk_data = { | |
"text": chunk["text"], | |
"document_id": chunk["metadata"]["document_id"], | |
"chunk_index": chunk["metadata"]["location"]["chunk_index"], | |
"file_name": chunk["metadata"]["file_name"], | |
"file_type": file.filename.split('.')[-1].lower(), | |
"created_date": chunk["metadata"]["created_date"], | |
"collection_id": collection_id, | |
"user_id": user_id, | |
"metadata_json": json.dumps(chunk["metadata"]), | |
"char_start": chunk["metadata"]["location"]["char_start"], | |
"char_end": chunk["metadata"]["location"]["char_end"], | |
"page_numbers": chunk["metadata"]["location"]["pages"], | |
"images": chunk["metadata"].get("images", []) | |
} | |
all_chunks.append(chunk_data) | |
doc_ids[doc_id] = file.filename | |
except Exception as e: | |
logging.error(f"Error processing file {file.filename}: {str(e)}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"Error processing file {file.filename}: {str(e)}" | |
) | |
try: | |
table = db.open_table(collection_id) | |
except Exception as e: | |
logging.error(f"Error opening table: {str(e)}") | |
try: | |
table = db.create_table( | |
collection_id, | |
schema=DocumentChunk, | |
mode="create" | |
) | |
# Create FTS index on the text column for hybrid search support | |
# table.create_fts_index( | |
# field_names="text", | |
# replace=True, | |
# tokenizer_name="en_stem", # Use English stemming | |
# lower_case=True, # Convert text to lowercase | |
# remove_stop_words=True, # Remove common words like "the", "is", "at" | |
# writer_heap_size=1024 * 1024 * 1024 # 1GB heap size | |
# ) | |
except Exception as e: | |
logging.error(f"Error creating table: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error creating database table: {str(e)}" | |
) | |
try: | |
df = pd.DataFrame(all_chunks) | |
table.add(data=df) | |
except Exception as e: | |
logging.error(f"Error adding data to table: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error adding data to database: {str(e)}" | |
) | |
return { | |
"message": f"Successfully processed {len(files)} files", | |
"collection_id": collection_id, | |
"total_chunks": len(all_chunks), | |
"user_id": user_id, | |
"document_ids": doc_ids | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logging.error(f"Unexpected error during file upload: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Unexpected error: {str(e)}" | |
) | |
async def get_document( | |
collection_id: str, | |
document_id: str, | |
user_id: str | |
): | |
try: | |
table = db.open_table(f"{user_id}_{collection_id}") | |
except Exception as e: | |
logging.error(f"Error opening table: {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Collection not found: {str(e)}" | |
) | |
try: | |
chunks = table.to_pandas() | |
doc_chunks = chunks[ | |
(chunks['document_id'] == document_id) & | |
(chunks['user_id'] == user_id) | |
].sort_values('chunk_index') | |
if len(doc_chunks) == 0: | |
raise HTTPException( | |
status_code=404, | |
detail=f"Document {document_id} not found in collection {collection_id}" | |
) | |
return { | |
"document_id": document_id, | |
"file_name": doc_chunks.iloc[0]['file_name'], | |
"chunks": [ | |
{ | |
"text": row['text'], | |
"metadata": json.loads(row['metadata_json']) | |
} | |
for _, row in doc_chunks.iterrows() | |
] | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logging.error(f"Error retrieving document: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error retrieving document: {str(e)}" | |
) | |
async def query_collection(input_data: QueryInput): | |
try: | |
collection_id = get_user_collection(input_data.user_id, input_data.collection_id) | |
try: | |
table = db.open_table(collection_id) | |
except Exception as e: | |
logging.error(f"Error opening table: {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Collection not found: {str(e)}" | |
) | |
try: | |
results = ( | |
table.search(input_data.query) | |
.where(f"user_id = '{input_data.user_id}'") | |
.limit(input_data.top_k) | |
.to_list() | |
) | |
except Exception as e: | |
logging.error(f"Error searching collection: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error searching collection: {str(e)}" | |
) | |
return SearchResponse(results=[ | |
SearchResult( | |
text=r['text'], | |
distance=float(r['_distance']), | |
metadata=json.loads(r['metadata_json']) | |
) | |
for r in results | |
]) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logging.error(f"Unexpected error during query: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Unexpected error: {str(e)}" | |
) | |
async def list_collections(user_id: str): | |
try: | |
all_collections = db.table_names() | |
user_collections = [ | |
c for c in all_collections | |
if c.startswith(f"{user_id}_") | |
] | |
# Get documents for each collection | |
collections_info = [] | |
for collection_name in user_collections: | |
try: | |
table = db.open_table(collection_name) | |
df = table.to_pandas() | |
# Group by document_id to get unique documents | |
documents = df.groupby('document_id').agg({ | |
'file_name': 'first', | |
'created_date': 'first' | |
}).reset_index() | |
collections_info.append({ | |
"collection_id": collection_name.replace(f"{user_id}_", ""), | |
"documents": [ | |
{ | |
"document_id": row['document_id'], | |
"file_name": row['file_name'], | |
"created_date": row['created_date'] | |
} | |
for _, row in documents.iterrows() | |
] | |
}) | |
except Exception as e: | |
logging.error(f"Error processing collection {collection_name}: {str(e)}") | |
continue | |
return {"collections": collections_info} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def delete_collection(collection_id: str, user_id: str): | |
try: | |
full_collection_id = f"{user_id}_{collection_id}" | |
# Check if collection exists | |
try: | |
table = db.open_table(full_collection_id) | |
except Exception as e: | |
logging.error(f"Collection not found: {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Collection {collection_id} not found" | |
) | |
# Verify ownership | |
if not full_collection_id.startswith(f"{user_id}_"): | |
logging.error(f"Unauthorized deletion attempt for collection {collection_id} by user {user_id}") | |
raise HTTPException( | |
status_code=403, | |
detail="Not authorized to delete this collection" | |
) | |
try: | |
db.drop_table(full_collection_id) | |
except Exception as e: | |
logging.error(f"Error deleting collection {collection_id}: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error deleting collection: {str(e)}" | |
) | |
return { | |
"message": f"Collection {collection_id} deleted successfully", | |
"collection_id": collection_id | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
logging.error(f"Unexpected error deleting collection {collection_id}: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Unexpected error: {str(e)}" | |
) | |
async def query_collection_tool(input_data: QueryInput): | |
try: | |
response = await query_collection(input_data) | |
results = [] | |
# Access response directly since it's a Pydantic model | |
for r in response.results: | |
result_dict = { | |
"text": r.text, | |
"distance": r.distance, | |
"metadata": { | |
"document_id": r.metadata.get("document_id"), | |
"chunk_index": r.metadata.get("location", {}).get("chunk_index") | |
} | |
} | |
results.append(result_dict) | |
return str(results) | |
except Exception as e: | |
logging.error(f"Unexpected error during query: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Unexpected error: {str(e)}" | |
) |