Hi,

This post is about my implementation of an encoder transformer network from scratch as a follow-up of understanding the attention layer together with the colab implementation. I use a simplified dataset, where I don’t expect great results. My approach is building something from scratch to understand it in depth. I faced many challenges during my implementation, so I aligned my code to the BertSequenceClassifier from huggingface. My biggest challenge was to get the network to train. This challenge took me several months of low focus and a proper de- and reconstruction of the architecture. Minor issues were missing skip connections and some data preparation issues.

Even though this post is six years too late, Transformers are transforming the world via ChatGPT, Bart, or LLama. The core of the transformer architecture is the self-attention layer. For a visual explanation of the transformer, look at the great post from Jay Alammar. Please check Andrej Karpathy’s video for the full implementation of a transformer from scratch. Other implementations are:

Transformer

The model is separated into multiple parts by their functions. The Encoder Transformer consists of:

  • Embeddings
  • Encoder
  • Pooler
  • Classifier

Embeddings transform the initial text tokens, the token position, and the token type into its vector representation. The Encoder represents the attention mechanism. An EncoderTransformer simplifies the architecture. A mask is generally only needed in the Decoder. The Pooler takes the cls token and adds a fully connected layer. The classifier is just a fully connected layer.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Transformer(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.model_type = 'Transformer'
        self.config = config
        self.embeddings = Embeddings(self.config)
        self.encoder = Encoder(self.config)
        self.pooler = Pooler(self.config)
        self.dropout = nn.Dropout(self.config.dropout)
        self.classifier = nn.Linear(self.config.hidden_dim, self.config.n_classes)

    def forward(self, x, mask = None):

        embedding_output = self.embeddings(x)
        encoder_outputs = self.encoder(embedding_output)
        first_token_tensor = self.pooler(encoder_outputs)

        first_token_tensor = self.dropout(first_token_tensor)
        logits = self.classifier(first_token_tensor)

        return logits

    def num_parameters(self):
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

Embedding

There is a word, positional, and token_type embedding. As a first implementation, we use the pytorch embedding lookup table. For an accurate code of the positional encoding, please have a look here or here or here. The huggingface implementation is a simplification for the positional embeddings. Here is an explanation for the differences between the huggingface implementation and the original Bert implementation. The token_type_embedding has no proper function, and its origin results in an internal huggingface compatibility issue.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class Embeddings(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_dim, padding_idx=config.pad_token_id)
        self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_dim)
        self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_dim)

        self.layernorm = nn.LayerNorm(config.hidden_dim, eps=config.eps)
        self.dropout = nn.Dropout(config.dropout)

        self.register_buffer(
            "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
        )
        self.register_buffer(
            "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
        )

    def forward(self, input_ids):

        input_shape = input_ids.size()
        seq_length = input_shape[1]

        position_ids = self.position_ids[:, : seq_length]
        token_type_ids = self.token_type_ids[:, :seq_length].expand(input_shape[0], seq_length)


        inputs_embeds = self.word_embeddings(input_ids)
        token_type_embeddings = self.token_type_embeddings(token_type_ids)
        position_embeddings = self.position_embeddings(position_ids)

        embeddings = inputs_embeds + token_type_embeddings + position_embeddings

        embeddings = self.layernorm(embeddings)
        embeddings = self.dropout(embeddings)

        return embeddings

Encoder

The Encoder stacks multiple AttentionBlocks. An AttiontionBlock consists of multiple modules:

  • MultiHeadAttention
  • AttentionLayerOutput
  • IntermediateLayer
  • AttentionBlockOutput ​ The original paper uses a skip connection for every second of AttentionBlock. The huggingface implementation adds a skip connection after each AttentionBlock. A skip connection or residual connection allows the training of deeper networks without information loss. They add the residuals to the original input vector and skip this layer.
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layers = nn.ModuleList([AttentionBlock(config) for _ in range(config.num_hidden_layers)])


    def forward(self, hidden_states):
        for i, layer in enumerate(self.layers):
            hidden_states = layer(hidden_states)

        return hidden_states


class AttentionBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.hidden_dim % config.n_heads == 0
        self.head_dim = config.hidden_dim // config.n_heads

        self.attention_layer = SimplisticMultiHead(config.hidden_dim, config.n_heads, self.head_dim)
        self.layer_output = AttentionLayerOutput(config)
        self.intermediate = AttentionIntermediate(config)
        self.block_output = AttentionBlockOutput(config)

    def forward(self, hidden_states):
        attention_output = self.attention_layer(hidden_states)
        attention_output = self.layer_output(attention_output)
        output = attention_output + hidden_states # skip connection
        output = self.intermediate(output)
        output = self.block_output(output)
        return output

AttentionHead

The “Attention is all you need” paper proposed multi-head attention. This simple mechanism allows the model to learn a different representation of the same incoming weights. Additional multi-head attention enables the calculation of the attention weights in parallel. After the multiplication, the multi-head attention scores are concatenated to one vector.

The scaled-dot product attention uses the same tensor as query, key, and value for their inner representations. The first matrix multiplication calculates the similarity between the query variable and the key variable, and the square root of the initial head dimension scales the resulting score. After a softmax transformation another matrix multiplication with the value variable calculates the final attention score.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class SimplisticMultiHead(torch.nn.Module):

    def __init__(self, hidden_dim, n_heads, head_dim):
        super().__init__()

        self.heads = torch.nn.ModuleList([SimplisticHead(hidden_dim, head_dim) for _ in range(n_heads)])

    def forward(self, x):

        out = torch.cat([h(x) for h in self.heads], dim=-1)
        return out


class SimplisticHead(torch.nn.Module):

    def __init__(self, hidden_dim, head_dim):

        super().__init__()

        self.query = torch.nn.Linear(hidden_dim, head_dim, bias=True)
        self.key = torch.nn.Linear(hidden_dim, head_dim, bias=True)
        self.value = torch.nn.Linear(hidden_dim, head_dim, bias=True)
        self.d_k = torch.Tensor([head_dim]).to(device)

    def forward(self, x):

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        score =  query @ key.transpose(-2, -1) / torch.sqrt(self.d_k)

        score = F.softmax(score, dim=-1)
        out = score @ value

        return out

AttentionLayers

I follow the huggingface implementation, although I don’t fully understand the structure of the Intermediate and Output Layers. I would expect some activation functions between the fully connected layers and not only a layer normalization layer. Also, this Intermediate Layer seems missing in the original BERT implementation. GeLU is used as an activation function. Compared to ReLU, it is a bit smoother.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class AttentionIntermediate(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_dim, config.intermediate_size)
        self.gelu = nn.GELU()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dense(hidden_states)
        hidden_states = self.gelu(hidden_states)

        return hidden_states


class AttentionBlockOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.intermediate_size, config.hidden_dim)
        self.layernorm = nn.LayerNorm(config.hidden_dim, eps=config.eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states)
        return hidden_states


class AttentionLayerOutput(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.layernorm = nn.LayerNorm(config.hidden_dim, eps=config.eps)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.layernorm(hidden_states)
        return hidden_states
        return output

Pooler

The Pooler module picks only the CLS token for our classification task and performs a linear transformation combined with a tanh activation. The tokenizer adds the CLS token to the text during the tokenization step (not shown here).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
class Pooler(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_dim, config.hidden_dim)
        self.activation = nn.Tanh()

    def forward(self, hidden_states):

        first_token_tensor = hidden_states[:, 0]

        first_token_tensor = self.dense(first_token_tensor)
        first_token_tensor = self.activatuion(first_token_tensor)

        return first_token_tensor

All those explained parts combine to a simplified implementation of a transformer network. Thank you for your attention.