Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List, Optional, Dict | |
from huggingface_hub import InferenceClient | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Initialize Hugging Face Inference Client | |
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1") | |
# Define request and response models | |
class Message(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
model: str | |
messages: List[Message] | |
temperature: float = 1.0 | |
top_p: float = 1.0 | |
max_tokens: int = 256 | |
stream: Optional[bool] = False | |
stop: Optional[List[str]] = None | |
class ChatResponse(BaseModel): | |
id: str | |
object: str | |
created: int | |
choices: List[Dict[str, str]] | |
# Helper to format prompt | |
def format_prompt(messages: List[Message]): | |
prompt = "<s>" | |
for msg in messages: | |
if msg.role == "system": | |
prompt += f"[INST] {msg.content} [/INST]" | |
elif msg.role == "user": | |
prompt += f"[INST] {msg.content} [/INST]" | |
elif msg.role == "assistant": | |
prompt += f" {msg.content}</s>" | |
return prompt | |
# Generate text using Hugging Face model | |
def generate_response(formatted_prompt, max_tokens, temperature, top_p): | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_tokens, | |
top_p=top_p, | |
repetition_penalty=1.2, | |
do_sample=True, | |
seed=42, | |
) | |
output = "" | |
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True) | |
for response in stream: | |
output += response.token.text | |
return output.strip() | |
def greet_json(): | |
return {"Hello": "World!"} | |
async def chat_completions(request: ChatRequest): | |
# Validate model name | |
if request.model != "Mixtral-8x7B-Instruct-v0.1": | |
raise HTTPException(status_code=400, detail="Invalid model specified") | |
# Format prompt | |
formatted_prompt = format_prompt(request.messages) | |
# Generate response | |
response_text = generate_response( | |
formatted_prompt, | |
max_tokens=request.max_tokens, | |
temperature=request.temperature, | |
top_p=request.top_p | |
) | |
# Format API response | |
response = ChatResponse( | |
id="chatcmpl-001", | |
object="chat.completion", | |
created=1234567890, # Replace with actual timestamp | |
choices=[ | |
{"index": 0, "message": {"role": "assistant", "content": response_text}} | |
] | |
) | |
return response |