Social Bias NER with BERT
In the previous blog, we learned that classifying entire sentences as "biased" or "fair" may be too large of an abstraction for effective training. Instead, what if we label words with semantic parts-of-speech such as: generalizations, unfairness, and stereotypes.
I'll walk through building such a named-entity recognition (NER) model for social bias entities in this article, which is a core contribution of our Ethical Spectacle Research GUS-Net paper, to be published in September, 2024 ;).
Model: 🤖Try it Out | 🔬Model Repo
Notebooks: ✍️Synthetic Data Annotation ipynb | 🏫️ Model Training ipynb Related Events: 📅Social Bias Hackathon (Sept '24) | ☕️Coding Workshops
Synthetic Dataset
To train BERT to classify words/tokens as generalizations, unfairness, or stereotypes, we need a dataset where words have been labeled with our entities... The problem is, no dataset with our entities exists.
To build our dataset from scratch, we'll have to build an annotation pipeline. In the past, this would probably be done with humans, but thanks to LM frameworks like LangChain or Stanford DSPy, we can build a similar team of annotator agents.
For our task, we'll use DSPy to create LM programs (i.e. agents) for annotating a sentence with each entity individually. Then we'll aggregate them into a list of lists, where each sub-list contains entity labels for each word. Here's an example of the labels output by our synthetic data pipeline:
We've annotated 3.5k records (sentences of varying bias/fair sentiments and varying targets), with labels as depicted above. Now we can move forward with building a model that can learn from our dataset, and label words in unseen sentences with our entities.
Note: Still undecided on publishing the dataset, it could be misused. Instead, you can use our pipeline to annotate your own dataset, in this notebook. It uses compiled DSPy modules, that have examples of their entity labels in each prompt, and guardrails (suggestion/retry logic).
Model and Training Architecture
To check out the code we used for training this model, open this notebook. Note that you won't be able to run it without loading a dataset (see previous section).
We're going to be training BertForTokenClassification, a module from the HuggingFace transformers library. It will load a model and allow us to use it as a PyTorch.nn module, configured automatically for our specified number of classes (2 for each entity + 1 for 'O').
Token Classification: The list of tokens that make up each text sequence will be passed into BertForTokenClassification, and it will process each token's individual classifications in parallel. Essentially, you can think of this as a multi-label classification on each token of the sequence (where each token can fall in multiple classes). The encodings BERT creates for each token still include information about the tokens on either side (i.e. the context). It's worth noting that typically, token classification is done via multi-class classification, where each token can be assigned only one label.
Loss Function: This is where things start to get wonky. We're labeling each token with a list of multiple potential labels. This means that when processing each sequence, our prediction output is actually a list of lists (2 dimensional tensor) that looks like this:
- Tokenized text: ['nursing', 'homes', 'are', ...]
- Labels: [[0,0,0,1,0,0,0,0,0], [0,0,0,0,1,0,0,0,1], [1,0,0,0,0,0,0,0,0],...]
For loss calculation we reduce both the predicted labels tensor and the true labels tensor to one dimension, like this:
- Labels: [0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0...]
Since they will still be equal length, we can apply binary cross entropy and calculate our loss for the entire sequence.
Good to know: BERT's tokenizer applies padding tokens so that every input to BERT is a fixed size vector (in our case, 128 token length). We also have to pad the first dimension of our true labels to the same size, but with -100, to be ignored during our loss calculation. This is done during a "tokenize and align labels" preprocessing step.
Optimizer
We'll use AdamW, which is common for training BERT. It will apply our learning rate to the gradients calculated from our loss (during backpropagation) to update our model's weights after every batch in training. It also implements weight decay and momentum smoothing when calculating weights.
Learning Rate Scheduler
By using a linear schedule with warmup, we'll avoid making large adjustments to weights before the model has seen our training data, and we'll make more precise changes to weights over the course of training. We'll end up with a learning rate schedule that looks like this:
Evaluation Metrics
Evaluating this multi-label NER model is not as straight forward as it was for the sequence classification model, where we had two binary datapoints to compare (i.e. true label, predicted label) to compare to each other. Now, we have more options...
- Label-level: In the previous step, we flattened the model's output tensor to one dimension, which also does in fact make it suitable for binary evaluation metrics like accuracy. The problem with this, is it will lead to a disproportionate amount of true negatives to true positives in our evaluation dataset. We would want to focus on evaluation metrics of true positives, because the majority the time, not all (or any) of our entities will be present.
- Token-level: Instead, we could leave the true and pred tensors 2d (i.e. list of lists), and then compare each list exactly to the other. This would avoid the flattening bias (tons of true negatives), but if both lists of labels do not match exactly, it is considered a failure (i.e. no credit for being close).
- Entity-level: Or we could check entities based on their fullness. Often, our entities span many words, and by interpreting the B- and I- tags, we can check that the entity spans the correct number of tokens/words. In this case, a prediction would be incorrect if the labels do not span the full entity, also no credit for being close.
To pick, we first need to decide on an objective. Do we care about getting close? In some applications, a false positive is very dangerous. In our case, we can use a combination of metrics for a sense of each level.
In multi-label NER, we're most interested in measuring the fraction of the labels for each token that were correctly predicted. This gives credit for "close" predictions.
Hamming Loss
I think of hamming loss as a mix of label-level and token-level evaluation. It compares the two lists of labels for each token, returning the portion of labels that were incorrect. In a case like this, hamming loss is lower than if we compared the lists exactly.
- True: [0,0,0,0,0,1,0,1,0]
- Pred: [0,0,0,0,0,1,0,0,0] # 1 of 9 labels is incorrect
- Exact Match Loss: 1 (100% incorrect), Hamming Loss: 0.1111 (11% incorrect) I like hamming loss for this evaluation, because it's indirectly linked to entity fullness, like in entity-level eval (but without the processing of the BIO format).
Remember, we're comparing a list of labels like the ones above, for every token in the text sequence (imagine it working through the sentence one word at a time). Luckily, hamming loss can handle this and will provide a macro average, of the losses created for all the tokens in the sequence.
Precision, Recall, F1
Lucky for us torchmetrics has multi-label metrics that will calculate the recall, precision and F1 scores across all of our individual entities simultaneously. We can isolate the metrics for each entity in our list of potential labels, or average them together.
- Precision = TP / TP + FP - Proportion of positive predictions that were actually correct.
- Recall: TP / TP + FN - Proportion of true positives that our model predicted correctly.
- F1 Score: 2 x ((Precision x Recall) / (Precision + Recall)) - Harmonic mean, think of it like accuracy that balances out the true negatives and true positives.
These scores will give us a good understanding of how our effective model is at predicting the entities (instead of how effective it is at NOT predicting the entities). In evaluation, we'll look at the macro-scores for the whole model (all entities averaged).
Confusion Matrix
Mostly because they're cool, but also because they help us visually compare each entity's individual performance, we'll make confusion matrices during evaluation. I'll merge the inputs for B- and I- tags, so we'll just have 4 matrices: 'O', 'GEN', 'UNFAIR', 'STEREO'. Each matrix will depict the true positives and negatives rates, though we should expect a lot of true negatives. Here's a key breaking down the confusion matrix:
Training the Model
Preprocessing:
Tokenize: Before we start training, we have to preprocess our training text sequences into lists of tokens that BERT has been trained on. The tokenizer also pads the input up to a set max token length, effectively making every sentence the same length for BERT. We're building a sentence level NER model, so, we'll use 128 max len.
Align Labels: The labels in our training set are for each word. However, during tokenization, many words are split into sub-words. We also have to correctly parse the label lists in our training set, making sure that tokens that are part of a sub-word inherit the label from their parent word.
We then redefine these as a list of lists like we saw in the previous step, where each token in a sequence has one fixed size vector of labels. Finally, we need to make the labels vector the same length as the tokens fixed size vector, by adding padding labels that align with the padding tokens. We can use label vectors of -100, which will be ignored during the loss calculation.
Good to know: You'll want to be careful of the first token of every sequence, CLS (added by the BERT tokenizer). It could misalign your labels if you don't add a vector of -100's before your labels. The notebook aligns tokens and labels based on word_ids (output by the tokenizer) which represent CLS and padding tokens the same way: None. This works as a mask for adding lists of -100 for the tokens we wish to ignore.
Baseline:
The most similar NER model I've seen is Nbias, which labels tokens with a single entity: BIAS. That's not quite similar enough to effectively compare to our model, so I'll take a guess at some training parameters, and use the results as baseline metrics that we can move forward to optimize.
Interpreting Results: Off to a great start! Our key metric, hamming loss, has a strong (low) baseline of 0.1137 (11% of predictions were incorrect). However, we can see that doesn't mean we're predicting entities entirely accurately. Our model detected 43% of the potential presences (i.e. recall). On the other hand, when it did label "Present", 65% of those were correct (i.e. precision). All things considered, these still aren't bad scores. One important thing to notice is the confusion matrices. They're as expected, weighted heavily towards "Not-present" (excluding 'O'). Perhaps predictably, they're biased towards not-present predictions. All of our new entities have more false negatives than false positives. In fact, the UNFAIR entity didn't even have a single positive prediction!!
Promoting Entity Prediction
The obvious way to increase positive classifications would be to lower the threshold at which probabilities become labels. I'll save you some time, I tried a few different thresholds, and 0.5 seems to be the best. Decreasing didn't improve precision or recall, it increased some true positives, but more false positives. A simple change in threshold won't cut it.
Instead, we can modify our loss function. I started by creating an ignore mask (a list corresponding to tokens/labels to ignore), and defined it to ignore all tokens that had an 'O' label in the training set. This created a drastic increase in entity predictions (excluding 'O'), but still too many false positives, ultimately with worse metrics across the board. We have one more neat option: Focal Loss.
Focal Loss will apply binary cross entropy, but with a modifier that adjusts the loss based on the predictions probability. A confident prediction will have a reduced impact on the total loss, conversely a low confidence prediction will have a larger impact on the loss (both for being correct and for being wrong).
Focal Loss Implementation
In the notebook you'll find a focal loss implementation that I threw together from a few examples on the pytorch forum. Our implementation:
- Creates a mask of active tokens/labels (i.e. not CLS/SEP/PAD/-100's).
- Preforms BCE (prediction logits vs true labels).
- Convert pred logits to probabilities.
- Calculate probability for each class being present (p_t) - Used in identifying which labels
- are the "hardest" to predict.
- Calculate modulation factors alpha_t and focal modulation.
- Multiply alpha_t, focal modulation, and BCE loss.
- Zero the losses for CLS, SEP, PAD or -100 tokens (using the mask).
- Sum the active loss values.
We can tune our focal loss with these two arguments:
- alpha - Emphasis on positive predictions.
- gamma - Emphasis on the "harder" entities (lower probability).
After trying a few different alpha and gamma values, here are the results of training with alpha = 0.75 and gamma = 3:
Interpreting Results: It worked!! Big improvements across the board. We detected 79% of total presences, and 73% of our positive predictions were correct!!! To me, that's a great sign we've found the right loss function. However, we're still not doing well predicting 'UNFAIR' accurately so it might be worth revisiting the gamma, and possibly increasing it.
Sum vs. Mean Pooling Loss
What if instead of adding all the loss values together, we find the mean of the losses for each label. This will still encode the same information, but smooths outliers and keeps the loss on a consistent scale batch-to-batch. This can be especially important in cases of high loss variance, as is common when using focal loss or multi-label classification. Switching to mean pooling (with all the same training params) saw another increase in accuracy!
Notes on Training:
I noticed just by trying out out model manually on a few sentences, that we were often predicting both B- and I- tag entities for the same tokens. It made me wonder if the model had found a clever way to achieve the lowest loss distribution by predicting both tags for an entity. I reduced the alpha to 0.65 and it seemed to clear the doubles up.
As a sanity check, I tried a few different batch sizes and learning rates. Nothing made an improvement, but higher learning rates demonstrated potential to be just as effective as what I ended up using: 5e-5.
Conclusion
In the sections above we:
- Built a synthetically annotated dataset, with the social bias entities: Generalizations (GEN), Unfairness (UNFAIR), and Stereotypes (STEREO).
- Optimized training and evaluation for multi-label token classification with BERT.
- Trained the model with focal loss, to account for the disproportionate representations of labels in the training data and our multi-label approach.
And it worked!! Our model has fit the training data well, only 6.6% of the labels predicted on the test set were incorrect (hamming loss)! More importantly, it detected 76% of the potential entity presences (recall). Finally, when the model labeled an entity as present, it was accurate 82% of the time (precision).
For all the complexity it added to the training process, doing multi-label token classification (instead of multi-class) gave us some unique results. We can see nested entities, and patterns emerge. I've noticed stereotypes often contain nested generalizations, being assigned some unfairness. Here's an example of an explicitly biased statement.
Resources:
Build your own NER model: You can follow this guide to train a NER model on your own entities, or your own definitions. Here's the notebook for building a dataset💻, and here's the notebook for training💻. [🎨HuggingFace Space])(https://huggingface.co/spaces/maximuspowers/bias-detection-ner) | 🔬Model Repo | ⚠️Binary Sequence Classification of Bias