Wonderplex commited on
Commit
8b07d8c
·
1 Parent(s): f1c7954

fixed parsing errors and extra_info (#56)

Browse files
Files changed (5) hide show
  1. app.py +35 -5
  2. message_classes.py +1 -1
  3. sotopia_pi_generate.py +49 -49
  4. start_app.sh +4 -0
  5. utils.py +5 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  from collections import defaultdict
3
  import json
 
4
 
5
  import gradio as gr
6
 
@@ -25,6 +26,21 @@ RELATIONSHIP_PROFILES = "profiles/relationship_profiles.jsonl"
25
 
26
  ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @cache
29
  def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
30
  with open(env_file, 'r') as f:
@@ -126,13 +142,27 @@ def create_bot_info(bot_agent_dropdown):
126
  return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
127
 
128
  def create_user_goal(environment_dropdown):
129
- _, environment_dict, _, _ = get_sotopia_profiles()
130
- text = environment_dict[environment_dropdown].agent_goals[0]
131
- return gr.Textbox(label="User Agent Goal", lines=4, value=text)
 
 
 
 
 
 
 
132
 
133
  def create_bot_goal(environment_dropdown):
134
  _, environment_dict, _, _ = get_sotopia_profiles()
135
  text = environment_dict[environment_dropdown].agent_goals[1]
 
 
 
 
 
 
 
136
  return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
137
 
138
  def sotopia_info_accordion(accordion_visible=True):
@@ -147,7 +177,7 @@ def sotopia_info_accordion(accordion_visible=True):
147
  interactive=True,
148
  )
149
  model_name_dropdown = gr.Dropdown(
150
- choices=["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1", "gpt-3.5-turbo", "gpt-4-turbo"],
151
  value=DEFAULT_MODEL_SELECTION,
152
  interactive=True,
153
  label="Model Selection"
@@ -215,7 +245,7 @@ def chat_tab():
215
 
216
  context = get_context_prompt(bot_agent, user_agent, environment)
217
  dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
218
- prompt_history = f"{context}\n\n{dialogue_history}"
219
  agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
220
  return agent_action.to_natural_language()
221
 
 
1
  import os
2
  from collections import defaultdict
3
  import json
4
+ from typing import Literal
5
 
6
  import gradio as gr
7
 
 
26
 
27
  ACTION_TYPES = ['none', 'action', 'non-verbal communication', 'speak', 'leave']
28
 
29
+ MODEL_OPTIONS = [
30
+ "gpt-3.5-turbo",
31
+ "gpt-4",
32
+ "gpt-4-turbo",
33
+ "cmu-lti/sotopia-pi-mistral-7b-BC_SR",
34
+ "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit",
35
+ "mistralai/Mistral-7B-Instruct-v0.1"
36
+ # "mistralai/Mixtral-8x7B-Instruct-v0.1",
37
+ # "togethercomputer/llama-2-7b-chat",
38
+ # "togethercomputer/llama-2-70b-chat",
39
+ # "togethercomputer/mpt-30b-chat",
40
+ # "together_ai/togethercomputer/llama-2-7b-chat",
41
+ # "together_ai/togethercomputer/falcon-7b-instruct",
42
+ ]
43
+
44
  @cache
45
  def get_sotopia_profiles(env_file=ENVIRONMENT_PROFILES, agent_file=AGENT_PROFILES, relationship_file=RELATIONSHIP_PROFILES):
46
  with open(env_file, 'r') as f:
 
142
  return gr.Textbox(label="Bot Agent Profile", lines=4, value=text)
143
 
144
  def create_user_goal(environment_dropdown):
145
+ _, environment_dict, _, _ = get_sotopia_profiles()
146
+ text = environment_dict[environment_dropdown].agent_goals[0]
147
+ text = text.replace('(', '').replace(')', '')
148
+ if "<extra_info>" in text:
149
+ text = text.replace("<extra_info>", "\n\n")
150
+ text = text.replace("</extra_info>", "\n")
151
+ if "<strategy_hint>" in text:
152
+ text = text.replace("<strategy_hint>", "\n\n")
153
+ text = text.replace("</strategy_hint>", "\n")
154
+ return gr.Textbox(label="User Agent Goal", lines=4, value=text)
155
 
156
  def create_bot_goal(environment_dropdown):
157
  _, environment_dict, _, _ = get_sotopia_profiles()
158
  text = environment_dict[environment_dropdown].agent_goals[1]
159
+ text = text.replace('(', '').replace(')', '')
160
+ if "<extra_info>" in text:
161
+ text = text.replace("<extra_info>", "\n\n")
162
+ text = text.replace("</extra_info>", "\n")
163
+ if "<strategy_hint>" in text:
164
+ text = text.replace("<strategy_hint>", "\n\n")
165
+ text = text.replace("</strategy_hint>", "\n")
166
  return gr.Textbox(label="Bot Agent Goal", lines=4, value=text)
167
 
168
  def sotopia_info_accordion(accordion_visible=True):
 
177
  interactive=True,
178
  )
179
  model_name_dropdown = gr.Dropdown(
180
+ choices=MODEL_OPTIONS,
181
  value=DEFAULT_MODEL_SELECTION,
182
  interactive=True,
183
  label="Model Selection"
 
245
 
246
  context = get_context_prompt(bot_agent, user_agent, environment)
247
  dialogue_history, next_turn_idx = dialogue_history_prompt(message, history, user_agent, bot_agent)
248
+ prompt_history = f"{context}{dialogue_history}"
249
  agent_action = generate_action(model_selection, prompt_history, next_turn_idx, ACTION_TYPES, bot_agent.name, TEMPERATURE)
250
  return agent_action.to_natural_language()
251
 
message_classes.py CHANGED
@@ -120,7 +120,7 @@ class AgentAction(Message):
120
  case "none":
121
  return "did nothing"
122
  case "speak":
123
- return f'said: "{self.argument}"'
124
  case "non-verbal communication":
125
  return f"[{self.action_type}] {self.argument}"
126
  case "action":
 
120
  case "none":
121
  return "did nothing"
122
  case "speak":
123
+ return f"{self.argument}"
124
  case "non-verbal communication":
125
  return f"[{self.action_type}] {self.argument}"
126
  case "action":
sotopia_pi_generate.py CHANGED
@@ -28,6 +28,9 @@ from utils import format_docstring
28
  from langchain_callback_handler import LoggingCallbackHandler
29
 
30
  HF_TOKEN_KEY_FILE="./hf_token.key"
 
 
 
31
 
32
  OutputType = TypeVar("OutputType", bound=object)
33
  log = logging.getLogger("generate")
@@ -44,59 +47,54 @@ def generate_action(
44
  """
45
  Using langchain to generate an example episode
46
  """
47
- try:
48
  # Normal case, model as agent
49
- template = """
50
- Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
51
- You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
52
- Note that {agent}'s goal is only visible to you.
53
- You should try your best to achieve {agent}'s goal in a way that align with their character traits.
54
- Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).
55
- {history}.
56
- You are at Turn #{turn_number}. Your available action types are
57
- {action_list}.
58
- Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
59
 
60
- Please only generate a JSON string including the action type and the argument.
61
- Your action should follow the given format:
62
- {format_instructions}
63
- """
64
- return generate(
65
- model_name=model_name,
66
- template=template,
67
- input_values=dict(
68
- agent=agent,
69
- turn_number=str(turn_number),
70
- history=history,
71
- action_list=" ".join(action_types),
72
- ),
73
- output_parser=PydanticOutputParser(pydantic_object=AgentAction),
74
- temperature=temperature,
75
- )
76
- except Exception:
77
- return AgentAction(action_type="none", argument="")
 
78
 
79
  @cache
80
- def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
81
  compute_type = torch.float16
82
- if os.path.exists(hf_token_key_file):
83
- with open (hf_token_key_file, 'r') as f:
84
- hf_token = f.read().strip()
85
- else:
86
- hf_token = os.environ["HF_TOKEN"]
87
 
88
  if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
89
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
90
  model = AutoModelForCausalLM.from_pretrained(
91
  "mistralai/Mistral-7B-Instruct-v0.1",
92
  cache_dir="./.cache",
93
- device_map='cuda',
94
- token=hf_token
95
  )
96
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
97
 
98
  elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
99
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
100
  model = AutoModelForCausalLM.from_pretrained(
101
  "mistralai/Mistral-7B-Instruct-v0.1",
102
  cache_dir="./.cache",
@@ -106,18 +104,17 @@ def prepare_model(model_name, hf_token_key_file=HF_TOKEN_KEY_FILE):
106
  bnb_4bit_use_double_quant=True,
107
  bnb_4bit_quant_type="nf4",
108
  bnb_4bit_compute_dtype=compute_type,
109
- ),
110
- token=hf_token
111
  )
112
- model = PeftModel.from_pretrained(model, model_name).to("cuda")
113
 
114
  elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
115
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
 
116
  model = AutoModelForCausalLM.from_pretrained(
117
  "mistralai/Mistral-7B-Instruct-v0.1",
118
  cache_dir="./.cache",
119
- device_map='cuda',
120
- token=hf_token
121
  )
122
 
123
  else:
@@ -146,7 +143,6 @@ def obtain_chain_hf(
146
  return_full_text=False,
147
  do_sample=True,
148
  num_beams=3,
149
- length_penalty=-1.0,
150
  )
151
  hf = HuggingFacePipeline(pipeline=pipe)
152
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
@@ -171,6 +167,8 @@ def generate(
171
  input_values["format_instructions"] = output_parser.get_format_instructions()
172
  result = chain.predict([logging_handler], **input_values)
173
  prompt = logging_handler.retrive_prompt()
 
 
174
  try:
175
  parsed_result = output_parser.parse(result)
176
  except KeyboardInterrupt:
@@ -183,6 +181,7 @@ def generate(
183
  reformat_parsed_result = format_bad_output(
184
  result, format_instructions=output_parser.get_format_instructions()
185
  )
 
186
  parsed_result = output_parser.parse(reformat_parsed_result)
187
  log.info(f"Generated result: {parsed_result}")
188
  return parsed_result
@@ -223,7 +222,7 @@ def obtain_chain(
223
  """
224
  Using langchain to sample profiles for participants
225
  """
226
- if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit"]:
227
  return obtain_chain_hf(
228
  model_name=model_name,
229
  template=template,
@@ -247,10 +246,11 @@ def obtain_chain(
247
  return chain
248
 
249
  def _return_fixed_model_version(model_name: str) -> str:
250
- return {
251
  "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
252
  "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
253
  "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
254
  "gpt-4": "gpt-4-0613",
255
  "gpt-4-turbo": "gpt-4-1106-preview",
256
- }[model_name]
 
 
28
  from langchain_callback_handler import LoggingCallbackHandler
29
 
30
  HF_TOKEN_KEY_FILE="./hf_token.key"
31
+ if os.path.exists(HF_TOKEN_KEY_FILE):
32
+ with open(HF_TOKEN_KEY_FILE, "r") as f:
33
+ os.environ["HF_TOKEN"] = f.read().strip()
34
 
35
  OutputType = TypeVar("OutputType", bound=object)
36
  log = logging.getLogger("generate")
 
47
  """
48
  Using langchain to generate an example episode
49
  """
50
+ # try:
51
  # Normal case, model as agent
52
+ template = """
53
+ Imagine you are {agent}, your task is to act/speak as {agent} would, keeping in mind {agent}'s social goal.
54
+ You can find {agent}'s goal (or background) in the 'Here is the context of the interaction' field.
55
+ Note that {agent}'s goal is only visible to you.
56
+ You should try your best to achieve {agent}'s goal in a way that align with their character traits.
57
+ Additionally, maintaining the conversation's naturalness and realism is essential (e.g., do not repeat what other people has already said before).\n
58
+ {history}.
59
+ You are at Turn #{turn_number}. Your available action types are
60
+ {action_list}.
61
+ Note: You can "leave" this conversation if 1. you have achieved your social goals, 2. this conversation makes you uncomfortable, 3. you find it uninteresting/you lose your patience, 4. or for other reasons you want to leave.
62
 
63
+ Please only generate a JSON string including the action type and the argument.
64
+ Your action should follow the given format:
65
+ {format_instructions}
66
+ """
67
+ return generate(
68
+ model_name=model_name,
69
+ template=template,
70
+ input_values=dict(
71
+ agent=agent,
72
+ turn_number=str(turn_number),
73
+ history=history,
74
+ action_list=" ".join(action_types),
75
+ ),
76
+ output_parser=PydanticOutputParser(pydantic_object=AgentAction),
77
+ temperature=temperature,
78
+ )
79
+ # except Exception as e:
80
+ # print(e)
81
+ # return AgentAction(action_type="none", argument="")
82
 
83
  @cache
84
+ def prepare_model(model_name):
85
  compute_type = torch.float16
 
 
 
 
 
86
 
87
  if model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR':
88
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
89
  model = AutoModelForCausalLM.from_pretrained(
90
  "mistralai/Mistral-7B-Instruct-v0.1",
91
  cache_dir="./.cache",
92
+ device_map='cuda'
 
93
  )
94
  model = PeftModel.from_pretrained(model, model_name).to("cuda")
95
 
96
  elif model_name == 'cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit':
97
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
98
  model = AutoModelForCausalLM.from_pretrained(
99
  "mistralai/Mistral-7B-Instruct-v0.1",
100
  cache_dir="./.cache",
 
104
  bnb_4bit_use_double_quant=True,
105
  bnb_4bit_quant_type="nf4",
106
  bnb_4bit_compute_dtype=compute_type,
107
+ )
 
108
  )
109
+ model = PeftModel.from_pretrained(model, model_name[0:-5]).to("cuda")
110
 
111
  elif model_name == 'mistralai/Mistral-7B-Instruct-v0.1':
112
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", model_max_length=4096)
113
+ tokenizer.model_max_length = 4096
114
  model = AutoModelForCausalLM.from_pretrained(
115
  "mistralai/Mistral-7B-Instruct-v0.1",
116
  cache_dir="./.cache",
117
+ device_map='cuda'
 
118
  )
119
 
120
  else:
 
143
  return_full_text=False,
144
  do_sample=True,
145
  num_beams=3,
 
146
  )
147
  hf = HuggingFacePipeline(pipeline=pipe)
148
  chain = LLMChain(llm=hf, prompt=chat_prompt_template)
 
167
  input_values["format_instructions"] = output_parser.get_format_instructions()
168
  result = chain.predict([logging_handler], **input_values)
169
  prompt = logging_handler.retrive_prompt()
170
+ print(f"Prompt:\n {prompt}")
171
+ print(f"Result:\n {result}")
172
  try:
173
  parsed_result = output_parser.parse(result)
174
  except KeyboardInterrupt:
 
181
  reformat_parsed_result = format_bad_output(
182
  result, format_instructions=output_parser.get_format_instructions()
183
  )
184
+ print(f"Reformatted result:\n {reformat_parsed_result}")
185
  parsed_result = output_parser.parse(reformat_parsed_result)
186
  log.info(f"Generated result: {parsed_result}")
187
  return parsed_result
 
222
  """
223
  Using langchain to sample profiles for participants
224
  """
225
+ if model_name in ["cmu-lti/sotopia-pi-mistral-7b-BC_SR", "cmu-lti/sotopia-pi-mistral-7b-BC_SR_4bit", "mistralai/Mistral-7B-Instruct-v0.1"]:
226
  return obtain_chain_hf(
227
  model_name=model_name,
228
  template=template,
 
246
  return chain
247
 
248
  def _return_fixed_model_version(model_name: str) -> str:
249
+ model_version_map = {
250
  "gpt-3.5-turbo": "gpt-3.5-turbo-0613",
251
  "gpt-3.5-turbo-finetuned": "ft:gpt-3.5-turbo-0613:academicscmu::8nY2zgdt",
252
  "gpt-3.5-turbo-ft-MF": "ft:gpt-3.5-turbo-0613:academicscmu::8nuER4bO",
253
  "gpt-4": "gpt-4-0613",
254
  "gpt-4-turbo": "gpt-4-1106-preview",
255
+ }
256
+ return model_version_map[model_name] if model_name in model_version_map else model_name
start_app.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ export OPENAI_API_KEY=$(cat openai_api.key)
2
+ export HF_TOKEN=$(cat hf_token.key)
3
+
4
+ python app.py
utils.py CHANGED
@@ -44,11 +44,11 @@ def dialogue_history_prompt(message, history, user_agent, bot_agent):
44
  user_turn_idx = idx * 2
45
  bot_turn_idx = idx * 2 + 1
46
  if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
47
- bot_message = "said :" + bot_message
48
- dialogue_history = f"{dialogue_history}\n\nTurn #{user_turn_idx}: {user_agent.name}: {user_message}\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"
49
- last_turn_idx = len(history) * 2
50
- dialogue_history = f"{dialogue_history}\n\nTurn #{last_turn_idx+1}: {user_agent.name}: {message}\n."
51
- return dialogue_history, last_turn_idx + 2
52
 
53
  def format_docstring(docstring: str) -> str:
54
  """Format a docstring for use in a prompt template."""
 
44
  user_turn_idx = idx * 2
45
  bot_turn_idx = idx * 2 + 1
46
  if not bot_message.startswith("["): # if action type == speak, need to add 'said: ' to be consistent with the dialog prompt
47
+ bot_message = 'said:"' + bot_message + '"'
48
+ dialogue_history = f"""{dialogue_history}\n\nTurn #{user_turn_idx} {user_agent.name} said: "{user_message}"\n\nTurn #{bot_turn_idx}: {bot_agent.name}: {bot_message}"""
49
+ curr_turn_idx = len(history) * 2
50
+ dialogue_history = f"""{dialogue_history}\n\nTurn #{curr_turn_idx} {user_agent.name} said: "{message}"\n"""
51
+ return dialogue_history, curr_turn_idx + 1
52
 
53
  def format_docstring(docstring: str) -> str:
54
  """Format a docstring for use in a prompt template."""