Hi,

This post is a short overview over a work project, where I trained a language model for invoices. This so-called base model is then fine-tuned for text classification on customer data. Due to data privacy, a non-disclosure agreement, ISO 27001 and SOAP2, I’m not allowed to publish any results. Believe me, it works like 🚀✨🪐.

A language model is trained on large amounts of textual data to understand the patterns and structure of language. The primary goal of a language model is to predict the probability of the next word or sequence of words in a sentence given the previous words.

Language models can be used for a variety of natural language processing (NLP) tasks, such as text classification, machine translation, text summarization, speech recognition, and sentiment analysis. There are many types of language models, ranging from simple n-gram models to more complex neural network-based models such as recurrent neural networks (RNNs) and transformers.

The transformer architecture is currently mostly used for language models and can be divided into an encoder and/or decoder architecture depending on the specific task. In general, transformers are trained on a large quantity of unlabeled text using self-supervised learning. The training of a transformer model on a lot of data takes a lot of computational effort and the training of language models can get expensive very quickly. So, often the best way to have a task-specific transformer model is to use a pre-trained model from Hugging Face and fine-tune the model based on your data.

Based on my work experience with invoices, fine-tuning a pre-existing model didn’t work well. I received the best results for text classification after fine-tuning a french base-model on german invoices. Nevertheless the overall F1-score wasn’t worth the effort. I assume that the content and structure of an invoice differs too much from the training data (e.g. no continuous text and many numbers). Additional, the tokenizers of the pre-trained models are not optimied for invoices, so the context window of a transformer will contain less text, which makes the training less effective.

I worked on text classification of invoices for multiple clients. I trained a base-model on a few million invoices (mostly german and english) and fine-tuned the base model for each client with around 2000 - 50000 invoices and 70 - 2000 labels. Initially I used the Longformer architecture (Beltagy et al. 2020), but a bug prevented the model deployment. Besides its limitations, I used the BERT architecture Devlin et al. 2019. Hugging Face also provides a tutorial for training language models. .

Tokenizer

A tokenizer converts raw text into smaller units, such as words or subwords, that can be used for training machine learning models. The tokenizer takes as input a string of text and outputs a sequence of tokens, each of which represents a distinct unit of meaning. The subword tokenizer breaks down words into smaller subword units. This is useful for handling out-of-vocabulary (OOV) words, which are words that are not in the training data.

The Byte-Pair Encoding tokenizer replaces the most common pair of consecutive bytes with bytes that does not occur in that data (Gage 1994, Sennrich et al. 2016).

First, we define our BPE tokenizer with the preprocessing steps for the incoming text data. As normalization we use unicode-normalization and set the text to lower case. Further preprocessing steps are a ByteLevel representation of the text followed by splitting the text by whitespaces. As a last step, we decode a tokenized input to the original one.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from tokenizers import normalizers
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.normalizers import NFD, Lowercase, NFKC
from tokenizers import pre_tokenizers
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
from tokenizers import Tokenizer, models, trainers

tokenizer = Tokenizer(models.BPE())

tokenizer.normalizer = normalizers.Sequence([
    NFD(),
    Lowercase()
])

# Our tokenizer also needs a pre-tokenizer responsible for converting the input to a ByteLevel representation.
tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
                                        ByteLevel(add_prefix_space=False),
                                        Whitespace()
                                        ])

tokenizer.decoder = ByteLevelDecoder()

We define the vocabulary size of the tokenizer, add the special tokens and define the initial alphabet. The provided batch iterator trains the tokenizer from our streaming data.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode

def batch_iterator(batch_size=10):
    for _ in tqdm(range(0, round(train_length,-1), batch_size)):
        yield [next(iter_dataset)['text'] for _ in range(batch_size)]


vocab_size=32768

byte_to_unicode_map = bytes_to_unicode()
unicode_to_byte_map = dict((v, k) for k, v in byte_to_unicode_map.items())
base_vocab = list(unicode_to_byte_map.keys())

trainer = trainers.BpeTrainer(vocab_size=vocab_size,
                              show_progress=True,
                              initial_alphabet=base_vocab,
                              special_tokens=["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]"])

iter_dataset = iter(train_dataset)

tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)

Here is an example of the tokenizer output

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
output = tokenizer.encode("Die Authentifikation wird mit OpenLDAP erledigt, die Einrichtung oder Aktualisierung von Systemen mit Yum und Servern für DNS, DHCP und TFTP.")

print(output.tokens)

>> ['die', 'Ġauthent', 'if', 'ikation', 'Ġwird', 'Ġmit', 'Ġopen', 'l', 'da', 'p', 'Ġerledigt', ',', 'Ġdie', 'Ġeinrichtung', 'Ġoder', 'Ġaktualisierung', 'Ġvon', 'Ġsystemen', 'Ġmit', 'Ġy', 'um', 'Ġund', 'Ġservern', 'ĠfÃ', '¼', 'r', 'Ġdns', ',', 'Ġd', 'hc', 'p', 'Ġund', 'Ġt', 'ft', 'p', '.']


print(output.ids)

>> [373, 12466, 997, 2887, 468, 341, 4256, 80, 609, 84, 11738, 16, 282, 9128, 550, 19260, 355, 18058, 341, 1312, 349, 309, 20238, 348, 125, 86, 31306, 16, 264, 25171, 84, 309, 328, 367, 84, 18]

Data pipeline

The training data is stored in multiple parquet files and split into a training and evaluation dataset in a preprocessing step. I used a train-test split of 0.01. Since the data doesn’t fit into memory, the data is streaming from disk. The text will be padded or truncated to the defined context length. The data collator for masked language modeling masks the incoming text data to allow the model training.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling


train_dataset = load_dataset(
    'parquet', data_dir="train_data/",
    streaming=True,
    split="train"
    )

eval_dataset = load_dataset(
    'parquet', data_dir="eval_data/",
    streaming=True,
    split="train"
    )

train_dataset = train_dataset.map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'],
        padding=True,
        truncation=True,
        max_length=max_length),
    batched=True)

eval_dataset = eval_dataset.map(
    lambda x: tokenizer.batch_encode_plus(
        x['text'],
        padding=True,
        truncation=True,
        max_length=max_length),
    batched=True)

train_dataset = train_dataset.with_format("torch")
eval_dataset = eval_dataset.with_format("torch")

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=use_mlm, mlm_probability=mlm_probability
)

Model Training

So far, the data is processed, the data streaming is set up and a tokenizer is trained. Finally the model training can start. I follow the BERT architecture Devlin et al. 2019 and use their initial setup and hyperparameters. The model is trained via masked language modelling, where 20 % of the tokens will be randomly masked. From those 20% of masked tokens, 80 % will be untouched, 10 % will be replaced with random tokens and 10 % will be replaced with the original tokens. Hugging Face provides an implementation for it. Wettig et al. 2023 scrutinized the impact of the mlm parameters towards the model result.

Here is an example, which shows some randomly mask tokens from an incoming text. The model is trained on predicting the masked tokens based on the context of the whole sentence.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
print(f"Mask Token id: {tokenizer.mask_token_id}")
>> Mask Token id: 4

output = tokenizer.encode("Die Authentifikation wird mit OpenLDAP erledigt, die Einrichtung oder Aktualisierung von Systemen mit Yum und Servern für DNS, DHCP und TFTP.")

masked = random.sample(range(0, 36), 7)
for mask in masked:
    output[mask] = tokenizer.mask_token_id

print(f"Masked encoding: {tokenizer.decode(output)}")
>> Masked encoding: [MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f[MASK]r dns, dhcp und tftp.

I’m not a big fan of using too many libraries, but I didn’t have enough time to set up the BERT model with Pytorch. I go the happy dependancy path and use the transformer library. Probably, I will create another post, where I describe the transition from the transformer library to plain pytorch.

I use the standard BERT configuration with eight attention layers with eight attention heads for each layer. A context size of 512 will truncate multiple invoices, but some experiments indicate that the overall effect can be neglected on the model performance . To understand the attention mechanism better, please follow my short blog post.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import BertConfig, BertForMaskedLM

use_mlm = True
mlm_probability = 0.2 # still keeping the 80 - 10 - 10 rule

max_length=512
block_size = 512
max_position_embeddings = 512
hidden_size = 768
num_hidden_layers = 8
num_attention_heads = 8
intermediate_size = 3072
drop_out = 0.1


config = BertConfig(
#   attention_window = [block_size]*num_attention_heads,
#   mask_token_id = 4,
    bos_token_id = 1,
    sep_token_id = 2,
#   pad_token_id = 3,
    eos_token_id = 2,

    max_position_embeddings = max_position_embeddings,

    hidden_size = hidden_size,
    num_hidden_layers = num_hidden_layers,
    num_attention_heads = num_attention_heads,
    intermediate_size = intermediate_size,

    hidden_act = 'gelu',
    hidden_dropout_prob = drop_out,
    attention_probs_dropout_prob = drop_out,

    type_vocab_size = 2,
    initializer_range = 0.02,
    layer_norm_eps = 1e-12,

    vocab_size = vocab_size,

    use_cache = True,
    classifier_dropout = None,
    onnx_export = False)


model = BertForMaskedLM(config=config)
print(f"n of parameters: {model.num_parameters():_}")
>> n of parameters: 82_820_774

The model will use 82 million parameters. Depending on the data size and GPUs, it will train less than 1,5 weeks on 4x T4 GPUs. The model train for five epochs with the AdamW optimizer Loshchilov & Hutter 2019 and used the learning rate published in the BERT paper with the same weight decay parameters. The batch size is optimized for maximum utilization of the GPU memory. The gradient accumulation step updates the model weights with a batch size of 64. To speed up training, we use fp16.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback


learning_rate = 1e-4 # bert
weight_decay = 1e-2 # bert
lr_scheduler_type = "linear"

num_train_epochs = 5
train_batch_size = 32
eval_batch_size = 32

gradient_accumulation_steps=2
eval_accumulation_steps=2

warmup_steps = 1_000

adam_beta1 = 0.9 # bert
adam_beta2 = 0.999 # bert
adam_epsilon = 1e-8 # bert
max_grad_norm = 1.0 # bert

max_steps=num_train_epochs*train_length//train_batch_size

training_args = TrainingArguments(
    output_dir=model_path,
    overwrite_output_dir=True,

    learning_rate=learning_rate,
    weight_decay=weight_decay,
    lr_scheduler_type=lr_scheduler_type,
    num_train_epochs=num_train_epochs,
    adam_beta1=adam_beta1,
    adam_beta2=adam_beta2,
    adam_epsilon=adam_epsilon,
    max_grad_norm=max_grad_norm,

    evaluation_strategy="steps",
    eval_steps=5_000,
    max_steps=max_steps,

    per_device_train_batch_size=train_batch_size, # depends on memory
    per_device_eval_batch_size=eval_batch_size,

    gradient_accumulation_steps=gradient_accumulation_steps,

    save_strategy="steps",
    save_steps=5_000,
    save_total_limit=3,

    prediction_loss_only=False,
    report_to="tensorboard",

    log_level="warning",
    logging_strategy="steps",

    fp16 = True,
    fp16_full_eval=True,

    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,

    push_to_hub=False,
    dataloader_pin_memory=True,
)

early_stopping = EarlyStoppingCallback(early_stopping_patience = 3,
                                       early_stopping_threshold = 0.02)

callbacks = [early_stopping]

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
#     compute_metrics=compute_metrics,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    callbacks=callbacks
)

Finally, everything is set up, and we can train our model. Depending on the data, model, and budget size, you can enjoy your holidays, and hopefully, the model training is finished, when you come back.

1
2
3
trainer.train()

trainer.save_model(f"{model_path}/main/")

As a final step, we can evaluate the model output. Since I can’t share any data, I use the output from my Kaggle notebook. For

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from transformers import pipeline

original_text = "Die Authentifikation wird mit OpenLDAP erledigt, die Einrichtung oder Aktualisierung von Systemen mit Yum und Servern für DNS, DHCP und TFTP.")

masked_text = "[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f[MASK]r dns, dhcp und tftp."

mask_filler = pipeline("fill-mask",f"{model_path}/main/")
mask_filler(tokenizer.decode(output), top_k=3)

>> [[{'score': 0.8014500737190247,
>>    'token': 373,
>>    'token_str': 'die',
>>    'sequence': 'die[MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.08444128930568695,
>>    'token': 517,
>>    'token_str': 'eine',
>>    'sequence': 'eine[MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.028047803789377213,
>>    'token': 1384,
>>    'token_str': 'diese',
>>    'sequence': 'diese[MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.03110463358461857,
>>    'token': 4354,
>>    'token_str': 'zert',
>>    'sequence': '[MASK]zertifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.030949348583817482,
>>    'token': 1160,
>>    'token_str': ' ant',
>>    'sequence': '[MASK] antifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.02660573646426201,
>>    'token': 12466,
>>    'token_str': ' authent',
>>    'sequence': '[MASK] authentifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.021829063072800636,
>>    'token': 1202,
>>    'token_str': ' per',
>>    'sequence': '[MASK][MASK]ifikation wird mit perldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.018226258456707,
>>    'token': 307,
>>    'token_str': ' p',
>>    'sequence': '[MASK][MASK]ifikation wird mit pldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.013632726855576038,
>>    'token': 276,
>>    'token_str': ' m',
>>    'sequence': '[MASK][MASK]ifikation wird mit mldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.46616730093955994,
>>    'token': 489,
>>    'token_str': ' einer',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung einer aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.26918938755989075,
>>    'token': 288,
>>    'token_str': ' der',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung der aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.05999871343374252,
>>    'token': 533,
>>    'token_str': ' zur',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung zur aktualisierung von systemen mit[MASK]um und[MASK] f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.19769684970378876,
>>    'token': 2150,
>>    'token_str': ' nov',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit novum und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.03822920098900795,
>>    'token': 1504,
>>    'token_str': ' vol',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit volum und[MASK] f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.030768193304538727,
>>    'token': 17401,
>>    'token_str': ' quant',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit quantum und[MASK] f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.0999603122472763,
>>    'token': 386,
>>    'token_str': ' dem',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und dem f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.07762913405895233,
>>    'token': 288,
>>    'token_str': ' der',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und der f�[MASK]r dns, dhcp und tftp.'},
>>   {'score': 0.055170487612485886,
>>    'token': 332,
>>    'token_str': ' den',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und den f�[MASK]r dns, dhcp und tftp.'}],
>>  [{'score': 0.690095841884613,
>>    'token': 125,
>>    'token_str': '�',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f��r dns, dhcp und tftp.'},
>>   {'score': 0.029322009533643723,
>>    'token': 12,
>>    'token_str': '(',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f�(r dns, dhcp und tftp.'},
>>   {'score': 0.01887478120625019,
>>    'token': 4585,
>>    'token_str': ' pet',
>>    'sequence': '[MASK][MASK]ifikation wird mit[MASK]ldap erledigt, die einrichtung[MASK] aktualisierung von systemen mit[MASK]um und[MASK] f� petr dns, dhcp und tftp.'}]]

Fine-tuning

For fine-tuning the language model, you can use the script above. The pre-trained model weights can be loaded into a classification model. The BertForSequenceClassification changes only the head from a MaskedLMHead to a ClassifierHead. All the other model weights will stay the same. Also, the data collator has to be adapted, and we output some metrics for the evaluation. That’s all.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from transformers import BertForSequenceClassification, DataCollatorWithPadding

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = accuracy_score(y_true=labels, y_pred=predictions)
    recall = recall_score(y_true=labels, y_pred=predictions, average='weighted')
    precision = precision_score(y_true=labels, y_pred=predictions, average='weighted')
    f1 = f1_score(y_true=labels, y_pred=predictions, average='weighted')
    return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1}


data_collator = DataCollatorWithPadding(tokenizer=tokenizer, pad_to_multiple_of=block_size)
model = BertForSequenceClassification.from_pretrained(model_path, num_labels=num_label)

Thank you for your attention.