Dash-inc's picture
Update main.py
0d39048 verified
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends
from pydantic import BaseModel
from fastapi.staticfiles import StaticFiles
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
import uuid
import time
import tempfile
from concurrent.futures import ThreadPoolExecutor
from pymongo import MongoClient
from urllib.parse import quote_plus
from langchain_groq import ChatGroq
from aura_sr import AuraSR
from io import BytesIO
from PIL import Image
import requests
import os
import logging
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Validate environment variables
assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!"
assert os.getenv('LLM_API_KEY'), "LLM API Key missing!"
assert os.getenv('BFL_API_KEY'), "BFL API Key missing!"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Set the Hugging Face cache directory to a writable location
os.environ['HF_HOME'] = '/tmp/huggingface_cache'
app = FastAPI()
# Middleware for CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# Globals
executor = ThreadPoolExecutor(max_workers=5)
llm = None
mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/")
db = mongo_client["Flux"]
collection = db["chat_histories"]
chat_sessions = {}
# Temporary directory for storing images
image_storage_dir = tempfile.mkdtemp()
app.mount("/images", StaticFiles(directory=image_storage_dir), name="images")
# Initialize AuraSR during startup
aura_sr = None
@app.on_event("startup")
async def startup():
global llm, aura_sr
try:
llm = ChatGroq(
model="llama-3.3-70b-versatile",
temperature=0.7,
max_tokens=1024,
api_key=os.getenv('LLM_API_KEY'),
)
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2")
except Exception as e:
logger.error(f"Error initializing models: {e}")
@app.on_event("shutdown")
def shutdown():
mongo_client.close()
executor.shutdown()
# Pydantic models
class ImageRequest(BaseModel):
subject: str
style: str
color_theme: str
elements: str
color_mode: str
lighting_conditions: str
framing_style: str
material_details: str
text: str
background_details: str
user_prompt: str
chat_id: str
class UpscaleRequest(BaseModel):
image_url: str
# Helper functions
def generate_chat_id():
chat_id = str(uuid.uuid4())
chat_sessions[chat_id] = collection
return chat_id
def get_chat_history(chat_id):
messages = collection.find({"session_id": chat_id})
return "\n".join(
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}"
for msg in messages
)
def save_image_locally(image, filename):
filepath = os.path.join(image_storage_dir, filename)
image.save(filepath, format="PNG")
return filepath
def make_request_with_retries(url, headers, payload, retries=3, delay=2):
"""
Makes an HTTP POST request with retries in case of failure.
:param url: The URL for the request.
:param headers: Headers to include in the request.
:param payload: Payload to include in the request.
:param retries: Number of retries on failure.
:param delay: Delay between retries.
:return: Response JSON from the server.
"""
for attempt in range(retries):
try:
with requests.Session() as session:
response = session.post(url, headers=headers, json=payload, timeout=30)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
if attempt < retries - 1:
time.sleep(delay)
continue
else:
raise HTTPException(status_code=500, detail=f"Request failed after {retries} attempts: {str(e)}")
def fetch_image(url):
try:
with requests.Session() as session:
response = session.get(url, timeout=30)
response.raise_for_status()
return Image.open(BytesIO(response.content))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}")
def poll_for_image_result(request_id, headers):
timeout = 60
start_time = time.time()
while time.time() - start_time < timeout:
time.sleep(0.5)
with requests.Session() as session:
result = session.get(
"https://api.bfl.ml/v1/get_result",
headers=headers,
params={"id": request_id},
timeout=10
).json()
if result["status"] == "Ready":
return result["result"].get("sample")
elif result["status"] == "Error":
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}")
raise HTTPException(status_code=500, detail="Image generation timed out.")
@app.post("/new-chat", response_model=dict)
async def new_chat():
chat_id = generate_chat_id()
return {"chat_id": chat_id}
@app.post("/generate-image", response_model=dict)
async def generate_image(request: ImageRequest):
chat_history = get_chat_history(request.chat_id)
prompt = f"""
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image.
Image Specifications:
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures.
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics.
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition.
- **Camera and Lighting**:
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background.
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background.
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background.
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background.
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene.
- Negative Prompt: Avoid grainy, blurry, or deformed outputs.
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way.
"""
refined_prompt = llm.invoke(prompt).content.strip()
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt})
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt})
headers = {
"accept": "application/json",
"x-key": os.getenv('BFL_API_KEY'),
"Content-Type": "application/json"
}
payload = {
"prompt": refined_prompt,
"width": 1024,
"height": 1024,
"guidance_scale": 1,
"num_inference_steps": 50,
"max_sequence_length": 512,
}
response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload)
if "id" not in response:
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response")
image_url = poll_for_image_result(response["id"], headers)
image = fetch_image(image_url)
filename = f"generated_{uuid.uuid4()}.png"
filepath = save_image_locally(image, filename)
return {
"status": "Image generated successfully",
"file_path": filepath,
"file_url": f"/images/{filename}",
}
@app.post("/upscale-image", response_model=dict)
async def upscale_image(request: UpscaleRequest):
if aura_sr is None:
raise HTTPException(status_code=500, detail="Upscaling model not initialized.")
img = await run_in_threadpool(fetch_image, request.image_url)
def perform_upscaling():
upscaled_image = aura_sr.upscale_4x_overlapped(img)
filename = f"upscaled_{uuid.uuid4()}.png"
return save_image_locally(upscaled_image, filename)
future = executor.submit(perform_upscaling)
filepath = await run_in_threadpool(lambda: future.result())
return {
"status": "Upscaling successful",
"file_path": filepath,
"file_url": f"/images/{os.path.basename(filepath)}",
}
@app.get("/")
async def root():
return {"message": "API is up and running!"}