File size: 2,766 Bytes
812f70a
 
 
0e02ca5
 
 
 
 
 
c6bb1bc
0e02ca5
1a6b000
ba913e7
0e02ca5
764f34e
 
 
 
f00ac1d
1a6b000
f00ac1d
764f34e
ba913e7
0e02ca5
1a6b000
0e02ca5
 
 
 
 
 
d259634
 
0e02ca5
 
 
 
 
f00ac1d
 
 
 
 
0e02ca5
 
 
 
 
 
d259634
0e02ca5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a6b000
0e02ca5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# https://www.gradio.app/guides/using-hugging-face-integrations

import gradio as gr
import logging
import html
import time
import torch
from   threading import Thread
from   transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

# Model
model_name = "augmxnt/shisa-7b-v1"

# UI Settings
title = "Shisa 7B"
description = "Test out Shisa 7B in either English or Japanese."
placeholder = "Type Here / ここに入力してください" 
examples = [
    "What's the best ramen in Tokyo?",
    "あなたは熱狂的なポケモンファンです。",
    "東京でおすすめのラーメン屋ってどこ?",
]

# LLM Settings
system_prompt = 'あなたは役に立つアシスタントです。'
chat_history = [{"role": "system", "content": system_prompt}]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    load_in_8bit=True,
    # load_in_4bit=True
)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)

def chat(message, history):
    chat_history.append({"role": "user", "content": message})
    input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
    # for multi-gpu, find the device of the first parameter of the model
    first_param_device = next(model.parameters()).device
    input_ids = input_ids.to(first_param_device)

    generate_kwargs = dict(
        inputs=input_ids,
        streamer=streamer,
        max_new_tokens=200,
        do_sample=True,
        temperature=0.7,
        repetition_penalty=1.15,
        top_p=0.95,
        eos_token_id=tokenizer.eos_token_id,
    )
    # https://www.gradio.app/main/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token # html.escape(new_token)
        yield partial_message


chat_interface = gr.ChatInterface(
    chat,
    chatbot=gr.Chatbot(height=400),
    textbox=gr.Textbox(placeholder=placeholder, container=False, scale=7),
    title=title,
    description=description,
    theme="soft",
    examples=examples,
    cache_examples=False,
    undo_btn="Delete Previous",
    clear_btn="Clear",
)

# https://huggingface.co/spaces/ysharma/Explore_llamav2_with_TGI/blob/main/app.py#L219 - we use this with construction b/c Gradio barfs on autoreload otherwise
with gr.Blocks() as demo:
    chat_interface.render()
    gr.Markdown("You can try asking this question in Japanese or English. We limit output to 200 tokens.")

demo.queue().launch()