Spaces:
Sleeping
Sleeping
File size: 2,591 Bytes
ccafeaf 06b5774 ccafeaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
from huggingface_hub import InferenceClient
# Initialize FastAPI app
app = FastAPI()
# Initialize Hugging Face Inference Client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
# Define request and response models
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: float = 1.0
top_p: float = 1.0
max_tokens: int = 256
stream: Optional[bool] = False
stop: Optional[List[str]] = None
class ChatResponse(BaseModel):
id: str
object: str
created: int
choices: List[Dict[str, str]]
# Helper to format prompt
def format_prompt(messages: List[Message]):
prompt = "<s>"
for msg in messages:
if msg.role == "system":
prompt += f"[INST] {msg.content} [/INST]"
elif msg.role == "user":
prompt += f"[INST] {msg.content} [/INST]"
elif msg.role == "assistant":
prompt += f" {msg.content}</s>"
return prompt
# Generate text using Hugging Face model
def generate_response(formatted_prompt, max_tokens, temperature, top_p):
generate_kwargs = dict(
temperature=temperature,
max_new_tokens=max_tokens,
top_p=top_p,
repetition_penalty=1.2,
do_sample=True,
seed=42,
)
output = ""
stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True)
for response in stream:
output += response.token.text
return output.strip()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
# Validate model name
if request.model != "Mixtral-8x7B-Instruct-v0.1":
raise HTTPException(status_code=400, detail="Invalid model specified")
# Format prompt
formatted_prompt = format_prompt(request.messages)
# Generate response
response_text = generate_response(
formatted_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p
)
# Format API response
response = ChatResponse(
id="chatcmpl-001",
object="chat.completion",
created=1234567890, # Replace with actual timestamp
choices=[
{"index": 0, "message": {"role": "assistant", "content": response_text}}
]
)
return response |