๐Ÿ“š Transformer Implementation

Course Link

This document reviews the main themes and key takeaways from Deep Learning Systems: Algorithms and Implementation** at Carnegie Mellon University, taught by J. Zico Kolter and Tianqi Chen.


This document details the implementation of a Transformer model using NumPy, comparing it to PyTorchโ€™s implementation and explaining the underlying concepts.

๐Ÿ”‘ Key Aspects of Transformer Implementation

๐ŸŒ€ Self-Attention Mechanism

  • Core Operation: Self-attention calculates output (Y) from input (X).
  • Formula:
    $[ Y = \text{softmax}\left(\frac{(X W_K)(W_Q X)^T}{\sqrt{d}}\right)(X W_V) ]$
  • Details:
    • $( W_K, W_Q, W_V )$ โ€“ Weight matrices for Key, Query, and Value.
    • $( W_{\text{out}} )$ โ€“ Additional linear projection.
    • Parameters: Weights are trainable, biases are often omitted for simplicity.

๐Ÿงฎ Softmax Implementation (NumPy)

  • Softmax:
    $[ \text{softmax}(z) = \frac{e^{z - \max(z)}}{\sum e^{z - \max(z)}} ]$
    • Normalizes along the last dimension.

๐Ÿงฉ Self-Attention Layer

  • Takes:
    • Input $( X )$
    • Mask
    • Weights $( W_K, W_Q, W_V, W_{\text{out}} )$
  • Process:
    1. Form $( K, Q, V )$ by multiplying $( X )$ with $( W_K, W_Q, W_V )$.
    2. Compute attention weights using:
      $[ \text{softmax}\left(\frac{K Q^T}{\sqrt{D}}\right) ]$
    3. Apply mask.
    4. Multiply by $( V )$, then project with $( W_{\text{out}} )$.
import numpy as np
import torch
import torch.nn as nn

def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)
    
def self_attention(X, mask, W_KQV, W_out):
    """
    # Input (X): Shape (B, T, d)
    # B = 1 (Batch size)
    # T = 100 (Sequence length)
    # d = 64 (Embedding dimension)

    # Mask (mask): Shape (T, T)
    # A triangular mask prevents attention to future tokens.

    # Weights (W_KQV): Shape (d, 3d)
    # Projects X into key (K), query (Q), and value (V) matrices.
    # Split along the last axis into three matrices: K, Q, V.

    # Output:
    # K, Q, V โ€“ Shape (B, T, d)
    # attn โ€“ Shape (B, T, T) (Attention weights)
    # Final output shape โ€“ (B, T, d)

    # Matrix Operations:
    # X @ W_KQV projects X into 3d dimensions, then splits into K, Q, V.
    # K @ Q^T โ€“ Similar to a correlation matrix, comparing queries and keys.
    # swapaxes(-1, -2) swaps the last two dimensions, performing Q^T (transpose of Q).
    # Mask is added to prevent attention to future steps.
    """
    K,Q,V = np.split(X@W_KQV, 3, axis=-1)
    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(X.shape[-1]) + mask)
    return attn@V@W_out, attn


T = 5
M = torch.triu(-float("inf")*torch.ones(T,T),1)
"""
tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])
"""

T, d = 100, 64
M = torch.triu(-float("inf")*torch.ones(T,T),1)
X = torch.randn(1,T,d)

# pytorch version
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
Y_, A_ = attn(X,X,X, attn_mask=M)

# numpy version
Y, A = self_attention(
  X[0].numpy(), M.numpy(), 
  attn.in_proj_weight.detach().numpy().T,
  attn.out_proj.weight.detach().numpy().T
)

print(np.linalg.norm(A - A_[0].detach().numpy()))
print(np.linalg.norm(Y - Y_[0].detach().numpy()))

๐Ÿ†š PyTorch Comparison

  • PyTorchโ€™s multi-head attention module produces similar outputs to the NumPy implementation.
  • Key Points:
    • PyTorch stores $( K, Q, V )$ weights together.
    • Masking uses negative infinities to prevent attention to future tokens.
    • Supports batch-first and time-first formats.

1. in_proj_weight:

  • Purpose: This is a single weight matrix that combines the projection for key (K), query (Q), and value (V). It is used to transform the input tensor into the three different spaces (K, Q, and V) that are used in multi-head attention. ๐Ÿ”‘๐Ÿ”๐Ÿ’ก

  • Shape: (d, 3d) where d is the input/output feature dimension. This matrix has three blocks of d-dimensional space (for K, Q, and V), hence the shape (d, 3d). ๐Ÿงฎ

  • How it works: The input tensor X (of shape (N, T, d)) is multiplied by in_proj_weight to produce a tensor of shape (N, T, 3d). This tensor is then split along the last axis into three parts: the key (K), query (Q), and value (V), each of shape (N, T, d). ๐Ÿ”„

  • Why itโ€™s needed:

    • In multi-head attention, we need to compute the attention scores between the queries and keys, and then use the values to compute the final attention output. ๐Ÿ’ญ๐ŸŽฏ
    • This matrix performs the linear transformation (projection) that converts the input tensor X into the necessary components (K, Q, V) for attention. Without it, we wouldnโ€™t have the separate K, Q, and V tensors needed for the attention mechanism. ๐ŸšซโŒ

2. out_proj_weight:

  • Purpose: This matrix is used to transform the final attention output (after attention computation) into the desired output space. ๐Ÿ”„โžก๏ธ

  • Shape: (d, d) where d is the feature dimension. It is used to map the output of the multi-head attention back to the original d-dimensional space. ๐Ÿงฎ

  • How it works: After computing the attention-weighted sum of the values V for each head, the output tensor is reshaped and multiplied by out_proj_weight to produce the final output tensor of shape (N, T, d). ๐Ÿ”ข

  • Why itโ€™s needed:

    • The output of the attention layer consists of multiple heads, each with reduced dimensions. The purpose of out_proj_weight is to combine the outputs of all attention heads (after reshaping and concatenation) and project them back to the original dimension d. โœ‚๏ธ๐ŸŒ€
    • Without this weight matrix, the multi-head outputs would not be properly combined or mapped back to the correct output space. ๐ŸšซโŒ

Why We Need These Weights in the NumPy Implementation:

In the PyTorch nn.MultiheadAttention module, the internal computations (like multi-head attention, splitting of K, Q, V, etc.) are handled automatically. However, when implementing the same logic from scratch in NumPy, we must manually perform these steps. Specifically:

  • Projection of X into K, Q, and V: We use in_proj_weight to compute the projections from the input tensor X into the K, Q, and V spaces. Without this matrix, we wouldnโ€™t be able to create these key components needed for attention calculation. ๐Ÿ”„
  • Final Output Projection: We use out_proj_weight to project the final attention output into the original feature space. Without this matrix, weโ€™d be left with attention results that arenโ€™t properly mapped back to the required output dimensions. ๐Ÿง‘โ€๐Ÿ’ป๐ŸŽฏ

In short, in_proj_weight and out_proj_weight are essential for both the transformation of the input into the K, Q, V components and for mapping the output back to the original feature space after attention. These weights encapsulate the necessary linear projections in the attention mechanism, which we have to manually implement in NumPy. ๐Ÿ› ๏ธ


๐Ÿ“ฆ Mini-Batching

  • Format:
    • Transformers prefer batch x time x dimension format.
    • RNNs typically use time x batch x dimension.
  • Efficiency:
    • NumPy implementation reshapes tensors for batched matrix multiplication (bmm).
    • Higher-order tensors are flattened for standard multiplication.

What is Batch Matrix Multiplication? When working with neural networks, especially transformers, itโ€™s common to process multiple samples at once โ€“ a process known as minibatching. During self-attention, for each sample in the minibatch, matrix multiplications must be performed independently. This operation, performed for every sample in the batch, is called batch matrix multiplication.

Why is it Different from Regular Matrix Multiplication? In a regular multi-layer perceptron (MLP) or convolutional neural network (CNN), a batched input tensor is multiplied by a single weight matrix. This is essentially stacking matrices and applying standard matrix multiplication. However, in transformers, each sample requires multiplication by its own matrices (like K, Q, V for each batch). This is where true batch matrix multiplication comes in โ€“ multiplying a batch of matrices by another batch of matrices.

import numpy as np

# Generate random tensor B with shape (10, 3, 5, 4)
# Interpretation:
# - Batch size (10) โ†’ 10 matrices
# - Each matrix has shape (3, 5, 4)
B = np.random.randn(10, 3, 5, 4)

# Generate random tensor C with shape (10, 3, 4, 3)
# Interpretation:
# - Batch size (10) โ†’ 10 matrices
# - Each matrix has shape (3, 4, 3)
C = np.random.randn(10, 3, 4, 3)

# Perform batch matrix multiplication B @ C
# This line multiplies corresponding matrices in B and C:
# B[0] @ C[0], B[1] @ C[1], ..., B[9] @ C[9]
result = B @ C

# Result shape: (10, 3, 5, 3)
# - Batch size (10)
# - Each resulting matrix has shape (3, 5, 3)

# Conceptually equivalent to:
# result = np.zeros((10, 3, 5, 3))  # Preallocate space
# for i in range(10):               # Loop over each batch
#     result[i] = B[i] @ C[i]       # Multiply matrices individually

# No explicit for-loop is written by you; NumPy automatically
# performs the operation in parallel using optimized internal code.

Above implementation already supports bmm

N = 10
M = torch.triu(-float("inf")*torch.ones(T,T),1)
X = torch.randn(N,T,d)
Y_, A_ = attn(X,X,X, attn_mask=M)

Y, A = self_attention(
  X.numpy(), M.numpy(),
  attn.in_proj_weight.detach().numpy().T, 
  attn.out_proj.weight.detach().numpy().T)

print(np.linalg.norm(A - A_.detach().numpy()))
print(np.linalg.norm(Y - Y_.detach().numpy()))

๐ŸŒ Multi-Head Attention

  • Concept: Apply self-attention multiple times in parallel.
  • Process:
    1. Split $( K, Q, V )$ into multiple heads $( K_1, K_2, โ€ฆ )$.
    2. Perform self-attention for each head.
    3. Concatenate outputs and project with $( W_{\text{out}} )$.
  • Scaling: $( D )$ is divided by the number of heads $( H )$:
    $[ \frac{D}{H} ]$
  • PyTorch: Returns average attention across all heads, whereas NumPy returns individual head outputs.
def multihead_attention(X, mask, heads, W_KQV, W_out):
    # X shape: (N, T, d) where:
    # N = batch size, T = sequence length, d = input dimension
    N, T, d = X.shape

    # Project the input X into K, Q, and V by multiplying with W_KQV
    # X@W_KQV results in a tensor of shape (N, T, 3d) because we have 3 parts for K, Q, V
    K, Q, V = np.split(X @ W_KQV, 3, axis=-1)

    # Reshape K, Q, and V for multi-head attention:
    # We split the last dimension d into 'heads' number of heads, 
    # so each head will get d // heads dimensions (d is divisible by heads).
    # Each of K, Q, and V will now have shape (N, T, heads, d // heads)
    # Then we swap axes to get shape (N, heads, T, d // heads) for easier matrix multiplication.
    K, Q, V = [a.reshape(N, T, heads, d // heads).swapaxes(1, 2) for a in (K, Q, V)]
    
    # Compute attention scores:
    # First, K @ Q^T gives us a tensor of shape (N, heads, T, T), which represents
    # the pairwise similarity scores between each query (Q) and key (K).
    # We scale by sqrt(d//heads) and add the mask to prevent attention to future tokens.
    attn = softmax(K @ Q.swapaxes(-1, -2) / np.sqrt(d // heads) + mask)

    # Perform the attention operation:
    # Multiply the attention scores (attn) with the values (V), then swap axes
    # and reshape back to the original shape (N, T, d) before multiplying with W_out.
    return (attn @ V).swapaxes(1, 2).reshape(N, T, d) @ W_out, attn

# Example usage:
# 'heads' defines the number of attention heads.
heads = 4

# PyTorch MultiheadAttention module example:
# X (batch, time, dimension), M (mask)
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask=M)

# Custom numpy multihead_attention:
Y, A = multihead_attention(
  X.numpy(), M.numpy(), 4,  # input tensor, mask, number of heads
  attn.in_proj_weight.detach().numpy().T,  # in_proj_weight is (3d, d)
  attn.out_proj.weight.detach().numpy().T  # out_proj.weight is (d, d)
)

# Compute and print the norm of the differences between PyTorch and custom results:
print(np.linalg.norm(Y - Y_.detach().numpy()))  # Difference in outputs
print(np.linalg.norm(A.mean(1) - A_.detach().numpy()))  # Difference in attention scores averaged across heads



๐Ÿ—๏ธ Transformer Block

  • Structure:
    1. Multi-head attention with residual connection.
    2. Layer normalization.
    3. Feed-forward network (two linear layers + ReLU).
    4. Second residual connection + normalization.
  • Layer Norm:
    • Subtract mean, divide by standard deviation.
    • Add small $( \epsilon )$ to avoid division by zero.
  • Reference: PyTorchโ€™s TransformerEncoderLayer is used as a benchmark.
def layer_norm(Z, eps):
    # Input:
    # Z - input tensor of shape (N, T, d) where:
    # N = batch size, T = sequence length, d = dimension of input
    # eps - small value added to avoid division by zero, usually very small like 1e-5
    
    # Layer normalization: 
    # For each feature dimension (axis=-1), subtract the mean and divide by the standard deviation.
    # The result is that each feature in the input will have mean 0 and variance 1.
    # This helps stabilize training by normalizing the input to each layer.
    return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)
    
def relu(Z):
    # Input:
    # Z - input tensor of shape (N, T, d)
    
    # ReLU activation: Replaces all negative values with 0.
    # This introduces non-linearity into the model.
    return np.maximum(Z, 0)

def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    # Input:
    # X - input tensor of shape (N, T, d)
    # mask - mask tensor of shape (T, T) for attention masking
    # heads - number of attention heads
    # W_KQV - weight matrix for key, query, and value, shape (d, 3d)
    # W_out - output weight matrix for multi-head attention, shape (d, d)
    # W_ff1 - weight matrix for the first feed-forward layer, shape (d, 128)
    # W_ff2 - weight matrix for the second feed-forward layer, shape (128, d)
    # eps - small value for layer normalization, e.g., 1e-5
    
    # Multi-head attention + residual connection:
    # Step 1: Compute multi-head attention using the input X, mask, and weight matrices W_KQV and W_out.
    # The output of multihead_attention is (N, T, d) shape.
    # Step 2: Add input X to the output of the attention and apply layer normalization.
    Z = layer_norm(multihead_attention(X, mask, heads, W_KQV, W_out)[0] + X, eps)
    
    # Feed-forward network with residual connection:
    # Step 3: Apply ReLU activation to the output of the previous layer, followed by matrix multiplication
    # with W_ff1, then apply W_ff2.
    # Step 4: Add the result to the layer normalized tensor from above and apply layer normalization again.
    return layer_norm(Z + relu(Z @ W_ff1) @ W_ff2, eps)

# Example usage:
# Input tensor X (batch, time, dimension)
# Mask tensor M (sequence length, sequence length)
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)

# Initialize biases of the feed-forward layers to zero
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()

# PyTorch Transformer Encoder Layer example:
Y_ = trans(X, M)

# Custom transformer implementation using numpy:
Y = transformer(
  X.numpy(), M.numpy(), heads,
  trans.self_attn.in_proj_weight.detach().numpy().T,  # W_KQV, shape (d, 3d)
  trans.self_attn.out_proj.weight.detach().numpy().T,  # W_out, shape (d, d)
  trans.linear1.weight.detach().numpy().T,  # W_ff1, shape (d, 128)
  trans.linear2.weight.detach().numpy().T,  # W_ff2, shape (128, d)
  trans.norm1.eps  # epsilon for layer normalization
)

# Compute the difference between PyTorch and custom transformer outputs:
print(np.linalg.norm(Y - Y_.detach().numpy()))


โœจ Overall Implementation

  • Entire Transformer block is implemented with minimal NumPy code, showcasing the simplicity and efficiency of matrix multiplications.