tybrs commited on
Commit
a87b5c5
·
verified ·
1 Parent(s): de109df

Update bias_auc.py

Browse files
Files changed (1) hide show
  1. bias_auc.py +20 -18
bias_auc.py CHANGED
@@ -120,25 +120,25 @@ class BiasAUC(evaluate.Metric):
120
  # subgroup positive
121
  elif subgroup in target and label == target_class:
122
  yield label, result[target_class]
123
-
 
 
 
 
 
 
 
 
 
124
  def _auc_by_group(self, target, label, output, subgroup):
125
  """ Compute bias AUC metrics
126
  """
127
-
128
- y_trues, y_preds = zip(*self._genreate_subgroup(target, label, output, subgroup))
129
- subgroup_auc_score = roc_auc_score(y_trues, y_preds)
130
-
131
- y_trues, y_preds = zip(*self._genreate_bpsn(target, label, output, subgroup))
132
- bpsn_auc_score = roc_auc_score(y_trues, y_preds)
133
-
134
- y_trues, y_preds = zip(*self._genreate_bnsp(target, label, output, subgroup))
135
- bnsp_auc_score = roc_auc_score(y_trues, y_preds)
136
 
137
-
138
- return {'Subgroup' : subgroup_auc_score,
139
- 'BPSN' : bpsn_auc_score,
140
- 'BNSP' : bnsp_auc_score}
141
-
142
  def _update_overall(self, result, labels, outputs, power_value=-5):
143
  """Compute the generalized mean of Bias AUCs"""
144
  result['Overall'] = {}
@@ -149,9 +149,11 @@ class BiasAUC(evaluate.Metric):
149
  mean_value = np.power(np.sum(metric_values)/(len(result) - 1), 1/power_value)
150
  result['Overall'][f"{metric} generalized mean"] = mean_value
151
  y_preds = [output[1] for output in outputs]
152
- result['Overall']["Overall AUC"] = roc_auc_score(labels, y_preds)
153
- return result
154
-
 
 
155
 
156
  def _compute(self, target, label, output, subgroups=None):
157
  if subgroups is None:
 
120
  # subgroup positive
121
  elif subgroup in target and label == target_class:
122
  yield label, result[target_class]
123
+
124
+ def _get_auc_score(self, gen_func, *args, **kwargs):
125
+ try:
126
+ y_trues, y_preds = zip(*self.gen_func(args))
127
+ score = roc_auc_score(y_trues, y_preds)
128
+ except ValueError:
129
+ print(f"Sample not sufficient for target clases '{args[-1]}' subgroup metric (need correct and incorrect predictions for '{args[-1]}')")
130
+ score = np.nan
131
+ return score
132
+
133
  def _auc_by_group(self, target, label, output, subgroup):
134
  """ Compute bias AUC metrics
135
  """
136
+ return {
137
+ 'Subgroup' : self._get_auc_score(self._genreate_subgroup, target, label, output, subgroup),
138
+ 'BPSN' : self._get_auc_score(self._genreate_bpsn, target, label, output, subgroup),
139
+ 'BNSP' : self._get_auc_score(self._genreate_bnsp, target, label, output, subgroup)
140
+ }
 
 
 
 
141
 
 
 
 
 
 
142
  def _update_overall(self, result, labels, outputs, power_value=-5):
143
  """Compute the generalized mean of Bias AUCs"""
144
  result['Overall'] = {}
 
149
  mean_value = np.power(np.sum(metric_values)/(len(result) - 1), 1/power_value)
150
  result['Overall'][f"{metric} generalized mean"] = mean_value
151
  y_preds = [output[1] for output in outputs]
152
+ try:
153
+ result['Overall']["Overall AUC"] = roc_auc_score(labels, y_preds)
154
+ except ValueError:
155
+ result['Overall']["Overall AUC"] = np.nan
156
+ return result
157
 
158
  def _compute(self, target, label, output, subgroups=None):
159
  if subgroups is None: