File size: 2,591 Bytes
ccafeaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
06b5774
 
 
 
ccafeaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Dict
from huggingface_hub import InferenceClient

# Initialize FastAPI app
app = FastAPI()

# Initialize Hugging Face Inference Client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Define request and response models
class Message(BaseModel):
    role: str
    content: str

class ChatRequest(BaseModel):
    model: str
    messages: List[Message]
    temperature: float = 1.0
    top_p: float = 1.0
    max_tokens: int = 256
    stream: Optional[bool] = False
    stop: Optional[List[str]] = None

class ChatResponse(BaseModel):
    id: str
    object: str
    created: int
    choices: List[Dict[str, str]]

# Helper to format prompt
def format_prompt(messages: List[Message]):
    prompt = "<s>"
    for msg in messages:
        if msg.role == "system":
            prompt += f"[INST] {msg.content} [/INST]"
        elif msg.role == "user":
            prompt += f"[INST] {msg.content} [/INST]"
        elif msg.role == "assistant":
            prompt += f" {msg.content}</s>"
    return prompt

# Generate text using Hugging Face model
def generate_response(formatted_prompt, max_tokens, temperature, top_p):
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_p=top_p,
        repetition_penalty=1.2,
        do_sample=True,
        seed=42,
    )
    output = ""
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True)
    for response in stream:
        output += response.token.text
    return output.strip()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/v1/chat/completions", response_model=ChatResponse)
async def chat_completions(request: ChatRequest):
    # Validate model name
    if request.model != "Mixtral-8x7B-Instruct-v0.1":
        raise HTTPException(status_code=400, detail="Invalid model specified")

    # Format prompt
    formatted_prompt = format_prompt(request.messages)
    
    # Generate response
    response_text = generate_response(
        formatted_prompt, 
        max_tokens=request.max_tokens, 
        temperature=request.temperature, 
        top_p=request.top_p
    )
    
    # Format API response
    response = ChatResponse(
        id="chatcmpl-001",
        object="chat.completion",
        created=1234567890,  # Replace with actual timestamp
        choices=[
            {"index": 0, "message": {"role": "assistant", "content": response_text}}
        ]
    )
    return response