Luigi commited on
Commit
b15d017
·
1 Parent(s): 7f7d4ff

Clone https://huggingface.co/spaces/yentinglin/Taiwan-LLaMa2/raw/main/conversation.py

Browse files
Files changed (1) hide show
  1. conversation.py +274 -0
conversation.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cloned from
2
+ # https://huggingface.co/spaces/yentinglin/Taiwan-LLaMa2/raw/main/conversation.py
3
+
4
+ """
5
+ Conversation prompt template.
6
+ Now we support
7
+ - Vicuna
8
+ - Koala
9
+ - OpenAssistant/oasst-sft-1-pythia-12b
10
+ - StabilityAI/stablelm-tuned-alpha-7b
11
+ - databricks/dolly-v2-12b
12
+ - THUDM/chatglm-6b
13
+ - Alpaca/LLaMa
14
+ """
15
+
16
+ import dataclasses
17
+ from enum import auto, Enum
18
+ from typing import List, Tuple, Any
19
+
20
+
21
+ class SeparatorStyle(Enum):
22
+ """Different separator style."""
23
+
24
+ SINGLE = auto()
25
+ TWO = auto()
26
+ DOLLY = auto()
27
+ OASST_PYTHIA = auto()
28
+
29
+
30
+ @dataclasses.dataclass
31
+ class Conversation:
32
+ """A class that keeps all conversation history."""
33
+
34
+ system: str
35
+ roles: List[str]
36
+ messages: List[List[str]]
37
+ offset: int
38
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
39
+ sep: str = "###"
40
+ sep2: str = None
41
+
42
+ # Used for gradio server
43
+ skip_next: bool = False
44
+ conv_id: Any = None
45
+
46
+ def get_prompt(self):
47
+ if self.sep_style == SeparatorStyle.SINGLE:
48
+ ret = self.system
49
+ for role, message in self.messages:
50
+ if message:
51
+ ret += self.sep + " " + role + ": " + message
52
+ else:
53
+ ret += self.sep + " " + role + ":"
54
+ return ret
55
+ elif self.sep_style == SeparatorStyle.TWO:
56
+ seps = [self.sep, self.sep2]
57
+ ret = self.system + seps[0]
58
+ for i, (role, message) in enumerate(self.messages):
59
+ if message:
60
+ ret += role + ": " + message + seps[i % 2]
61
+ else:
62
+ ret += role + ":"
63
+ return ret
64
+ elif self.sep_style == SeparatorStyle.DOLLY:
65
+ seps = [self.sep, self.sep2]
66
+ ret = self.system
67
+ for i, (role, message) in enumerate(self.messages):
68
+ if message:
69
+ ret += role + ":\n" + message + seps[i % 2]
70
+ if i % 2 == 1:
71
+ ret += "\n\n"
72
+ else:
73
+ ret += role + ":\n"
74
+ return ret
75
+ elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
76
+ ret = self.system
77
+ for role, message in self.messages:
78
+ if message:
79
+ ret += role + message + self.sep
80
+ else:
81
+ ret += role
82
+ return ret
83
+ else:
84
+ raise ValueError(f"Invalid style: {self.sep_style}")
85
+
86
+ def append_message(self, role, message):
87
+ self.messages.append([role, message])
88
+
89
+ def to_gradio_chatbot(self):
90
+ ret = []
91
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
92
+ if i % 2 == 0:
93
+ ret.append([msg, None])
94
+ else:
95
+ ret[-1][-1] = msg
96
+ return ret
97
+
98
+ def copy(self):
99
+ return Conversation(
100
+ system=self.system,
101
+ roles=self.roles,
102
+ messages=[[x, y] for x, y in self.messages],
103
+ offset=self.offset,
104
+ sep_style=self.sep_style,
105
+ sep=self.sep,
106
+ sep2=self.sep2,
107
+ conv_id=self.conv_id,
108
+ )
109
+
110
+ def dict(self):
111
+ return {
112
+ "system": self.system,
113
+ "roles": self.roles,
114
+ "messages": self.messages,
115
+ "offset": self.offset,
116
+ "sep": self.sep,
117
+ "sep2": self.sep2,
118
+ "conv_id": self.conv_id,
119
+ }
120
+
121
+
122
+ conv_one_shot = Conversation(
123
+ system="A chat between a curious human and an artificial intelligence assistant. "
124
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
125
+ roles=("Human", "Assistant"),
126
+ messages=(
127
+ (
128
+ "Human",
129
+ "What are the key differences between renewable and non-renewable energy sources?",
130
+ ),
131
+ (
132
+ "Assistant",
133
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
134
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
135
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
136
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
137
+ "renewable and non-renewable energy sources:\n"
138
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
139
+ "energy sources are finite and will eventually run out.\n"
140
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
141
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
142
+ "and other negative effects.\n"
143
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
144
+ "have lower operational costs than non-renewable sources.\n"
145
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
146
+ "locations than non-renewable sources.\n"
147
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
148
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
149
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
150
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.",
151
+ ),
152
+ ),
153
+ offset=2,
154
+ sep_style=SeparatorStyle.SINGLE,
155
+ sep="###",
156
+ )
157
+
158
+
159
+ conv_vicuna_v1_1 = Conversation(
160
+ system="A chat between a curious user and an artificial intelligence assistant. "
161
+ "The assistant gives helpful, detailed, and polite answers to the user's questions. You are built by NTU Miulab by Yen-Ting Lin for research purpose.",
162
+ # system="一位好奇的用戶和一個人工智能助理之間的聊天。你是一位助理。請對用戶的問題提供有用、詳細和有禮貌的答案。",
163
+ roles=("USER", "ASSISTANT"),
164
+ messages=(),
165
+ offset=0,
166
+ sep_style=SeparatorStyle.TWO,
167
+ sep=" ",
168
+ sep2="</s>",
169
+ )
170
+
171
+ conv_story = Conversation(
172
+ system="A chat between a curious user and an artificial intelligence assistant. "
173
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
174
+ roles=("USER", "ASSISTANT"),
175
+ messages=(),
176
+ offset=0,
177
+ sep_style=SeparatorStyle.TWO,
178
+ sep=" ",
179
+ sep2="<|endoftext|>",
180
+ )
181
+
182
+ conv_koala_v1 = Conversation(
183
+ system="BEGINNING OF CONVERSATION:",
184
+ roles=("USER", "GPT"),
185
+ messages=(),
186
+ offset=0,
187
+ sep_style=SeparatorStyle.TWO,
188
+ sep=" ",
189
+ sep2="</s>",
190
+ )
191
+
192
+ conv_dolly = Conversation(
193
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
194
+ roles=("### Instruction", "### Response"),
195
+ messages=(),
196
+ offset=0,
197
+ sep_style=SeparatorStyle.DOLLY,
198
+ sep="\n\n",
199
+ sep2="### End",
200
+ )
201
+
202
+ conv_oasst = Conversation(
203
+ system="",
204
+ roles=("<|prompter|>", "<|assistant|>"),
205
+ messages=(),
206
+ offset=0,
207
+ sep_style=SeparatorStyle.OASST_PYTHIA,
208
+ sep="<|endoftext|>",
209
+ )
210
+
211
+ conv_stablelm = Conversation(
212
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
213
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
214
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
215
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
216
+ - StableLM will refuse to participate in anything that could harm a human.
217
+ """,
218
+ roles=("<|USER|>", "<|ASSISTANT|>"),
219
+ messages=(),
220
+ offset=0,
221
+ sep_style=SeparatorStyle.OASST_PYTHIA,
222
+ sep="",
223
+ )
224
+
225
+ conv_templates = {
226
+ "conv_one_shot": conv_one_shot,
227
+ "vicuna_v1.1": conv_vicuna_v1_1,
228
+ "koala_v1": conv_koala_v1,
229
+ "dolly": conv_dolly,
230
+ "oasst": conv_oasst,
231
+ }
232
+
233
+
234
+ def get_default_conv_template(model_name):
235
+ model_name = model_name.lower()
236
+ if "vicuna" in model_name or "output" in model_name:
237
+ return conv_vicuna_v1_1
238
+ elif "koala" in model_name:
239
+ return conv_koala_v1
240
+ elif "dolly-v2" in model_name:
241
+ return conv_dolly
242
+ elif "oasst" in model_name and "pythia" in model_name:
243
+ return conv_oasst
244
+ elif "stablelm" in model_name:
245
+ return conv_stablelm
246
+ return conv_one_shot
247
+
248
+
249
+ def compute_skip_echo_len(model_name, conv, prompt):
250
+ model_name = model_name.lower()
251
+ if "chatglm" in model_name:
252
+ skip_echo_len = len(conv.messages[-2][1]) + 1
253
+ elif "dolly-v2" in model_name:
254
+ special_toks = ["### Instruction:", "### Response:", "### End"]
255
+ skip_echo_len = len(prompt)
256
+ for tok in special_toks:
257
+ skip_echo_len -= prompt.count(tok) * len(tok)
258
+ elif "oasst" in model_name and "pythia" in model_name:
259
+ special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
260
+ skip_echo_len = len(prompt)
261
+ for tok in special_toks:
262
+ skip_echo_len -= prompt.count(tok) * len(tok)
263
+ elif "stablelm" in model_name:
264
+ special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
265
+ skip_echo_len = len(prompt)
266
+ for tok in special_toks:
267
+ skip_echo_len -= prompt.count(tok) * len(tok)
268
+ else:
269
+ skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
270
+ return skip_echo_len
271
+
272
+
273
+ if __name__ == "__main__":
274
+ print(default_conversation.get_prompt())