feature/major backend update with agent
Browse files
.gitignore
CHANGED
@@ -5,4 +5,5 @@ __pycache__/utils.cpython-38.pyc
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
8 |
-
local_tests/
|
|
|
|
5 |
|
6 |
notebooks/
|
7 |
*.pyc
|
8 |
+
local_tests/
|
9 |
+
.vscode/
|
app.py
CHANGED
@@ -64,9 +64,9 @@ async def chat(query, history):
|
|
64 |
async for event in result:
|
65 |
print(event)
|
66 |
if event["event"] == "on_chat_model_stream":
|
67 |
-
print("line 66")
|
68 |
if start_streaming == False:
|
69 |
-
print("line 68")
|
70 |
start_streaming = True
|
71 |
history[-1] = (query, "")
|
72 |
|
@@ -77,17 +77,26 @@ async def chat(query, history):
|
|
77 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
78 |
history[-1] = (query, answer_yet)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
elif (
|
81 |
event["name"] == "retrieve_documents"
|
82 |
and event["event"] == "on_chain_end"
|
83 |
):
|
84 |
try:
|
85 |
-
print(
|
|
|
86 |
docs = event["data"]["output"]["documents"]
|
87 |
docs_html = []
|
88 |
-
for i,
|
89 |
-
docs_html.append(make_html_source(
|
|
|
90 |
docs_html = "".join(docs_html)
|
|
|
91 |
except Exception as e:
|
92 |
print(f"Error getting documents: {e}")
|
93 |
print(event)
|
@@ -97,9 +106,9 @@ async def chat(query, history):
|
|
97 |
display_output,
|
98 |
) in steps_display.items():
|
99 |
if event["name"] == event_name:
|
100 |
-
print("line 99")
|
101 |
if event["event"] == "on_chain_start":
|
102 |
-
print("line 101")
|
103 |
answer_yet = event_description
|
104 |
history[-1] = (query, answer_yet)
|
105 |
|
|
|
64 |
async for event in result:
|
65 |
print(event)
|
66 |
if event["event"] == "on_chat_model_stream":
|
67 |
+
# print("line 66")
|
68 |
if start_streaming == False:
|
69 |
+
# print("line 68")
|
70 |
start_streaming = True
|
71 |
history[-1] = (query, "")
|
72 |
|
|
|
77 |
answer_yet = parse_output_llm_with_sources(answer_yet)
|
78 |
history[-1] = (query, answer_yet)
|
79 |
|
80 |
+
elif (
|
81 |
+
event["name"] == "answer_rag_wrong"
|
82 |
+
and event["event"] == "on_chain_stream"
|
83 |
+
):
|
84 |
+
history[-1] = (query, event["data"]["chunk"]["answer"])
|
85 |
+
|
86 |
elif (
|
87 |
event["name"] == "retrieve_documents"
|
88 |
and event["event"] == "on_chain_end"
|
89 |
):
|
90 |
try:
|
91 |
+
# print(event)
|
92 |
+
# print("line 84")
|
93 |
docs = event["data"]["output"]["documents"]
|
94 |
docs_html = []
|
95 |
+
for i, doc in enumerate(docs, 1):
|
96 |
+
docs_html.append(make_html_source(i, doc))
|
97 |
+
# print(docs_html)
|
98 |
docs_html = "".join(docs_html)
|
99 |
+
# print(docs_html)
|
100 |
except Exception as e:
|
101 |
print(f"Error getting documents: {e}")
|
102 |
print(event)
|
|
|
106 |
display_output,
|
107 |
) in steps_display.items():
|
108 |
if event["name"] == event_name:
|
109 |
+
# print("line 99")
|
110 |
if event["event"] == "on_chain_start":
|
111 |
+
# print("line 101")
|
112 |
answer_yet = event_description
|
113 |
history[-1] = (query, answer_yet)
|
114 |
|
celsius_csrd_chatbot/agent.py
CHANGED
@@ -39,16 +39,12 @@ def route_intent(state):
|
|
39 |
return "intent_esrs"
|
40 |
|
41 |
elif esrs == "wrong_esrs":
|
42 |
-
return "
|
43 |
|
44 |
else:
|
45 |
return "retrieve_documents"
|
46 |
|
47 |
|
48 |
-
def make_id_dict(values):
|
49 |
-
return {k: k for k in values}
|
50 |
-
|
51 |
-
|
52 |
def make_graph_agent(llm, vectorstore):
|
53 |
workflow = StateGraph(GraphState)
|
54 |
|
@@ -70,11 +66,7 @@ def make_graph_agent(llm, vectorstore):
|
|
70 |
workflow.set_entry_point("categorize_esrs")
|
71 |
|
72 |
# CONDITIONAL EDGES
|
73 |
-
workflow.add_conditional_edges(
|
74 |
-
"categorize_esrs",
|
75 |
-
route_intent,
|
76 |
-
make_id_dict(["intent_esrs", "retrieve_documents", "answer_rag_wrong"]),
|
77 |
-
)
|
78 |
|
79 |
# Define the edges
|
80 |
workflow.add_edge("intent_esrs", "retrieve_documents")
|
|
|
39 |
return "intent_esrs"
|
40 |
|
41 |
elif esrs == "wrong_esrs":
|
42 |
+
return "answer_rag_wrong"
|
43 |
|
44 |
else:
|
45 |
return "retrieve_documents"
|
46 |
|
47 |
|
|
|
|
|
|
|
|
|
48 |
def make_graph_agent(llm, vectorstore):
|
49 |
workflow = StateGraph(GraphState)
|
50 |
|
|
|
66 |
workflow.set_entry_point("categorize_esrs")
|
67 |
|
68 |
# CONDITIONAL EDGES
|
69 |
+
workflow.add_conditional_edges("categorize_esrs", route_intent)
|
|
|
|
|
|
|
|
|
70 |
|
71 |
# Define the edges
|
72 |
workflow.add_edge("intent_esrs", "retrieve_documents")
|
celsius_csrd_chatbot/chains/answer_rag.py
CHANGED
@@ -36,6 +36,7 @@ answering_template = """
|
|
36 |
10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
|
37 |
11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
|
38 |
12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
|
|
|
39 |
|
40 |
Question: {query}
|
41 |
Answer:
|
|
|
36 |
10. Method Focus: When addressing "how" questions, emphasize methods and procedures over outcomes.
|
37 |
11. Selective Usage: You're not obligated to use every passage; include only those relevant to the question.
|
38 |
12. Insufficient Information: If documents lack necessary details, indicate that you don't have enough information.
|
39 |
+
13. Never mention these guidelines as a source attribution in your response.
|
40 |
|
41 |
Question: {query}
|
42 |
Answer:
|
celsius_csrd_chatbot/chains/esrs_categorization.py
CHANGED
@@ -5,7 +5,7 @@ def make_esrs_categorization_node():
|
|
5 |
|
6 |
def categorize_message(state):
|
7 |
query = state["query"]
|
8 |
-
pattern = r"ESRS \d
|
9 |
esrs_truth = [
|
10 |
"ESRS 1",
|
11 |
"ESRS 2",
|
@@ -25,7 +25,6 @@ def make_esrs_categorization_node():
|
|
25 |
if matches:
|
26 |
true_matches = [match for match in matches if match in esrs_truth]
|
27 |
output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
|
28 |
-
|
29 |
else:
|
30 |
output = {"esrs_type": "none"}
|
31 |
|
|
|
5 |
|
6 |
def categorize_message(state):
|
7 |
query = state["query"]
|
8 |
+
pattern = r"ESRS \d+[A-Z0-9]*"
|
9 |
esrs_truth = [
|
10 |
"ESRS 1",
|
11 |
"ESRS 2",
|
|
|
25 |
if matches:
|
26 |
true_matches = [match for match in matches if match in esrs_truth]
|
27 |
output = {"esrs_type": true_matches if true_matches else "wrong_esrs"}
|
|
|
28 |
else:
|
29 |
output = {"esrs_type": "none"}
|
30 |
|
celsius_csrd_chatbot/chains/esrs_intent.py
CHANGED
@@ -23,51 +23,41 @@ class ESRSAnalysis(BaseModel):
|
|
23 |
"ESRS S3",
|
24 |
"ESRS S4",
|
25 |
"ESRS G1",
|
26 |
-
"
|
27 |
] = Field(
|
28 |
-
description="""
|
29 |
-
Given a user question choose which documents would be most relevant for answering their question :
|
30 |
-
|
31 |
-
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
32 |
-
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
33 |
-
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
34 |
-
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
35 |
-
- ESRS E3 is for questions about water and marine resources
|
36 |
-
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
37 |
-
- ESRS E5 is for questions about resource use and circular economy
|
38 |
-
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
39 |
-
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
40 |
-
- SRS S3 is for questions about affected communities, impact on local communities
|
41 |
-
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
42 |
-
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
43 |
-
- none is for questions that do not fit into any of the above categories
|
44 |
-
|
45 |
-
Follow these guidelines :
|
46 |
-
|
47 |
-
- Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
|
48 |
-
- Remember, if the question is not related to any ESRS, the output should be 'none'.
|
49 |
-
""",
|
50 |
)
|
51 |
|
52 |
|
53 |
def make_esrs_intent_chain(llm):
|
54 |
-
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis)
|
55 |
prompt_template = """
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
Question: '{query}'
|
63 |
Answer:
|
64 |
"""
|
65 |
-
|
66 |
-
prompt = PromptTemplate(
|
67 |
-
template=prompt_template,
|
68 |
-
input_variables=["query"],
|
69 |
-
partial_variables={"format_instructions": parser.get_format_instructions()},
|
70 |
-
)
|
71 |
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
72 |
|
73 |
return chain
|
@@ -78,7 +68,9 @@ def make_esrs_intent_node(llm):
|
|
78 |
def intent_message(state):
|
79 |
query = state["query"]
|
80 |
categorization_chain = make_esrs_intent_chain(llm)
|
81 |
-
output =
|
|
|
|
|
82 |
|
83 |
return output
|
84 |
|
|
|
23 |
"ESRS S3",
|
24 |
"ESRS S4",
|
25 |
"ESRS G1",
|
26 |
+
"no_intent",
|
27 |
] = Field(
|
28 |
+
description="""The ESRS type that the user query refers to.""",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
)
|
30 |
|
31 |
|
32 |
def make_esrs_intent_chain(llm):
|
|
|
33 |
prompt_template = """
|
34 |
+
Please analyze the question and indicate if it refers to a specific ESRS.
|
35 |
+
|
36 |
+
Follow these definitions in order to choose the appropriate ESRS :
|
37 |
+
- ESRS 1 is for questions about general principles for preparing and presenting sustainability information in accordance with CSRD
|
38 |
+
- ESRS 2 is for questions about general disclosures related to sustainability reporting, including governance, strategy, impact, risk, opportunity management, and metrics and targets
|
39 |
+
- ESRS E1 is for questions about climate change, global warming, GES and energy
|
40 |
+
- ESRS E2 is for questions about air, water, and soil pollution, and dangerous substances
|
41 |
+
- ESRS E3 is for questions about water and marine resources
|
42 |
+
- ESRS E4 is for questions about biodiversity, nature, wildlife and ecosystems
|
43 |
+
- ESRS E5 is for questions about resource use and circular economy
|
44 |
+
- ESRS S1 is for questions about workforce and labor issues, job security, fair pay, and health and safety
|
45 |
+
- ESRS S2 is for questions about workers in the value chain, workers' treatment
|
46 |
+
- ESRS S3 is for questions about affected communities, impact on local communities
|
47 |
+
- ESRS S4 is for questions about consumers and end users, customer privacy, safety, and inclusion
|
48 |
+
- ESRS G1 is for questions about governance, risk management, internal control, and business conduct
|
49 |
+
- no_intent is for questions that do not fit into any of the above categories
|
50 |
+
|
51 |
+
Keep in mind these guidelines :
|
52 |
+
- Some questions could be related to multiple ESRS. In such case, choose the most appropriate one.
|
53 |
+
|
54 |
+
The output needs to respect a JSON format with 'esrs_type' as the key and the appropriate ESRS as the value.
|
55 |
|
56 |
Question: '{query}'
|
57 |
Answer:
|
58 |
"""
|
59 |
+
parser = PydanticOutputParser(pydantic_object=ESRSAnalysis, method="json_mode")
|
60 |
+
prompt = PromptTemplate(template=prompt_template, input_variables=["query"])
|
|
|
|
|
|
|
|
|
61 |
chain = {"query": itemgetter("query")} | prompt | llm | parser
|
62 |
|
63 |
return chain
|
|
|
68 |
def intent_message(state):
|
69 |
query = state["query"]
|
70 |
categorization_chain = make_esrs_intent_chain(llm)
|
71 |
+
output = {
|
72 |
+
"esrs_type": [categorization_chain.invoke({"query": query}).esrs_type]
|
73 |
+
}
|
74 |
|
75 |
return output
|
76 |
|
celsius_csrd_chatbot/chains/retriever.py
CHANGED
@@ -1,16 +1,15 @@
|
|
1 |
def make_retriever_node(vectorstore, k=10):
|
2 |
-
|
3 |
def retrieve_documents(state):
|
4 |
sources = state["esrs_type"]
|
5 |
query = state["query"]
|
6 |
-
if sources == "none":
|
7 |
-
|
8 |
else:
|
9 |
-
|
|
|
|
|
|
|
10 |
docs = []
|
11 |
-
docs_retrieved = vectorstore.similarity_search_with_score(
|
12 |
-
query=query, filter=filters_full, k=k
|
13 |
-
)
|
14 |
for doc in docs_retrieved:
|
15 |
doc_append = doc[0]
|
16 |
doc_append.metadata["similarity_score"] = doc[1]
|
|
|
1 |
def make_retriever_node(vectorstore, k=10):
|
|
|
2 |
def retrieve_documents(state):
|
3 |
sources = state["esrs_type"]
|
4 |
query = state["query"]
|
5 |
+
if sources == "none" or sources == "no_intent":
|
6 |
+
docs_retrieved = vectorstore.similarity_search_with_score(query=query, k=k)
|
7 |
else:
|
8 |
+
filters = {"ESRS_filter": {"$in": sources}}
|
9 |
+
docs_retrieved = vectorstore.similarity_search_with_score(
|
10 |
+
query=query, filter=filters, k=k
|
11 |
+
)
|
12 |
docs = []
|
|
|
|
|
|
|
13 |
for doc in docs_retrieved:
|
14 |
doc_append = doc[0]
|
15 |
doc_append.metadata["similarity_score"] = doc[1]
|