Ramon Meffert commited on
Commit
492106d
·
1 Parent(s): 0157dfd

Add evaluation

Browse files
main.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  from dotenv import load_dotenv
2
  # needs to happen as very first thing, otherwise HF ignores env vars
3
  load_dotenv()
@@ -5,8 +7,8 @@ load_dotenv()
5
  import os
6
  import pandas as pd
7
 
8
- from dataclasses import dataclass
9
- from typing import Dict, cast
10
  from datasets import DatasetDict, load_dataset
11
 
12
  from src.readers.base_reader import Reader
@@ -24,10 +26,15 @@ from src.utils.preprocessing import context_to_reader_input
24
  from src.utils.timing import get_times, timeit
25
 
26
 
 
 
 
27
  @dataclass
28
  class Experiment:
29
  retriever: Retriever
30
  reader: Reader
 
 
31
 
32
 
33
  if __name__ == '__main__':
@@ -45,21 +52,25 @@ if __name__ == '__main__':
45
  retriever=FaissRetriever(
46
  paragraphs,
47
  FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
48
- reader=DprReader()
 
49
  ),
50
  "faiss_longformer": Experiment(
51
  retriever=FaissRetriever(
52
  paragraphs,
53
  FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
54
- reader=LongformerReader()
 
55
  ),
56
  "es_dpr": Experiment(
57
  retriever=ESRetriever(paragraphs),
58
- reader=DprReader()
 
59
  ),
60
  "es_longformer": Experiment(
61
  retriever=ESRetriever(paragraphs),
62
- reader=LongformerReader()
 
63
  ),
64
  }
65
 
@@ -69,6 +80,8 @@ if __name__ == '__main__':
69
  question = questions_test["question"][idx]
70
  answer = questions_test["answer"][idx]
71
 
 
 
72
  retrieve_timer = timeit(f"{experiment_name}.retrieve")
73
  t_retrieve = retrieve_timer(experiment.retriever.retrieve)
74
 
@@ -80,28 +93,38 @@ if __name__ == '__main__':
80
  scores, context = t_retrieve(question, 5)
81
  reader_input = context_to_reader_input(context)
82
 
83
- # workaround so we can use the decorator with a dynamic name for
84
- # time recording
85
- answers = t_read(question, reader_input, 5)
86
-
87
- # Calculate softmaxed scores for readable output
88
- # sm = torch.nn.Softmax(dim=0)
89
- # document_scores = sm(torch.Tensor(
90
- # [pred.relevance_score for pred in answers]))
91
- # span_scores = sm(torch.Tensor(
92
- # [pred.span_score for pred in answers]))
93
 
94
- # print_answers(answers, scores, context)
 
 
 
 
 
 
95
 
96
- # TODO evaluation and storing of results
97
  print()
98
 
99
- times = get_times()
 
 
 
 
 
100
 
101
- df = pd.DataFrame(times)
102
- os.makedirs("./results/", exist_ok=True)
103
- df.to_csv("./results/timings.csv")
 
 
 
 
 
104
 
 
 
 
105
 
106
  # TODO evaluation and storing of results
107
 
 
1
+ from collections import namedtuple
2
+ from pprint import pprint
3
  from dotenv import load_dotenv
4
  # needs to happen as very first thing, otherwise HF ignores env vars
5
  load_dotenv()
 
7
  import os
8
  import pandas as pd
9
 
10
+ from dataclasses import dataclass, field
11
+ from typing import Dict, cast, List
12
  from datasets import DatasetDict, load_dataset
13
 
14
  from src.readers.base_reader import Reader
 
26
  from src.utils.timing import get_times, timeit
27
 
28
 
29
+ ExperimentResult = namedtuple('ExperimentResult', ['correct', 'given'])
30
+
31
+
32
  @dataclass
33
  class Experiment:
34
  retriever: Retriever
35
  reader: Reader
36
+ lm: str
37
+ results: List[ExperimentResult] = field(default_factory=list)
38
 
39
 
40
  if __name__ == '__main__':
 
52
  retriever=FaissRetriever(
53
  paragraphs,
54
  FaissRetrieverOptions.dpr("./src/models/dpr.faiss")),
55
+ reader=DprReader(),
56
+ lm="dpr"
57
  ),
58
  "faiss_longformer": Experiment(
59
  retriever=FaissRetriever(
60
  paragraphs,
61
  FaissRetrieverOptions.longformer("./src/models/longformer.faiss")),
62
+ reader=LongformerReader(),
63
+ lm="longformer"
64
  ),
65
  "es_dpr": Experiment(
66
  retriever=ESRetriever(paragraphs),
67
+ reader=DprReader(),
68
+ lm="dpr"
69
  ),
70
  "es_longformer": Experiment(
71
  retriever=ESRetriever(paragraphs),
72
+ reader=LongformerReader(),
73
+ lm="longformer"
74
  ),
75
  }
76
 
 
80
  question = questions_test["question"][idx]
81
  answer = questions_test["answer"][idx]
82
 
83
+ # workaround so we can use the decorator with a dynamic name for
84
+ # time recording
85
  retrieve_timer = timeit(f"{experiment_name}.retrieve")
86
  t_retrieve = retrieve_timer(experiment.retriever.retrieve)
87
 
 
93
  scores, context = t_retrieve(question, 5)
94
  reader_input = context_to_reader_input(context)
95
 
96
+ # Requesting 1 answers results in us getting the best answer
97
+ given_answer = t_read(question, reader_input, 1)[0]
 
 
 
 
 
 
 
 
98
 
99
+ # Save the results so we can evaluate laters
100
+ if experiment.lm == "longformer":
101
+ experiment.results.append(
102
+ ExperimentResult(answer, given_answer[0]))
103
+ else:
104
+ experiment.results.append(
105
+ ExperimentResult(answer, given_answer.text))
106
 
 
107
  print()
108
 
109
+ if os.getenv("ENABLE_TIMING", "false").lower() == "true":
110
+ # Save times
111
+ times = get_times()
112
+ df = pd.DataFrame(times)
113
+ os.makedirs("./results/", exist_ok=True)
114
+ df.to_csv("./results/timings.csv")
115
 
116
+ f1_results = pd.DataFrame(columns=experiments.keys())
117
+ em_results = pd.DataFrame(columns=experiments.keys())
118
+ for experiment_name, experiment in experiments.items():
119
+ em, f1 = zip(*list(map(
120
+ lambda r: evaluate(r.correct, r.given), experiment.results
121
+ )))
122
+ em_results[experiment_name] = em
123
+ f1_results[experiment_name] = f1
124
 
125
+ os.makedirs("./results/", exist_ok=True)
126
+ f1_results.to_csv("./results/f1_scores.csv")
127
+ em_results.to_csv("./results/em_scores.csv")
128
 
129
  # TODO evaluation and storing of results
130
 
results/em_scores.csv ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,faiss_dpr,faiss_longformer,es_dpr,es_longformer
2
+ 0,0,0,0,0
3
+ 1,0,0,0,0
4
+ 2,0,0,0,0
5
+ 3,0,0,0,0
6
+ 4,0,0,0,0
7
+ 5,0,0,0,0
8
+ 6,0,0,0,0
9
+ 7,0,0,0,0
10
+ 8,0,0,0,0
11
+ 9,0,0,0,0
12
+ 10,0,0,0,0
13
+ 11,0,0,0,0
14
+ 12,0,0,0,0
15
+ 13,0,0,0,0
16
+ 14,0,0,0,0
17
+ 15,0,0,0,0
18
+ 16,0,0,0,0
19
+ 17,0,0,0,0
20
+ 18,0,0,0,1
21
+ 19,0,0,0,0
22
+ 20,0,0,0,0
23
+ 21,0,0,0,0
24
+ 22,0,0,0,0
25
+ 23,0,0,0,0
26
+ 24,0,0,0,0
27
+ 25,0,0,0,0
28
+ 26,0,0,0,0
29
+ 27,0,0,0,0
30
+ 28,0,0,0,0
31
+ 29,0,0,0,0
32
+ 30,0,0,0,1
33
+ 31,0,0,0,0
34
+ 32,0,0,0,0
35
+ 33,0,0,0,0
36
+ 34,0,0,0,0
37
+ 35,0,0,0,0
38
+ 36,0,0,0,1
39
+ 37,0,0,0,1
40
+ 38,0,0,0,0
41
+ 39,0,0,0,0
42
+ 40,0,0,0,0
43
+ 41,0,0,0,0
44
+ 42,0,0,0,0
45
+ 43,0,0,0,0
46
+ 44,0,0,0,1
47
+ 45,0,0,0,0
48
+ 46,0,0,0,1
49
+ 47,0,0,0,0
50
+ 48,0,0,0,0
51
+ 49,0,0,0,0
52
+ 50,0,0,0,0
53
+ 51,0,0,0,0
54
+ 52,0,0,0,0
55
+ 53,0,0,0,0
56
+ 54,0,0,0,0
57
+ 55,0,0,0,0
58
+ 56,0,0,0,1
59
+ 57,0,0,0,0
60
+ 58,0,0,0,0
results/f1_scores.csv ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,faiss_dpr,faiss_longformer,es_dpr,es_longformer
2
+ 0,0.0,0.0,0.13008130081300812,0.7692307692307692
3
+ 1,0.0,0.0,0.0,0.5833333333333334
4
+ 2,0.0,0.0,0.3076923076923077,0.8421052631578948
5
+ 3,0.0,0.0,0.0,0.0
6
+ 4,0.25,0.0,0.25,0.88
7
+ 5,0.2222222222222222,0.08695652173913043,0.2222222222222222,0.5454545454545454
8
+ 6,0.0,0.0,0.0,0.10526315789473685
9
+ 7,0.0,0.0,0.0,0.14545454545454545
10
+ 8,0.0,0.0,0.0,0.7499999999999999
11
+ 9,0.1935483870967742,0.0,0.0,0.3913043478260869
12
+ 10,0.0,0.0,0.10526315789473685,0.0
13
+ 11,0.0,0.0,0.0,0.0
14
+ 12,0.07407407407407407,0.0,0.06896551724137931,0.0
15
+ 13,0.0,0.0,0.0,0.3076923076923077
16
+ 14,0.2222222222222222,0.0,0.29090909090909095,0.7142857142857143
17
+ 15,0.0,0.0,0.0,0.08695652173913043
18
+ 16,0.0,0.0,0.4347826086956522,0.30769230769230765
19
+ 17,0.0,0.0,0.5,0.0
20
+ 18,0.0,0.0,0.0,1.0
21
+ 19,0.0,0.0,0.07692307692307693,0.75
22
+ 20,0.0,0.046511627906976744,0.0,0.7333333333333334
23
+ 21,0.0,0.0,0.0,0.5806451612903226
24
+ 22,0.0,0.0,0.25,0.0
25
+ 23,0.0,0.0,0.7142857142857143,0.6153846153846153
26
+ 24,0.15384615384615383,0.0,0.15384615384615383,0.6666666666666666
27
+ 25,0.15384615384615383,0.0625,0.0909090909090909,0.0
28
+ 26,0.2285714285714286,0.05714285714285715,0.0,0.0
29
+ 27,0.19999999999999998,0.0,0.3636363636363636,0.0
30
+ 28,0.0,0.0,0.3076923076923077,0.16666666666666669
31
+ 29,0.5,0.0,0.07407407407407407,0.4
32
+ 30,0.11764705882352941,0.0,0.0,0.9375
33
+ 31,0.0,0.05405405405405406,0.12121212121212122,0.13953488372093023
34
+ 32,0.0,0.0,0.0,0.6
35
+ 33,0.0,0.0,0.0,0.3333333333333333
36
+ 34,0.07692307692307693,0.06896551724137931,0.07407407407407407,0.8
37
+ 35,0.0,0.0,0.0,0.049999999999999996
38
+ 36,0.0,0.0,0.0,1.0
39
+ 37,0.22222222222222224,0.0,0.0,0.7142857142857143
40
+ 38,0.058823529411764705,0.0,0.0,0.0
41
+ 39,0.33333333333333326,0.05128205128205129,0.33333333333333326,0.33333333333333326
42
+ 40,0.5882352941176471,0.0,0.0,0.0
43
+ 41,0.0,0.0,0.0909090909090909,0.0
44
+ 42,0.0,0.0,0.0,0.0
45
+ 43,0.0,0.0,0.0,0.0588235294117647
46
+ 44,0.0,0.0,0.19999999999999998,0.8888888888888888
47
+ 45,0.0,0.05714285714285714,0.13793103448275865,0.10256410256410256
48
+ 46,0.0,0.07142857142857142,0.0,0.8888888888888888
49
+ 47,0.19999999999999998,0.0,0.5714285714285714,0.9473684210526316
50
+ 48,0.0,0.0,0.0,0.0
51
+ 49,0.0,0.0,0.0,0.0
52
+ 50,0.13333333333333333,0.0,0.125,0.17391304347826086
53
+ 51,0.0,0.0,0.0,0.21052631578947367
54
+ 52,0.0,0.0,0.28571428571428575,0.0
55
+ 53,0.07692307692307691,0.06060606060606061,0.0,0.0
56
+ 54,0.0,0.11111111111111112,0.0,0.6153846153846153
57
+ 55,0.23809523809523808,0.0,0.0,0.19999999999999998
58
+ 56,0.0,0.0,0.0,1.0
59
+ 57,0.0,0.0,0.0,0.0
60
+ 58,0.0,0.0,0.0,0.13333333333333333
src/evaluation.py CHANGED
@@ -74,4 +74,5 @@ def evaluate(answer: Any, prediction: Any):
74
  float: overall exact match
75
  float: overall F1-score
76
  """
 
77
  return exact_match(prediction, answer), f1(prediction, answer)
 
74
  float: overall exact match
75
  float: overall F1-score
76
  """
77
+ print(prediction, answer)
78
  return exact_match(prediction, answer), f1(prediction, answer)
src/readers/dpr_reader.py CHANGED
@@ -13,8 +13,7 @@ class DprReader(Reader):
13
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
14
  "facebook/dpr-reader-single-nq-base")
15
  self._model = DPRReader.from_pretrained(
16
- "facebook/dpr-reader-single-nq-base"
17
- )
18
 
19
  def read(self,
20
  query: str,
 
13
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
14
  "facebook/dpr-reader-single-nq-base")
15
  self._model = DPRReader.from_pretrained(
16
+ "facebook/dpr-reader-single-nq-base")
 
17
 
18
  def read(self,
19
  query: str,
src/readers/longformer_reader.py CHANGED
@@ -24,7 +24,7 @@ class LongformerReader(Reader):
24
  num_answers=5) -> List[Tuple]:
25
  answers = []
26
 
27
- for text in context['texts']:
28
  encoding = self.tokenizer(query, text, return_tensors="pt")
29
  input_ids = encoding["input_ids"]
30
  attention_mask = encoding["attention_mask"]
 
24
  num_answers=5) -> List[Tuple]:
25
  answers = []
26
 
27
+ for text in context['texts'][:num_answers]:
28
  encoding = self.tokenizer(query, text, return_tensors="pt")
29
  input_ids = encoding["input_ids"]
30
  attention_mask = encoding["attention_mask"]
src/retrievers/faiss_retriever.py CHANGED
@@ -98,7 +98,8 @@ class FaissRetriever(Retriever):
98
  def _embed_question(self, q):
99
  match self.lm:
100
  case "dpr":
101
- tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
 
102
  return self.q_encoder(**tok)[0][0].numpy()
103
  case "longformer":
104
  tok = self.q_tokenizer(q, return_tensors="pt")
@@ -110,7 +111,7 @@ class FaissRetriever(Retriever):
110
  match self.lm:
111
  case "dpr":
112
  tok = self.ctx_tokenizer(
113
- p, return_tensors="pt", truncation=True)
114
  enc = self.ctx_encoder(**tok)[0][0].numpy()
115
  return {"embeddings": enc}
116
  case "longformer":
 
98
  def _embed_question(self, q):
99
  match self.lm:
100
  case "dpr":
101
+ tok = self.q_tokenizer(
102
+ q, return_tensors="pt", truncation=True, padding=True)
103
  return self.q_encoder(**tok)[0][0].numpy()
104
  case "longformer":
105
  tok = self.q_tokenizer(q, return_tensors="pt")
 
111
  match self.lm:
112
  case "dpr":
113
  tok = self.ctx_tokenizer(
114
+ p, return_tensors="pt", truncation=True, padding=True)
115
  enc = self.ctx_encoder(**tok)[0][0].numpy()
116
  return {"embeddings": enc}
117
  case "longformer":
src/utils/preprocessing.py CHANGED
@@ -17,7 +17,8 @@ def context_to_reader_input(result: Dict[str, List[str]]) \
17
  # Prepare result
18
  reader_result = {
19
  'titles': [],
20
- 'texts': []
 
21
  }
22
 
23
  for n in range(num_entries):
@@ -31,6 +32,7 @@ def context_to_reader_input(result: Dict[str, List[str]]) \
31
 
32
  reader_result['titles'].append(title)
33
  reader_result['texts'].append(result['text'][n])
 
34
 
35
  return reader_result
36
 
 
17
  # Prepare result
18
  reader_result = {
19
  'titles': [],
20
+ 'texts': [],
21
+ 'scores': []
22
  }
23
 
24
  for n in range(num_entries):
 
32
 
33
  reader_result['titles'].append(title)
34
  reader_result['texts'].append(result['text'][n])
35
+ reader_result['scores'].append(result['text'][n])
36
 
37
  return reader_result
38