Setup

ds_name = "conll2003"
model_name = "google/electra-small-discriminator"

max_len = 512
bs = 16
val_bs = bs*2

lr = 3e-5
ds = load_dataset(ds_name)

Dataloaders

splits = get_splits(ds)
ds = concatenate_datasets([ds['train'], ds['validation']])
ds[0]
{'chunk_tags': [11, 21, 11, 12, 21, 22, 11, 12, 0],
 'id': '0',
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0],
 'pos_tags': [22, 42, 16, 21, 35, 37, 16, 21, 7],
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.']}
task = 'ner' # 'pos', 'chunk'
label_vocab = ds.features[f"{task}_tags"].feature.names
label_vocab
['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
tokenizer = AutoTokenizer.from_pretrained(model_name)
label_all_tokens = True

The preprocessing is explained in HuggingFace example notebook.

def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"{task}_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            # Special tokens have a word id that is None. We set the label to -100 so they are automatically
            # ignored in the loss function.
            if word_idx is None:
                label_ids.append(-100)
            # We set the label for the first token of each word.
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            # For the other tokens in a word, we set the label to either the current label or -100, depending on
            # the label_all_tokens flag.
            else:
                label_ids.append(label[word_idx] if label_all_tokens else -100)
            previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs
ds = ds.map(tokenize_and_align_labels, batched=True)

dblock = DataBlock(
    blocks = [TokenClassificationBlock(tokenizer=tokenizer, label_vocab=label_vocab)],
    get_x=KeyGetter(['input_ids', 'attention_mask', 'token_type_ids', 'labels']),
    splitter=RandomSplitter())
 
%%time
dls = dblock.dataloaders(ds, bs=bs, val_bs=val_bs, num_workes=2)
CPU times: user 49.8 s, sys: 1.55 s, total: 51.3 s
Wall time: 57.4 s
dls.show_batch(max_n=4)
text tags
0 u. n. relief officials said they were not aware that the tanks advancing on arbil were manned by iraqi troops as they advanced from kdp - controlled areas and raised kdp flags. B-ORG, B-ORG, B-ORG, B-ORG, O, O, O, O, O, O, O, O, O, O, O, O, B-LOC, B-LOC, O, O, O, B-MISC, O, O, O, O, O, O, O, O, O, O, O, O, B-ORG, B-ORG, O, O,
1 13 - thomas enqvist ( sweden ) vs. stephane simian ( france ) O, O, B-PER, I-PER, I-PER, O, B-LOC, O, O, O, B-PER, I-PER, I-PER, O, B-LOC, O,
2 san francisco 56 73. 434 14 1 / 2 B-ORG, I-ORG, O, O, O, O, O, O, O, O, O,
3 daniels pharmaceuticals manufactures prescription pharmaceutical products, the largest of which is levoxyl, a synthetic thyroid hormone for treating hypothyroidism. B-ORG, I-ORG, O, O, O, O, O, O, O, O, O, O, B-MISC, B-MISC, B-MISC, O, O, O, O, O, O, O, O, O, O, O, O, O, O,

Metrics

from datasets import load_metric
seqeval = load_metric("seqeval")

print(seqeval)
Metric(name: "seqeval", features: {'predictions': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='label'), length=-1, id='sequence')}, usage: """
Produces labelling scores along with its sufficient statistics
from a source against one or more references.

Args:
    predictions: List of List of predicted labels (Estimated targets as returned by a tagger)
    references: List of List of reference labels (Ground truth (correct) target values)
    suffix: True if the IOB prefix is after type, False otherwise. default: False
    scheme: Specify target tagging scheme. Should be one of ["IOB1", "IOB2", "IOE1", "IOE2", "IOBES", "BILOU"].
        default: None
    mode: Whether to count correct entity labels with incorrect I/B tags as true positives or not.
        If you want to only count exact matches, pass mode="strict". default: None.
    sample_weight: Array-like of shape (n_samples,), weights for individual samples. default: None
    zero_division: Which value to substitute as a metric value when encountering zero division. Should be on of 0, 1,
        "warn". "warn" acts as 0, but the warning is raised.

Returns:
    'scores': dict. Summary of the scores for overall and per type
        Overall:
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': F1 score, also known as balanced F-score or F-measure,
        Per type:
            'precision': precision,
            'recall': recall,
            'f1': F1 score, also known as balanced F-score or F-measure
Examples:

    >>> predictions = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
    >>> references = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
    >>> seqeval = datasets.load_metric("seqeval")
    >>> results = seqeval.compute(predictions=predictions, references=references)
    >>> print(list(results.keys()))
    ['MISC', 'PER', 'overall_precision', 'overall_recall', 'overall_f1', 'overall_accuracy']
    >>> print(results["overall_f1"])
    0.5
    >>> print(results["PER"]["f1"])
    1.0
""", stored examples: 0)
from fasthugs.metrics import MetricCallback
from typing import Tuple

class Seqeval(MetricCallback):

    def __init__(self, label_list, scores:Tuple[str]=('accuracy', 'f1', 'precision', 'recall')):
        self.metric = load_metric('seqeval')
        store_attr()
        self._register_value_funcs()

    @staticmethod
    def _get_score(obj, score, **kwargs):
        return obj.res[f"overall_{score}"]

    def preprocess(self, predictions, labels):

        # Remove ignored index (special tokens)
        true_predictions = [
            [self.label_list[p] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        true_labels = [
            [self.label_list[l] for (p, l) in zip(prediction, label) if l != -100]
            for prediction, label in zip(predictions, labels)
        ]
        return true_predictions, true_labels
wandb.init(reinit=True, project="fasthugs", entity="fastai_community",
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG);

Training

model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(label_vocab))
learn = TransLearner(dls, model, loss_func=noop, cbs=Seqeval(label_vocab))

The final layers of the model are initialized with random weights, we can varufy that the performance is as good as random choice:

learn.validate()
(#5) [2.1915693283081055,0.12161642694685589,0.03735191637630662,0.022816832887375335,0.10290729566648382]
cbs = []
learn.fit_one_cycle(2, lr, wd=0.01, cbs=cbs)
epoch train_loss valid_loss accuracy f1 precision recall time
0 0.197947 0.173893 0.952574 0.776549 0.774937 0.778168 02:01
1 0.147690 0.134610 0.962232 0.837942 0.829961 0.846078 02:01