Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends | |
from pydantic import BaseModel | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.concurrency import run_in_threadpool | |
from fastapi.middleware.cors import CORSMiddleware | |
import uuid | |
import time | |
import tempfile | |
from concurrent.futures import ThreadPoolExecutor | |
from pymongo import MongoClient | |
from urllib.parse import quote_plus | |
from langchain_groq import ChatGroq | |
from aura_sr import AuraSR | |
from io import BytesIO | |
from PIL import Image | |
import requests | |
import os | |
import logging | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Validate environment variables | |
assert os.getenv('MONGO_USER') and os.getenv('MONGO_PASSWORD') and os.getenv('MONGO_HOST'), "MongoDB credentials missing!" | |
assert os.getenv('LLM_API_KEY'), "LLM API Key missing!" | |
assert os.getenv('BFL_API_KEY'), "BFL API Key missing!" | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Set the Hugging Face cache directory to a writable location | |
os.environ['HF_HOME'] = '/tmp/huggingface_cache' | |
app = FastAPI() | |
# Middleware for CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Globals | |
executor = ThreadPoolExecutor(max_workers=5) | |
llm = None | |
mongo_client = MongoClient(f"mongodb+srv://{os.getenv('MONGO_USER')}:{quote_plus(os.getenv('MONGO_PASSWORD'))}@{os.getenv('MONGO_HOST')}/") | |
db = mongo_client["Flux"] | |
collection = db["chat_histories"] | |
chat_sessions = {} | |
# Temporary directory for storing images | |
image_storage_dir = tempfile.mkdtemp() | |
app.mount("/images", StaticFiles(directory=image_storage_dir), name="images") | |
# Initialize AuraSR during startup | |
aura_sr = None | |
async def startup(): | |
global llm, aura_sr | |
try: | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", | |
temperature=0.7, | |
max_tokens=1024, | |
api_key=os.getenv('LLM_API_KEY'), | |
) | |
aura_sr = AuraSR.from_pretrained("fal/AuraSR-v2") | |
except Exception as e: | |
logger.error(f"Error initializing models: {e}") | |
def shutdown(): | |
mongo_client.close() | |
executor.shutdown() | |
# Pydantic models | |
class ImageRequest(BaseModel): | |
subject: str | |
style: str | |
color_theme: str | |
elements: str | |
color_mode: str | |
lighting_conditions: str | |
framing_style: str | |
material_details: str | |
text: str | |
background_details: str | |
user_prompt: str | |
chat_id: str | |
class UpscaleRequest(BaseModel): | |
image_url: str | |
# Helper functions | |
def generate_chat_id(): | |
chat_id = str(uuid.uuid4()) | |
chat_sessions[chat_id] = collection | |
return chat_id | |
def get_chat_history(chat_id): | |
messages = collection.find({"session_id": chat_id}) | |
return "\n".join( | |
f"User: {msg['content']}" if msg['role'] == "user" else f"AI: {msg['content']}" | |
for msg in messages | |
) | |
def save_image_locally(image, filename): | |
filepath = os.path.join(image_storage_dir, filename) | |
image.save(filepath, format="PNG") | |
return filepath | |
def make_request_with_retries(url, headers, payload, retries=3, delay=2): | |
""" | |
Makes an HTTP POST request with retries in case of failure. | |
:param url: The URL for the request. | |
:param headers: Headers to include in the request. | |
:param payload: Payload to include in the request. | |
:param retries: Number of retries on failure. | |
:param delay: Delay between retries. | |
:return: Response JSON from the server. | |
""" | |
for attempt in range(retries): | |
try: | |
with requests.Session() as session: | |
response = session.post(url, headers=headers, json=payload, timeout=30) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
if attempt < retries - 1: | |
time.sleep(delay) | |
continue | |
else: | |
raise HTTPException(status_code=500, detail=f"Request failed after {retries} attempts: {str(e)}") | |
def fetch_image(url): | |
try: | |
with requests.Session() as session: | |
response = session.get(url, timeout=30) | |
response.raise_for_status() | |
return Image.open(BytesIO(response.content)) | |
except Exception as e: | |
raise HTTPException(status_code=400, detail=f"Error fetching image: {str(e)}") | |
def poll_for_image_result(request_id, headers): | |
timeout = 60 | |
start_time = time.time() | |
while time.time() - start_time < timeout: | |
time.sleep(0.5) | |
with requests.Session() as session: | |
result = session.get( | |
"https://api.bfl.ml/v1/get_result", | |
headers=headers, | |
params={"id": request_id}, | |
timeout=10 | |
).json() | |
if result["status"] == "Ready": | |
return result["result"].get("sample") | |
elif result["status"] == "Error": | |
raise HTTPException(status_code=500, detail=f"Image generation failed: {result.get('error', 'Unknown error')}") | |
raise HTTPException(status_code=500, detail="Image generation timed out.") | |
async def new_chat(): | |
chat_id = generate_chat_id() | |
return {"chat_id": chat_id} | |
async def generate_image(request: ImageRequest): | |
chat_history = get_chat_history(request.chat_id) | |
prompt = f""" | |
You are a professional assistant responsible for crafting a clear and visually compelling prompt for an image generation model. Your task is to generate a high-quality prompt for creating both the **main subject** and the **background** of the image. | |
Image Specifications: | |
- **Subject**: Focus on **{request.subject}**, highlighting its defining features, expressions, and textures. | |
- **Style**: Emphasize the **{request.style}**, capturing its key characteristics. | |
- **Background**: Create a background with **{request.background_details}** that complements and enhances the subject. Ensure it aligns with the color theme and overall composition. | |
- **Camera and Lighting**: | |
- Lighting: Apply **{request.lighting_conditions}**, emphasizing depth, highlights, and shadows to accentuate the subject and harmonize the background. | |
- **Framing**: Use a **{request.framing_style}** to enhance the composition around both the subject and the background. | |
- **Materials**: Highlight textures like **{request.material_details}**, with realistic details and natural imperfections on the subject and background. | |
- **Key Elements**: Include **{request.elements}** to enrich the subject’s details and add cohesive elements to the background. | |
- **Color Theme**: Follow the **{request.color_theme}** to set the mood and tone for the entire scene. | |
- Negative Prompt: Avoid grainy, blurry, or deformed outputs. | |
- **Text to Include in Image**: Clearly display the text **"{request.text}"** as part of the composition (e.g., on a card, badge, or banner) attached to the subject in a realistic and contextually appropriate way. | |
""" | |
refined_prompt = llm.invoke(prompt).content.strip() | |
collection.insert_one({"session_id": request.chat_id, "role": "user", "content": request.user_prompt}) | |
collection.insert_one({"session_id": request.chat_id, "role": "ai", "content": refined_prompt}) | |
headers = { | |
"accept": "application/json", | |
"x-key": os.getenv('BFL_API_KEY'), | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"prompt": refined_prompt, | |
"width": 1024, | |
"height": 1024, | |
"guidance_scale": 1, | |
"num_inference_steps": 50, | |
"max_sequence_length": 512, | |
} | |
response = make_request_with_retries("https://api.bfl.ml/v1/flux-pro-1.1", headers, payload) | |
if "id" not in response: | |
raise HTTPException(status_code=500, detail="Error generating image: ID missing from response") | |
image_url = poll_for_image_result(response["id"], headers) | |
image = fetch_image(image_url) | |
filename = f"generated_{uuid.uuid4()}.png" | |
filepath = save_image_locally(image, filename) | |
return { | |
"status": "Image generated successfully", | |
"file_path": filepath, | |
"file_url": f"/images/{filename}", | |
} | |
async def upscale_image(request: UpscaleRequest): | |
if aura_sr is None: | |
raise HTTPException(status_code=500, detail="Upscaling model not initialized.") | |
img = await run_in_threadpool(fetch_image, request.image_url) | |
def perform_upscaling(): | |
upscaled_image = aura_sr.upscale_4x_overlapped(img) | |
filename = f"upscaled_{uuid.uuid4()}.png" | |
return save_image_locally(upscaled_image, filename) | |
future = executor.submit(perform_upscaling) | |
filepath = await run_in_threadpool(lambda: future.result()) | |
return { | |
"status": "Upscaling successful", | |
"file_path": filepath, | |
"file_url": f"/images/{os.path.basename(filepath)}", | |
} | |
async def root(): | |
return {"message": "API is up and running!"} | |