Ramon Meffert
commited on
Commit
·
492106d
1
Parent(s):
0157dfd
Add evaluation
Browse files- main.py +45 -22
- results/em_scores.csv +60 -0
- results/f1_scores.csv +60 -0
- src/evaluation.py +1 -0
- src/readers/dpr_reader.py +1 -2
- src/readers/longformer_reader.py +1 -1
- src/retrievers/faiss_retriever.py +3 -2
- src/utils/preprocessing.py +3 -1
main.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
from dotenv import load_dotenv
|
2 |
# needs to happen as very first thing, otherwise HF ignores env vars
|
3 |
load_dotenv()
|
@@ -5,8 +7,8 @@ load_dotenv()
|
|
5 |
import os
|
6 |
import pandas as pd
|
7 |
|
8 |
-
from dataclasses import dataclass
|
9 |
-
from typing import Dict, cast
|
10 |
from datasets import DatasetDict, load_dataset
|
11 |
|
12 |
from src.readers.base_reader import Reader
|
@@ -24,10 +26,15 @@ from src.utils.preprocessing import context_to_reader_input
|
|
24 |
from src.utils.timing import get_times, timeit
|
25 |
|
26 |
|
|
|
|
|
|
|
27 |
@dataclass
|
28 |
class Experiment:
|
29 |
retriever: Retriever
|
30 |
reader: Reader
|
|
|
|
|
31 |
|
32 |
|
33 |
if __name__ == '__main__':
|
@@ -45,21 +52,25 @@ if __name__ == '__main__':
|
|
45 |
retriever=FaissRetriever(
|
46 |
paragraphs,
|
47 |
FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
|
48 |
-
reader=DprReader()
|
|
|
49 |
),
|
50 |
"faiss_longformer": Experiment(
|
51 |
retriever=FaissRetriever(
|
52 |
paragraphs,
|
53 |
FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
|
54 |
-
reader=LongformerReader()
|
|
|
55 |
),
|
56 |
"es_dpr": Experiment(
|
57 |
retriever=ESRetriever(paragraphs),
|
58 |
-
reader=DprReader()
|
|
|
59 |
),
|
60 |
"es_longformer": Experiment(
|
61 |
retriever=ESRetriever(paragraphs),
|
62 |
-
reader=LongformerReader()
|
|
|
63 |
),
|
64 |
}
|
65 |
|
@@ -69,6 +80,8 @@ if __name__ == '__main__':
|
|
69 |
question = questions_test["question"][idx]
|
70 |
answer = questions_test["answer"][idx]
|
71 |
|
|
|
|
|
72 |
retrieve_timer = timeit(f"{experiment_name}.retrieve")
|
73 |
t_retrieve = retrieve_timer(experiment.retriever.retrieve)
|
74 |
|
@@ -80,28 +93,38 @@ if __name__ == '__main__':
|
|
80 |
scores, context = t_retrieve(question, 5)
|
81 |
reader_input = context_to_reader_input(context)
|
82 |
|
83 |
-
#
|
84 |
-
|
85 |
-
answers = t_read(question, reader_input, 5)
|
86 |
-
|
87 |
-
# Calculate softmaxed scores for readable output
|
88 |
-
# sm = torch.nn.Softmax(dim=0)
|
89 |
-
# document_scores = sm(torch.Tensor(
|
90 |
-
# [pred.relevance_score for pred in answers]))
|
91 |
-
# span_scores = sm(torch.Tensor(
|
92 |
-
# [pred.span_score for pred in answers]))
|
93 |
|
94 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
-
# TODO evaluation and storing of results
|
97 |
print()
|
98 |
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
104 |
|
|
|
|
|
|
|
105 |
|
106 |
# TODO evaluation and storing of results
|
107 |
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
from pprint import pprint
|
3 |
from dotenv import load_dotenv
|
4 |
# needs to happen as very first thing, otherwise HF ignores env vars
|
5 |
load_dotenv()
|
|
|
7 |
import os
|
8 |
import pandas as pd
|
9 |
|
10 |
+
from dataclasses import dataclass, field
|
11 |
+
from typing import Dict, cast, List
|
12 |
from datasets import DatasetDict, load_dataset
|
13 |
|
14 |
from src.readers.base_reader import Reader
|
|
|
26 |
from src.utils.timing import get_times, timeit
|
27 |
|
28 |
|
29 |
+
ExperimentResult = namedtuple('ExperimentResult', ['correct', 'given'])
|
30 |
+
|
31 |
+
|
32 |
@dataclass
|
33 |
class Experiment:
|
34 |
retriever: Retriever
|
35 |
reader: Reader
|
36 |
+
lm: str
|
37 |
+
results: List[ExperimentResult] = field(default_factory=list)
|
38 |
|
39 |
|
40 |
if __name__ == '__main__':
|
|
|
52 |
retriever=FaissRetriever(
|
53 |
paragraphs,
|
54 |
FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
|
55 |
+
reader=DprReader(),
|
56 |
+
lm="dpr"
|
57 |
),
|
58 |
"faiss_longformer": Experiment(
|
59 |
retriever=FaissRetriever(
|
60 |
paragraphs,
|
61 |
FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
|
62 |
+
reader=LongformerReader(),
|
63 |
+
lm="longformer"
|
64 |
),
|
65 |
"es_dpr": Experiment(
|
66 |
retriever=ESRetriever(paragraphs),
|
67 |
+
reader=DprReader(),
|
68 |
+
lm="dpr"
|
69 |
),
|
70 |
"es_longformer": Experiment(
|
71 |
retriever=ESRetriever(paragraphs),
|
72 |
+
reader=LongformerReader(),
|
73 |
+
lm="longformer"
|
74 |
),
|
75 |
}
|
76 |
|
|
|
80 |
question = questions_test["question"][idx]
|
81 |
answer = questions_test["answer"][idx]
|
82 |
|
83 |
+
# workaround so we can use the decorator with a dynamic name for
|
84 |
+
# time recording
|
85 |
retrieve_timer = timeit(f"{experiment_name}.retrieve")
|
86 |
t_retrieve = retrieve_timer(experiment.retriever.retrieve)
|
87 |
|
|
|
93 |
scores, context = t_retrieve(question, 5)
|
94 |
reader_input = context_to_reader_input(context)
|
95 |
|
96 |
+
# Requesting 1 answers results in us getting the best answer
|
97 |
+
given_answer = t_read(question, reader_input, 1)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
# Save the results so we can evaluate laters
|
100 |
+
if experiment.lm == "longformer":
|
101 |
+
experiment.results.append(
|
102 |
+
ExperimentResult(answer, given_answer[0]))
|
103 |
+
else:
|
104 |
+
experiment.results.append(
|
105 |
+
ExperimentResult(answer, given_answer.text))
|
106 |
|
|
|
107 |
print()
|
108 |
|
109 |
+
if os.getenv("ENABLE_TIMING", "false").lower() == "true":
|
110 |
+
# Save times
|
111 |
+
times = get_times()
|
112 |
+
df = pd.DataFrame(times)
|
113 |
+
os.makedirs("./results/", exist_ok=True)
|
114 |
+
df.to_csv("./results/timings.csv")
|
115 |
|
116 |
+
f1_results = pd.DataFrame(columns=experiments.keys())
|
117 |
+
em_results = pd.DataFrame(columns=experiments.keys())
|
118 |
+
for experiment_name, experiment in experiments.items():
|
119 |
+
em, f1 = zip(*list(map(
|
120 |
+
lambda r: evaluate(r.correct, r.given), experiment.results
|
121 |
+
)))
|
122 |
+
em_results[experiment_name] = em
|
123 |
+
f1_results[experiment_name] = f1
|
124 |
|
125 |
+
os.makedirs("./results/", exist_ok=True)
|
126 |
+
f1_results.to_csv("./results/f1_scores.csv")
|
127 |
+
em_results.to_csv("./results/em_scores.csv")
|
128 |
|
129 |
# TODO evaluation and storing of results
|
130 |
|
results/em_scores.csv
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,faiss_dpr,faiss_longformer,es_dpr,es_longformer
|
2 |
+
0,0,0,0,0
|
3 |
+
1,0,0,0,0
|
4 |
+
2,0,0,0,0
|
5 |
+
3,0,0,0,0
|
6 |
+
4,0,0,0,0
|
7 |
+
5,0,0,0,0
|
8 |
+
6,0,0,0,0
|
9 |
+
7,0,0,0,0
|
10 |
+
8,0,0,0,0
|
11 |
+
9,0,0,0,0
|
12 |
+
10,0,0,0,0
|
13 |
+
11,0,0,0,0
|
14 |
+
12,0,0,0,0
|
15 |
+
13,0,0,0,0
|
16 |
+
14,0,0,0,0
|
17 |
+
15,0,0,0,0
|
18 |
+
16,0,0,0,0
|
19 |
+
17,0,0,0,0
|
20 |
+
18,0,0,0,1
|
21 |
+
19,0,0,0,0
|
22 |
+
20,0,0,0,0
|
23 |
+
21,0,0,0,0
|
24 |
+
22,0,0,0,0
|
25 |
+
23,0,0,0,0
|
26 |
+
24,0,0,0,0
|
27 |
+
25,0,0,0,0
|
28 |
+
26,0,0,0,0
|
29 |
+
27,0,0,0,0
|
30 |
+
28,0,0,0,0
|
31 |
+
29,0,0,0,0
|
32 |
+
30,0,0,0,1
|
33 |
+
31,0,0,0,0
|
34 |
+
32,0,0,0,0
|
35 |
+
33,0,0,0,0
|
36 |
+
34,0,0,0,0
|
37 |
+
35,0,0,0,0
|
38 |
+
36,0,0,0,1
|
39 |
+
37,0,0,0,1
|
40 |
+
38,0,0,0,0
|
41 |
+
39,0,0,0,0
|
42 |
+
40,0,0,0,0
|
43 |
+
41,0,0,0,0
|
44 |
+
42,0,0,0,0
|
45 |
+
43,0,0,0,0
|
46 |
+
44,0,0,0,1
|
47 |
+
45,0,0,0,0
|
48 |
+
46,0,0,0,1
|
49 |
+
47,0,0,0,0
|
50 |
+
48,0,0,0,0
|
51 |
+
49,0,0,0,0
|
52 |
+
50,0,0,0,0
|
53 |
+
51,0,0,0,0
|
54 |
+
52,0,0,0,0
|
55 |
+
53,0,0,0,0
|
56 |
+
54,0,0,0,0
|
57 |
+
55,0,0,0,0
|
58 |
+
56,0,0,0,1
|
59 |
+
57,0,0,0,0
|
60 |
+
58,0,0,0,0
|
results/f1_scores.csv
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,faiss_dpr,faiss_longformer,es_dpr,es_longformer
|
2 |
+
0,0.0,0.0,0.13008130081300812,0.7692307692307692
|
3 |
+
1,0.0,0.0,0.0,0.5833333333333334
|
4 |
+
2,0.0,0.0,0.3076923076923077,0.8421052631578948
|
5 |
+
3,0.0,0.0,0.0,0.0
|
6 |
+
4,0.25,0.0,0.25,0.88
|
7 |
+
5,0.2222222222222222,0.08695652173913043,0.2222222222222222,0.5454545454545454
|
8 |
+
6,0.0,0.0,0.0,0.10526315789473685
|
9 |
+
7,0.0,0.0,0.0,0.14545454545454545
|
10 |
+
8,0.0,0.0,0.0,0.7499999999999999
|
11 |
+
9,0.1935483870967742,0.0,0.0,0.3913043478260869
|
12 |
+
10,0.0,0.0,0.10526315789473685,0.0
|
13 |
+
11,0.0,0.0,0.0,0.0
|
14 |
+
12,0.07407407407407407,0.0,0.06896551724137931,0.0
|
15 |
+
13,0.0,0.0,0.0,0.3076923076923077
|
16 |
+
14,0.2222222222222222,0.0,0.29090909090909095,0.7142857142857143
|
17 |
+
15,0.0,0.0,0.0,0.08695652173913043
|
18 |
+
16,0.0,0.0,0.4347826086956522,0.30769230769230765
|
19 |
+
17,0.0,0.0,0.5,0.0
|
20 |
+
18,0.0,0.0,0.0,1.0
|
21 |
+
19,0.0,0.0,0.07692307692307693,0.75
|
22 |
+
20,0.0,0.046511627906976744,0.0,0.7333333333333334
|
23 |
+
21,0.0,0.0,0.0,0.5806451612903226
|
24 |
+
22,0.0,0.0,0.25,0.0
|
25 |
+
23,0.0,0.0,0.7142857142857143,0.6153846153846153
|
26 |
+
24,0.15384615384615383,0.0,0.15384615384615383,0.6666666666666666
|
27 |
+
25,0.15384615384615383,0.0625,0.0909090909090909,0.0
|
28 |
+
26,0.2285714285714286,0.05714285714285715,0.0,0.0
|
29 |
+
27,0.19999999999999998,0.0,0.3636363636363636,0.0
|
30 |
+
28,0.0,0.0,0.3076923076923077,0.16666666666666669
|
31 |
+
29,0.5,0.0,0.07407407407407407,0.4
|
32 |
+
30,0.11764705882352941,0.0,0.0,0.9375
|
33 |
+
31,0.0,0.05405405405405406,0.12121212121212122,0.13953488372093023
|
34 |
+
32,0.0,0.0,0.0,0.6
|
35 |
+
33,0.0,0.0,0.0,0.3333333333333333
|
36 |
+
34,0.07692307692307693,0.06896551724137931,0.07407407407407407,0.8
|
37 |
+
35,0.0,0.0,0.0,0.049999999999999996
|
38 |
+
36,0.0,0.0,0.0,1.0
|
39 |
+
37,0.22222222222222224,0.0,0.0,0.7142857142857143
|
40 |
+
38,0.058823529411764705,0.0,0.0,0.0
|
41 |
+
39,0.33333333333333326,0.05128205128205129,0.33333333333333326,0.33333333333333326
|
42 |
+
40,0.5882352941176471,0.0,0.0,0.0
|
43 |
+
41,0.0,0.0,0.0909090909090909,0.0
|
44 |
+
42,0.0,0.0,0.0,0.0
|
45 |
+
43,0.0,0.0,0.0,0.0588235294117647
|
46 |
+
44,0.0,0.0,0.19999999999999998,0.8888888888888888
|
47 |
+
45,0.0,0.05714285714285714,0.13793103448275865,0.10256410256410256
|
48 |
+
46,0.0,0.07142857142857142,0.0,0.8888888888888888
|
49 |
+
47,0.19999999999999998,0.0,0.5714285714285714,0.9473684210526316
|
50 |
+
48,0.0,0.0,0.0,0.0
|
51 |
+
49,0.0,0.0,0.0,0.0
|
52 |
+
50,0.13333333333333333,0.0,0.125,0.17391304347826086
|
53 |
+
51,0.0,0.0,0.0,0.21052631578947367
|
54 |
+
52,0.0,0.0,0.28571428571428575,0.0
|
55 |
+
53,0.07692307692307691,0.06060606060606061,0.0,0.0
|
56 |
+
54,0.0,0.11111111111111112,0.0,0.6153846153846153
|
57 |
+
55,0.23809523809523808,0.0,0.0,0.19999999999999998
|
58 |
+
56,0.0,0.0,0.0,1.0
|
59 |
+
57,0.0,0.0,0.0,0.0
|
60 |
+
58,0.0,0.0,0.0,0.13333333333333333
|
src/evaluation.py
CHANGED
@@ -74,4 +74,5 @@ def evaluate(answer: Any, prediction: Any):
|
|
74 |
float: overall exact match
|
75 |
float: overall F1-score
|
76 |
"""
|
|
|
77 |
return exact_match(prediction, answer), f1(prediction, answer)
|
|
|
74 |
float: overall exact match
|
75 |
float: overall F1-score
|
76 |
"""
|
77 |
+
print(prediction, answer)
|
78 |
return exact_match(prediction, answer), f1(prediction, answer)
|
src/readers/dpr_reader.py
CHANGED
@@ -13,8 +13,7 @@ class DprReader(Reader):
|
|
13 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
14 |
"facebook/dpr-reader-single-nq-base")
|
15 |
self._model = DPRReader.from_pretrained(
|
16 |
-
"facebook/dpr-reader-single-nq-base"
|
17 |
-
)
|
18 |
|
19 |
def read(self,
|
20 |
query: str,
|
|
|
13 |
self._tokenizer = DPRReaderTokenizer.from_pretrained(
|
14 |
"facebook/dpr-reader-single-nq-base")
|
15 |
self._model = DPRReader.from_pretrained(
|
16 |
+
"facebook/dpr-reader-single-nq-base")
|
|
|
17 |
|
18 |
def read(self,
|
19 |
query: str,
|
src/readers/longformer_reader.py
CHANGED
@@ -24,7 +24,7 @@ class LongformerReader(Reader):
|
|
24 |
num_answers=5) -> List[Tuple]:
|
25 |
answers = []
|
26 |
|
27 |
-
for text in context['texts']:
|
28 |
encoding = self.tokenizer(query, text, return_tensors="pt")
|
29 |
input_ids = encoding["input_ids"]
|
30 |
attention_mask = encoding["attention_mask"]
|
|
|
24 |
num_answers=5) -> List[Tuple]:
|
25 |
answers = []
|
26 |
|
27 |
+
for text in context['texts'][:num_answers]:
|
28 |
encoding = self.tokenizer(query, text, return_tensors="pt")
|
29 |
input_ids = encoding["input_ids"]
|
30 |
attention_mask = encoding["attention_mask"]
|
src/retrievers/faiss_retriever.py
CHANGED
@@ -98,7 +98,8 @@ class FaissRetriever(Retriever):
|
|
98 |
def _embed_question(self, q):
|
99 |
match self.lm:
|
100 |
case "dpr":
|
101 |
-
tok = self.q_tokenizer(
|
|
|
102 |
return self.q_encoder(**tok)[0][0].numpy()
|
103 |
case "longformer":
|
104 |
tok = self.q_tokenizer(q, return_tensors="pt")
|
@@ -110,7 +111,7 @@ class FaissRetriever(Retriever):
|
|
110 |
match self.lm:
|
111 |
case "dpr":
|
112 |
tok = self.ctx_tokenizer(
|
113 |
-
p, return_tensors="pt", truncation=True)
|
114 |
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
115 |
return {"embeddings": enc}
|
116 |
case "longformer":
|
|
|
98 |
def _embed_question(self, q):
|
99 |
match self.lm:
|
100 |
case "dpr":
|
101 |
+
tok = self.q_tokenizer(
|
102 |
+
q, return_tensors="pt", truncation=True, padding=True)
|
103 |
return self.q_encoder(**tok)[0][0].numpy()
|
104 |
case "longformer":
|
105 |
tok = self.q_tokenizer(q, return_tensors="pt")
|
|
|
111 |
match self.lm:
|
112 |
case "dpr":
|
113 |
tok = self.ctx_tokenizer(
|
114 |
+
p, return_tensors="pt", truncation=True, padding=True)
|
115 |
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
116 |
return {"embeddings": enc}
|
117 |
case "longformer":
|
src/utils/preprocessing.py
CHANGED
@@ -17,7 +17,8 @@ def context_to_reader_input(result: Dict[str, List[str]]) \
|
|
17 |
# Prepare result
|
18 |
reader_result = {
|
19 |
'titles': [],
|
20 |
-
'texts': []
|
|
|
21 |
}
|
22 |
|
23 |
for n in range(num_entries):
|
@@ -31,6 +32,7 @@ def context_to_reader_input(result: Dict[str, List[str]]) \
|
|
31 |
|
32 |
reader_result['titles'].append(title)
|
33 |
reader_result['texts'].append(result['text'][n])
|
|
|
34 |
|
35 |
return reader_result
|
36 |
|
|
|
17 |
# Prepare result
|
18 |
reader_result = {
|
19 |
'titles': [],
|
20 |
+
'texts': [],
|
21 |
+
'scores': []
|
22 |
}
|
23 |
|
24 |
for n in range(num_entries):
|
|
|
32 |
|
33 |
reader_result['titles'].append(title)
|
34 |
reader_result['texts'].append(result['text'][n])
|
35 |
+
reader_result['scores'].append(result['text'][n])
|
36 |
|
37 |
return reader_result
|
38 |
|