Spaces:
Runtime error
Runtime error
Wonderplex
commited on
Commit
·
8b07d8c
1
Parent(s):
f1c7954
fixed parsing errors and extra_info (#56)
Browse files- app.py +35 -5
- message_classes.py +1 -1
- sotopia_pi_generate.py +49 -49
- start_app.sh +4 -0
- utils.py +5 -5
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
from collections import defaultdict
|
3 |
import json
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
|
@@ -25,6 +26,21 @@ RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
|
|
25 |
|
26 |
ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
@cache
|
29 |
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
30 |
with open(env_file, 'r') as f:
|
@@ -126,13 +142,27 @@ def create_bot_info(bot_agent_dropdown):
|
|
126 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
127 |
|
128 |
def create_user_goal(environment_dropdown):
|
129 |
-
|
130 |
-
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
def create_bot_goal(environment_dropdown):
|
134 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
135 |
text = environment_dict[environment_dropdown].agent_goals[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
137 |
|
138 |
def sotopia_info_accordion(accordion_visible=True):
|
@@ -147,7 +177,7 @@ def sotopia_info_accordion(accordion_visible=True):
|
|
147 |
interactive=True,
|
148 |
)
|
149 |
model_name_dropdown = gr.Dropdown(
|
150 |
-
choices=
|
151 |
value=DEFAULT_MODEL_SELECTION,
|
152 |
interactive=True,
|
153 |
label="Model Selection"
|
@@ -215,7 +245,7 @@ def chat_tab():
|
|
215 |
|
216 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
217 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
218 |
-
prompt_history = f"{context}
|
219 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
220 |
return agent_action.to_natural_language()
|
221 |
|
|
|
1 |
import os
|
2 |
from collections import defaultdict
|
3 |
import json
|
4 |
+
from typing import Literal
|
5 |
|
6 |
import gradio as gr
|
7 |
|
|
|
26 |
|
27 |
ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
|
28 |
|
29 |
+
MODEL_OPTIONS = [
|
30 |
+
"gpt-3.5-turbo",
|
31 |
+
"gpt-4",
|
32 |
+
"gpt-4-turbo",
|
33 |
+
"cmu-lti/sotopia-pi-mistral-7b-BC_SR",
|
34 |
+
"cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
|
35 |
+
"mistralai/Mistral-7B-Instruct-v0.1"
|
36 |
+
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
37 |
+
# "togethercomputer/llama-2-7b-chat",
|
38 |
+
# "togethercomputer/llama-2-70b-chat",
|
39 |
+
# "togethercomputer/mpt-30b-chat",
|
40 |
+
# "together_ai/togethercomputer/llama-2-7b-chat",
|
41 |
+
# "together_ai/togethercomputer/falcon-7b-instruct",
|
42 |
+
]
|
43 |
+
|
44 |
@cache
|
45 |
def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
|
46 |
with open(env_file, 'r') as f:
|
|
|
142 |
return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
|
143 |
|
144 |
def create_user_goal(environment_dropdown):
|
145 |
+
_, environment_dict, _, _ = get_sotopia_profiles()
|
146 |
+
text = environment_dict[environment_dropdown].agent_goals[0]
|
147 |
+
text = text.replace('(', '').replace(')', '')
|
148 |
+
if "<extra_info>" in text:
|
149 |
+
text = text.replace("<extra_info>", "\n\n")
|
150 |
+
text = text.replace("</extra_info>", "\n")
|
151 |
+
if "<strategy_hint>" in text:
|
152 |
+
text = text.replace("<strategy_hint>", "\n\n")
|
153 |
+
text = text.replace("</strategy_hint>", "\n")
|
154 |
+
return gr.Textbox(label="User Agent Goal", lines=4, value=text)
|
155 |
|
156 |
def create_bot_goal(environment_dropdown):
|
157 |
_, environment_dict, _, _ = get_sotopia_profiles()
|
158 |
text = environment_dict[environment_dropdown].agent_goals[1]
|
159 |
+
text = text.replace('(', '').replace(')', '')
|
160 |
+
if "<extra_info>" in text:
|
161 |
+
text = text.replace("<extra_info>", "\n\n")
|
162 |
+
text = text.replace("</extra_info>", "\n")
|
163 |
+
if "<strategy_hint>" in text:
|
164 |
+
text = text.replace("<strategy_hint>", "\n\n")
|
165 |
+
text = text.replace("</strategy_hint>", "\n")
|
166 |
return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
|
167 |
|
168 |
def sotopia_info_accordion(accordion_visible=True):
|
|
|
177 |
interactive=True,
|
178 |
)
|
179 |
model_name_dropdown = gr.Dropdown(
|
180 |
+
choices=MODEL_OPTIONS,
|
181 |
value=DEFAULT_MODEL_SELECTION,
|
182 |
interactive=True,
|
183 |
label="Model Selection"
|
|
|
245 |
|
246 |
context = get_context_prompt(bot_agent, user_agent, environment)
|
247 |
dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
|
248 |
+
prompt_history = f"{context}{dialogue_history}"
|
249 |
agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
|
250 |
return agent_action.to_natural_language()
|
251 |
|
message_classes.py
CHANGED
@@ -120,7 +120,7 @@ class AgentAction(Message):
|
|
120 |
case "none":
|
121 |
return "did nothing"
|
122 |
case "speak":
|
123 |
-
return f
|
124 |
case "non-verbal communication":
|
125 |
return f"[{self.action_type}] {self.argument}"
|
126 |
case "action":
|
|
|
120 |
case "none":
|
121 |
return "did nothing"
|
122 |
case "speak":
|
123 |
+
return f"{self.argument}"
|
124 |
case "non-verbal communication":
|
125 |
return f"[{self.action_type}] {self.argument}"
|
126 |
case "action":
|
sotopia_pi_generate.py
CHANGED
@@ -28,6 +28,9 @@ from utils import format_docstring
|
|
28 |
from langchain_callback_handler import LoggingCallbackHandler
|
29 |
|
30 |
HF_TOKEN_KEY_FILE="./hf_token.key"
|
|
|
|
|
|
|
31 |
|
32 |
OutputType = TypeVar("OutputType", bound=object)
|
33 |
log = logging.getLogger("generate")
|
@@ -44,59 +47,54 @@ def generate_action(
|
|
44 |
"""
|
45 |
Using langchain to generate an example episode
|
46 |
"""
|
47 |
-
try:
|
48 |
# Normal case, model as agent
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
except Exception:
|
77 |
-
|
|
|
78 |
|
79 |
@cache
|
80 |
-
def prepare_model(model_name
|
81 |
compute_type = torch.float16
|
82 |
-
if os.path.exists(hf_token_key_file):
|
83 |
-
with open (hf_token_key_file, 'r') as f:
|
84 |
-
hf_token = f.read().strip()
|
85 |
-
else:
|
86 |
-
hf_token = os.environ["HF_TOKEN"]
|
87 |
|
88 |
if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
|
89 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
|
90 |
model = AutoModelForCausalLM.from_pretrained(
|
91 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
92 |
cache_dir="./.cache",
|
93 |
-
device_map='cuda'
|
94 |
-
token=hf_token
|
95 |
)
|
96 |
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
97 |
|
98 |
elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
|
99 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
|
100 |
model = AutoModelForCausalLM.from_pretrained(
|
101 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
102 |
cache_dir="./.cache",
|
@@ -106,18 +104,17 @@ def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
|
|
106 |
bnb_4bit_use_double_quant=True,
|
107 |
bnb_4bit_quant_type="nf4",
|
108 |
bnb_4bit_compute_dtype=compute_type,
|
109 |
-
)
|
110 |
-
token=hf_token
|
111 |
)
|
112 |
-
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
113 |
|
114 |
elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
|
115 |
-
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1",
|
|
|
116 |
model = AutoModelForCausalLM.from_pretrained(
|
117 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
118 |
cache_dir="./.cache",
|
119 |
-
device_map='cuda'
|
120 |
-
token=hf_token
|
121 |
)
|
122 |
|
123 |
else:
|
@@ -146,7 +143,6 @@ def obtain_chain_hf(
|
|
146 |
return_full_text=False,
|
147 |
do_sample=True,
|
148 |
num_beams=3,
|
149 |
-
length_penalty=-1.0,
|
150 |
)
|
151 |
hf = HuggingFacePipeline(pipeline=pipe)
|
152 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
@@ -171,6 +167,8 @@ def generate(
|
|
171 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
172 |
result = chain.predict([logging_handler], **input_values)
|
173 |
prompt = logging_handler.retrive_prompt()
|
|
|
|
|
174 |
try:
|
175 |
parsed_result = output_parser.parse(result)
|
176 |
except KeyboardInterrupt:
|
@@ -183,6 +181,7 @@ def generate(
|
|
183 |
reformat_parsed_result = format_bad_output(
|
184 |
result, format_instructions=output_parser.get_format_instructions()
|
185 |
)
|
|
|
186 |
parsed_result = output_parser.parse(reformat_parsed_result)
|
187 |
log.info(f"Generated result: {parsed_result}")
|
188 |
return parsed_result
|
@@ -223,7 +222,7 @@ def obtain_chain(
|
|
223 |
"""
|
224 |
Using langchain to sample profiles for participants
|
225 |
"""
|
226 |
-
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]:
|
227 |
return obtain_chain_hf(
|
228 |
model_name=model_name,
|
229 |
template=template,
|
@@ -247,10 +246,11 @@ def obtain_chain(
|
|
247 |
return chain
|
248 |
|
249 |
def _return_fixed_model_version(model_name: str) -> str:
|
250 |
-
|
251 |
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
252 |
"gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
|
253 |
"gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
|
254 |
"gpt-4": "gpt-4-0613",
|
255 |
"gpt-4-turbo": "gpt-4-1106-preview",
|
256 |
-
}
|
|
|
|
28 |
from langchain_callback_handler import LoggingCallbackHandler
|
29 |
|
30 |
HF_TOKEN_KEY_FILE="./hf_token.key"
|
31 |
+
if os.path.exists(HF_TOKEN_KEY_FILE):
|
32 |
+
with open(HF_TOKEN_KEY_FILE, "r") as f:
|
33 |
+
os.environ["HF_TOKEN"] = f.read().strip()
|
34 |
|
35 |
OutputType = TypeVar("OutputType", bound=object)
|
36 |
log = logging.getLogger("generate")
|
|
|
47 |
"""
|
48 |
Using langchain to generate an example episode
|
49 |
"""
|
50 |
+
# try:
|
51 |
# Normal case, model as agent
|
52 |
+
template = """
|
53 |
+
Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
|
54 |
+
You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
|
55 |
+
Note that {agent}'s goal is only visible to you.
|
56 |
+
You should try your best to achieve {agent}'s goal in a way that align with their character traits.
|
57 |
+
Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n
|
58 |
+
{history}.
|
59 |
+
You are at Turn #{turn_number}. Your available action types are
|
60 |
+
{action_list}.
|
61 |
+
Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
|
62 |
|
63 |
+
Please only generate a JSON string including the action type and the argument.
|
64 |
+
Your action should follow the given format:
|
65 |
+
{format_instructions}
|
66 |
+
"""
|
67 |
+
return generate(
|
68 |
+
model_name=model_name,
|
69 |
+
template=template,
|
70 |
+
input_values=dict(
|
71 |
+
agent=agent,
|
72 |
+
turn_number=str(turn_number),
|
73 |
+
history=history,
|
74 |
+
action_list=" ".join(action_types),
|
75 |
+
),
|
76 |
+
output_parser=PydanticOutputParser(pydantic_object=AgentAction),
|
77 |
+
temperature=temperature,
|
78 |
+
)
|
79 |
+
# except Exception as e:
|
80 |
+
# print(e)
|
81 |
+
# return AgentAction(action_type="none", argument="")
|
82 |
|
83 |
@cache
|
84 |
+
def prepare_model(model_name):
|
85 |
compute_type = torch.float16
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
|
88 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
|
89 |
model = AutoModelForCausalLM.from_pretrained(
|
90 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
91 |
cache_dir="./.cache",
|
92 |
+
device_map='cuda'
|
|
|
93 |
)
|
94 |
model = PeftModel.from_pretrained(model, model_name).to("cuda")
|
95 |
|
96 |
elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
|
97 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
|
98 |
model = AutoModelForCausalLM.from_pretrained(
|
99 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
100 |
cache_dir="./.cache",
|
|
|
104 |
bnb_4bit_use_double_quant=True,
|
105 |
bnb_4bit_quant_type="nf4",
|
106 |
bnb_4bit_compute_dtype=compute_type,
|
107 |
+
)
|
|
|
108 |
)
|
109 |
+
model = PeftModel.from_pretrained(model, model_name[0:-5]).to("cuda")
|
110 |
|
111 |
elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
|
112 |
+
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
|
113 |
+
tokenizer.model_max_length = 4096
|
114 |
model = AutoModelForCausalLM.from_pretrained(
|
115 |
"mistralai/Mistral-7B-Instruct-v0.1",
|
116 |
cache_dir="./.cache",
|
117 |
+
device_map='cuda'
|
|
|
118 |
)
|
119 |
|
120 |
else:
|
|
|
143 |
return_full_text=False,
|
144 |
do_sample=True,
|
145 |
num_beams=3,
|
|
|
146 |
)
|
147 |
hf = HuggingFacePipeline(pipeline=pipe)
|
148 |
chain = LLMChain(llm=hf, prompt=chat_prompt_template)
|
|
|
167 |
input_values["format_instructions"] = output_parser.get_format_instructions()
|
168 |
result = chain.predict([logging_handler], **input_values)
|
169 |
prompt = logging_handler.retrive_prompt()
|
170 |
+
print(f"Prompt:\n {prompt}")
|
171 |
+
print(f"Result:\n {result}")
|
172 |
try:
|
173 |
parsed_result = output_parser.parse(result)
|
174 |
except KeyboardInterrupt:
|
|
|
181 |
reformat_parsed_result = format_bad_output(
|
182 |
result, format_instructions=output_parser.get_format_instructions()
|
183 |
)
|
184 |
+
print(f"Reformatted result:\n {reformat_parsed_result}")
|
185 |
parsed_result = output_parser.parse(reformat_parsed_result)
|
186 |
log.info(f"Generated result: {parsed_result}")
|
187 |
return parsed_result
|
|
|
222 |
"""
|
223 |
Using langchain to sample profiles for participants
|
224 |
"""
|
225 |
+
if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1"]:
|
226 |
return obtain_chain_hf(
|
227 |
model_name=model_name,
|
228 |
template=template,
|
|
|
246 |
return chain
|
247 |
|
248 |
def _return_fixed_model_version(model_name: str) -> str:
|
249 |
+
model_version_map = {
|
250 |
"gpt-3.5-turbo": "gpt-3.5-turbo-0613",
|
251 |
"gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
|
252 |
"gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
|
253 |
"gpt-4": "gpt-4-0613",
|
254 |
"gpt-4-turbo": "gpt-4-1106-preview",
|
255 |
+
}
|
256 |
+
return model_version_map[model_name] if model_name in model_version_map else model_name
|
start_app.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export OPENAI_API_KEY=$(cat openai_api.key)
|
2 |
+
export HF_TOKEN=$(cat hf_token.key)
|
3 |
+
|
4 |
+
python app.py
|
utils.py
CHANGED
@@ -44,11 +44,11 @@ def dialogue_history_prompt(message, history, user_agent, bot_agent):
|
|
44 |
user_turn_idx = idx * 2
|
45 |
bot_turn_idx = idx * 2 + 1
|
46 |
if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
|
47 |
-
bot_message =
|
48 |
-
dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}
|
49 |
-
|
50 |
-
dialogue_history = f"{dialogue_history}\n\nTurn #{
|
51 |
-
return dialogue_history,
|
52 |
|
53 |
def format_docstring(docstring: str) -> str:
|
54 |
"""Format a docstring for use in a prompt template."""
|
|
|
44 |
user_turn_idx = idx * 2
|
45 |
bot_turn_idx = idx * 2 + 1
|
46 |
if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
|
47 |
+
bot_message = 'said:"' + bot_message + '"'
|
48 |
+
dialogue_history = f"""{dialogue_history}\n\nTurn #{user_turn_idx} {user_agent.name} said: "{user_message}"\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"""
|
49 |
+
curr_turn_idx = len(history) * 2
|
50 |
+
dialogue_history = f"""{dialogue_history}\n\nTurn #{curr_turn_idx} {user_agent.name} said: "{message}"\n"""
|
51 |
+
return dialogue_history, curr_turn_idx + 1
|
52 |
|
53 |
def format_docstring(docstring: str) -> str:
|
54 |
"""Format a docstring for use in a prompt template."""
|