tybrs commited on
Commit
1273a08
·
verified ·
1 Parent(s): 7aab96e

Upload bias_auc.py

Browse files

Added bias auc metric

Files changed (1) hide show
  1. bias_auc.py +135 -0
bias_auc.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import datasets
3
+ from datasets.features import Sequence, Value, ClassLabel
4
+ from sklearn.metrics import roc_auc_score
5
+ import numpy as np
6
+
7
+
8
+ _DESCRIPTION = """\
9
+ Suite of threshold-agnostic metrics that provide a nuanced view
10
+ of this unintended bias, by considering the various ways that a
11
+ classifier’s score distribution can vary across designated groups.
12
+
13
+ The following are computed:
14
+
15
+ - BNSP (Background Negative, Subgroup Positive); and
16
+ - BPSN (Background Positive, Subgroup Negative) AUC metrics
17
+
18
+ """
19
+
20
+ _CITATION = """\
21
+ @inproceedings{borkan2019nuanced,
22
+ title={Nuanced metrics for measuring unintended bias with real data for text classification},
23
+ author={Borkan, Daniel and Dixon, Lucas and Sorensen, Jeffrey and Thain, Nithum and Vasserman, Lucy},
24
+ booktitle={Companion proceedings of the 2019 world wide web conference},
25
+ pages={491--500},
26
+ year={2019}
27
+ }
28
+ """
29
+
30
+ _KWARGS_DESCRIPTION = """\
31
+ target list[list[str]]: list containing list of group targeted for each item
32
+ label list[int]: list containing label index for each item
33
+ output list[list[float]]: list of model output values for each
34
+ """
35
+
36
+ class BiasAUC(evaluate.EvaluationModule):
37
+ def _info(self):
38
+ return datasets.MetricInfo(
39
+ description=_DESCRIPTION,
40
+ citation=_CITATION,
41
+ inputs_description=_KWARGS_DESCRIPTION,
42
+ features=datasets.Features(
43
+ {
44
+ 'target': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
45
+ 'label': Value(dtype='int64', id=None),
46
+ 'output': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
47
+ }
48
+ ),
49
+ reference_urls=["https://arxiv.org/abs/1903.04561"],
50
+ )
51
+
52
+ def _genreate_subgroup(self, targets, labels, outputs, subgroup, target_class=None):
53
+ """Returns label and output score from `targets` and `labels`
54
+ if `subgroup` is in list of targeted groups found in `targets`
55
+ """
56
+ target_class = target_class if target_class is not None else np.asarray(outputs).shape[-1] - 1
57
+ for target, label, result in zip(targets, labels, outputs):
58
+ if subgroup in target:
59
+ yield label, result[target_class]
60
+
61
+ def _genreate_bpsn(self, targets, labels, outputs, subgroup, target_class=None):
62
+ """Returns label and output score from `targets` and `labels`
63
+ if (1) `subgroup` is in list of targeted groups found in `targets` and
64
+ label is not the same as `target_class`; or (2) `subgroup` is not in list of
65
+ targeted groups found in `targets` and label is the same as `target_class`
66
+ """
67
+ target_class = target_class if target_class is not None else np.asarray(outputs).shape[-1] - 1
68
+ for target, label, result in zip(targets, labels, outputs):
69
+ if not target:
70
+ continue
71
+ # background positive
72
+ if subgroup not in target and label == target_class:
73
+ yield label, result[target_class]
74
+ # subgroup negative
75
+ elif subgroup in target and label != target_class:
76
+ yield label, result[target_class]
77
+
78
+ def _genreate_bnsp(self, targets, labels, outputs, subgroup, target_class=None):
79
+ """Returns label and output score from `targets` and `labels`
80
+ if (1) `subgroup` is not in list of targeted groups found in `targets` and
81
+ label is the same as `target_class`; or (2) `subgroup` is in list of
82
+ targeted groups found in `targets` and label is not the same as `target_class`
83
+ """
84
+ # get the index from class
85
+ target_class = target_class if target_class is not None else np.asarray(outputs).shape[-1] - 1
86
+ for target, label, result in zip(targets, labels, outputs):
87
+ if not target:
88
+ continue
89
+ # background negative
90
+ if subgroup not in target and label != target_class:
91
+ yield label, result[target_class]
92
+ # subgroup positive
93
+ elif subgroup in target and label == target_class:
94
+ yield label, result[target_class]
95
+
96
+ def _auc_by_group(self, target, label, output, subgroup):
97
+ """ Compute bias AUC metrics
98
+ """
99
+
100
+ y_trues, y_preds = zip(*self._genreate_subgroup(target, label, output, subgroup))
101
+ subgroup_auc_score = roc_auc_score(y_trues, y_preds)
102
+
103
+ y_trues, y_preds = zip(*self._genreate_bpsn(target, label, output, subgroup))
104
+ bpsn_auc_score = roc_auc_score(y_trues, y_preds)
105
+
106
+ y_trues, y_preds = zip(*self._genreate_bnsp(target, label, output, subgroup))
107
+ bnsp_auc_score = roc_auc_score(y_trues, y_preds)
108
+
109
+
110
+ return {'Subgroup' : subgroup_auc_score,
111
+ 'BPSN' : bpsn_auc_score,
112
+ 'BNSP' : bnsp_auc_score}
113
+
114
+ def _update_overall(self, result, labels, outputs, power_value=-5):
115
+ """Compute the generalized mean of Bias AUCs"""
116
+ result['Overall'] = {}
117
+ for metric in ['Subgroup', 'BPSN', 'BNSP']:
118
+ metric_values = np.array([result[community][metric] for community in result
119
+ if community != 'Overall'])
120
+ metric_values **= power_value
121
+ mean_value = np.power(np.sum(metric_values)/(len(result) - 1), 1/power_value)
122
+ result['Overall'][f"{metric} generalized mean"] = mean_value
123
+ y_preds = [output[1] for output in outputs]
124
+ result['Overall']["Overall AUC"] = roc_auc_score(labels, y_preds)
125
+ return result
126
+
127
+
128
+ def _compute(self, target, label, output, subgroups=None):
129
+ if subgroups is None:
130
+ subgroups = set(group for group_list in target for group in group_list)
131
+ result = {subgroup : self._auc_by_group(target, label, output, subgroup)
132
+ for subgroup in subgroups}
133
+ result = self._update_overall(result, label, output)
134
+ return result
135
+