Transformers documentation

Token classification

You are viewing v4.22.2 version. A newer version v4.47.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Token classification

Token classification assigns a label to individual tokens in a sentence. One of the most common token classification tasks is Named Entity Recognition (NER). NER attempts to find a label for each entity in a sentence, such as a person, location, or organization.

This guide will show you how to fine-tune DistilBERT on the WNUT 17 dataset to detect new entities.

See the token classification task page for more information about other forms of token classification and their associated models, datasets, and metrics.

Load WNUT 17 dataset

Load the WNUT 17 dataset from the 🤗 Datasets library:

>>> from datasets import load_dataset

>>> wnut = load_dataset("wnut_17")

Then take a look at an example:

>>> wnut["train"][0]
{'id': '0',
 'ner_tags': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 8, 8, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
 'tokens': ['@paulwalk', 'It', "'s", 'the', 'view', 'from', 'where', 'I', "'m", 'living', 'for', 'two', 'weeks', '.', 'Empire', 'State', 'Building', '=', 'ESB', '.', 'Pretty', 'bad', 'storm', 'here', 'last', 'evening', '.']
}

Each number in ner_tags represents an entity. Convert the number to a label name for more information:

>>> label_list = wnut["train"].features[f"ner_tags"].feature.names
>>> label_list
[
    "O",
    "B-corporation",
    "I-corporation",
    "B-creative-work",
    "I-creative-work",
    "B-group",
    "I-group",
    "B-location",
    "I-location",
    "B-person",
    "I-person",
    "B-product",
    "I-product",
]

The ner_tag describes an entity, such as a corporation, location, or person. The letter that prefixes each ner_tag indicates the token position of the entity:

  • B- indicates the beginning of an entity.
  • I- indicates a token is contained inside the same entity (e.g., the State token is a part of an entity like Empire State Building).
  • 0 indicates the token doesn’t correspond to any entity.

Preprocess

Load the DistilBERT tokenizer to process the tokens:

>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

Since the input has already been split into words, set is_split_into_words=True to tokenize the words into subwords:

>>> tokenized_input = tokenizer(example["tokens"], is_split_into_words=True)
>>> tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
>>> tokens
['[CLS]', '@', 'paul', '##walk', 'it', "'", 's', 'the', 'view', 'from', 'where', 'i', "'", 'm', 'living', 'for', 'two', 'weeks', '.', 'empire', 'state', 'building', '=', 'es', '##b', '.', 'pretty', 'bad', 'storm', 'here', 'last', 'evening', '.', '[SEP]']

Adding the special tokens [CLS] and [SEP] and subword tokenization creates a mismatch between the input and labels. A single word corresponding to a single label may be split into two subwords. You will need to realign the tokens and labels by:

  1. Mapping all tokens to their corresponding word with the word_ids method.
  2. Assigning the label -100 to the special tokens [CLS] and [SEP] so the PyTorch loss function ignores them.
  3. Only labeling the first token of a given word. Assign -100 to other subtokens from the same word.

Here is how you can create a function to realign the tokens and labels, and truncate sequences to be no longer than DistilBERT’s maximum input length::

>>> def tokenize_and_align_labels(examples):
...     tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

...     labels = []
...     for i, label in enumerate(examples[f"ner_tags"]):
...         word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
...         previous_word_idx = None
...         label_ids = []
...         for word_idx in word_ids:  # Set the special tokens to -100.
...             if word_idx is None:
...                 label_ids.append(-100)
...             elif word_idx != previous_word_idx:  # Only label the first token of a given word.
...                 label_ids.append(label[word_idx])
...             else:
...                 label_ids.append(-100)
...             previous_word_idx = word_idx
...         labels.append(label_ids)

...     tokenized_inputs["labels"] = labels
...     return tokenized_inputs

Use 🤗 Datasets map function to tokenize and align the labels over the entire dataset. You can speed up the map function by setting batched=True to process multiple elements of the dataset at once:

>>> tokenized_wnut = wnut.map(tokenize_and_align_labels, batched=True)

Use DataCollatorForTokenClassification to create a batch of examples. It will also dynamically pad your text and labels to the length of the longest element in its batch, so they are a uniform length. While it is possible to pad your text in the tokenizer function by setting padding=True, dynamic padding is more efficient.

Pytorch
Hide Pytorch content
>>> from transformers import DataCollatorForTokenClassification

>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
TensorFlow
Hide TensorFlow content
>>> from transformers import DataCollatorForTokenClassification

>>> data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, return_tensors="tf")

Train

Pytorch
Hide Pytorch content

Load DistilBERT with AutoModelForTokenClassification along with the number of expected labels:

>>> from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

>>> model = AutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=14)

If you aren’t familiar with fine-tuning a model with the Trainer, take a look at the basic tutorial here!

At this point, only three steps remain:

  1. Define your training hyperparameters in TrainingArguments.
  2. Pass the training arguments to Trainer along with the model, dataset, tokenizer, and data collator.
  3. Call train() to fine-tune your model.
>>> training_args = TrainingArguments(
...     output_dir="./results",
...     evaluation_strategy="epoch",
...     learning_rate=2e-5,
...     per_device_train_batch_size=16,
...     per_device_eval_batch_size=16,
...     num_train_epochs=3,
...     weight_decay=0.01,
... )

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=tokenized_wnut["train"],
...     eval_dataset=tokenized_wnut["test"],
...     tokenizer=tokenizer,
...     data_collator=data_collator,
... )

>>> trainer.train()
TensorFlow
Hide TensorFlow content

To fine-tune a model in TensorFlow, start by converting your datasets to the tf.data.Dataset format with prepare_tf_dataset().

>>> tf_train_set = model.prepare_tf_dataset(
...     tokenized_wnut["train"],
...     shuffle=True,
...     batch_size=16,
...     collate_fn=data_collator,
... )

>>> tf_validation_set = model.prepare_tf_dataset(
...     tokenized_wnut["validation"],
...     shuffle=False,
...     batch_size=16,
...     collate_fn=data_collator,
... )

If you aren’t familiar with fine-tuning a model with Keras, take a look at the basic tutorial here!

Set up an optimizer function, learning rate schedule, and some training hyperparameters:

>>> from transformers import create_optimizer

>>> batch_size = 16
>>> num_train_epochs = 3
>>> num_train_steps = (len(tokenized_wnut["train"]) // batch_size) * num_train_epochs
>>> optimizer, lr_schedule = create_optimizer(
...     init_lr=2e-5,
...     num_train_steps=num_train_steps,
...     weight_decay_rate=0.01,
...     num_warmup_steps=0,
... )

Load DistilBERT with TFAutoModelForTokenClassification along with the number of expected labels:

>>> from transformers import TFAutoModelForTokenClassification

>>> model = TFAutoModelForTokenClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

Configure the model for training with compile:

>>> import tensorflow as tf

>>> model.compile(optimizer=optimizer)

Call fit to fine-tune the model:

>>> model.fit(x=tf_train_set, validation_data=tf_validation_set, epochs=3)

For a more in-depth example of how to fine-tune a model for token classification, take a look at the corresponding PyTorch notebook or TensorFlow notebook.