from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional, Dict from huggingface_hub import InferenceClient # Initialize FastAPI app app = FastAPI() # Initialize Hugging Face Inference Client client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") # Define request and response models class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str messages: List[Message] temperature: float = 1.0 top_p: float = 1.0 max_tokens: int = 256 stream: Optional[bool] = False stop: Optional[List[str]] = None class ChatResponse(BaseModel): id: str object: str created: int choices: List[Dict[str, str]] # Helper to format prompt def format_prompt(messages: List[Message]): prompt = "" for msg in messages: if msg.role == "system": prompt += f"[INST] {msg.content} [/INST]" elif msg.role == "user": prompt += f"[INST] {msg.content} [/INST]" elif msg.role == "assistant": prompt += f" {msg.content}" return prompt # Generate text using Hugging Face model def generate_response(formatted_prompt, max_tokens, temperature, top_p): generate_kwargs = dict( temperature=temperature, max_new_tokens=max_tokens, top_p=top_p, repetition_penalty=1.2, do_sample=True, seed=42, ) output = "" stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True) for response in stream: output += response.token.text return output.strip() @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/v1/chat/completions", response_model=ChatResponse) async def chat_completions(request: ChatRequest): # Validate model name if request.model != "Mixtral-8x7B-Instruct-v0.1": raise HTTPException(status_code=400, detail="Invalid model specified") # Format prompt formatted_prompt = format_prompt(request.messages) # Generate response response_text = generate_response( formatted_prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p ) # Format API response response = ChatResponse( id="chatcmpl-001", object="chat.completion", created=1234567890, # Replace with actual timestamp choices=[ {"index": 0, "message": {"role": "assistant", "content": response_text}} ] ) return response