Transformer and LLM
Modern AI models are becoming increasingly large, demanding substantial computational resources and memory. This creates a gap between the computational demands of these models and the available hardware capabilities. Pruning addresses this gap by reducing model size, memory footprint, and ultimately, energy consumption.
This briefing reviews MIT 6.5940 EfficientML.ai Lecture 11, focusing on deploying neural networks on resource-constrained edge devices, such as microcontrollers and laptops.
🚀 Transformer Basics
The lecture begins with an overview of NLP task categories:
- 🔍 Discriminative tasks: Classifying sentences (e.g., sentiment analysis, text classification).
- 📝 Generative tasks: Generating new tokens (e.g., language modeling, machine translation).
⏳ Pre-Transformer Era
Before transformers, RNNs and CNNs were common in NLP but had limitations:
- RNNs struggled with long-term dependencies.
- CNNs provided limited context.
- Uni-/Bi-directional RNNs were used for specific encoding/decoding tasks.
🛑 Problems with RNNs/LSTMs
- ❌ Difficulty modeling long-term relationships in long sentences.
- ❌ Limited training parallelism due to sequential dependencies.
🏗️ Transformer Architecture

Transformers addressed RNN/CNN issues using:
- 🔡 Tokenizer: Converts words to tokens (e.g., “parallelism” → “parallel”, “ism”).
- 📈 Embedding: Maps tokens into vector representations.
- 🔗 Multi-Head Attention (MHA): Models relationships between all tokens (e.g., aligns words in machine translation).
- 📊 Feed-Forward Network (FFN): Adds non-linearity to capture local features.
- 🔄 Residual Connections: Improves gradient flow and stability.
- 📏 Positional Encoding: Encodes word order, essential due to permutation invariance.
- 🔚 Linear Head: Generates final predictions.
1️⃣ Tokenizer
The tokenizer is the first step in processing text, breaking sentences into smaller units called tokens.
- 🔹 What is a token?: A word, subword, or character, depending on the tokenizer.
- 🔹 Example: The word “unbreakable” might be tokenized into “un,” “break,” and “able.”
- 🔹 Purpose: Converts text into a numerical format the model can process.
2️⃣ Embedding
After tokenization, tokens are converted into vector representations called embeddings.
- 🌟 Purpose: Captures the meaning and context of tokens numerically.
- 🔹 Pre-trained embeddings: Like Word2Vec or GloVe.
- 🔹 End-to-end training: Allows the model to learn task-specific representations during training.
3️⃣ Transformer Block
The Transformer block is the backbone of the architecture. A stack of these blocks forms the model. Each block has key components:
🔗 Multi-Head Attention (MHA)
The heart of the Transformer block, MHA enables the model to:
- 👀 Attend to all tokens in the sequence simultaneously.
- 🌐 Capture long-range dependencies between words.
💡 How MHA Works:

- Projecting the input embeddings into three matrices:
- Query (Q), Key (K), and Value (V).
- Calculating attention scores:
- Dot product of Q and K → Scaling → Normalization.
- Softmax function: Converts scores into attention weights.
- Weighted sum: Aggregates relevant tokens using attention weights.
- Multi-head mechanism: Repeats this process with different learned projections, capturing diverse relationships.

⚡ Feed-Forward Network (FFN)
- 🌟 Adds non-linearity to process tokens independently.
- ⚙️ Typically consists of:
- Two fully connected layers.
- An activation function like ReLU or GeLU.
- 🔹 Enhances the model’s ability to capture complex patterns within tokens.

🎛️ Layer Normalization (LN)
- Purpose: Normalizes activations across features for each token.
- 📈 Improves training stability and speeds up convergence.

🔄 Residual Connections
- Purpose: Helps gradients flow directly through the network.
- 🌟 Prevents vanishing gradients and aids in training deep models.

4️⃣ Positional Encoding
- ⚠️ Problem: Attention is permutation-invariant and doesn’t consider token order.
- ✅ Solution: Positional encoding adds information about token positions.
- 🔹 Methods:
- Absolute: Fixed encoding for each position.
- Relative: Considers distances between tokens (e.g., ALiBi, RoPE).
5️⃣ Linear Head
After the Transformer blocks, the output is passed through:
- A linear layer.
- (Optional) A softmax function to produce predictions.
🔹 Example:
In machine translation, the linear head predicts a probability distribution over the target vocabulary.
🎨 Transformer Design Variants
The Transformer architecture has evolved significantly, with numerous optimizations and improvements pushing the boundaries of Natural Language Processing (NLP). Let’s dive into some of the key design variants and their benefits, with examples to clarify each concept. Innovations include:
- 🧩 Encoder-Decoder, Encoder-Only, Decoder-Only: Variants like BERT (encoder-only) and GPT (decoder-only).
- 📐 Relative Positional Encoding: Replaces absolute encoding for better generalization (e.g., ALiBi, RoPE).
- 📂 KV Cache Optimizations: Reduces memory usage in long sequences (e.g., MQA, GQA).
- 🚪 FFN → GLU: Improves performance using Gated Linear Units (GLU).
1️⃣ Relative Positional Encoding: Rotary Positional Embedding (RoPE)
Rotary Positional Embedding (RoPE) encodes positional information by rotating token embeddings in a 2D space. Unlike absolute positional encodings, RoPE focuses on relative distances between tokens, making it ideal for capturing sequence relationships.

- Embedding Splitting:
- An embedding vector of dimension $( d )$ is split into $( d/2 )$ pairs.
- Each pair represents a 2D coordinate: $( (x_i, y_i) )$.
- Rotation:
- Each 2D pair is rotated by a position-dependent angle using a rotation matrix.
Rotation Matrix for position $( m )$:
$[ R_m = \begin{bmatrix} \cos(m \theta) & -\sin(m \theta)
\sin(m \theta) & \cos(m \theta) \end{bmatrix} ]$- The rotation angle $( \theta_i )$ for each pair is determined by:
$[ \theta_i = 10000^{-2(i-1)/d} ]$
where $( i )$ is the dimension index.
- Final Transformation:
- For each pair $( (x_i, y_i) )$, the rotated coordinates are computed as:
$[ \begin{bmatrix} x_i’
y_i’ \end{bmatrix} = R_m \begin{bmatrix} x_i
y_i \end{bmatrix} = \begin{bmatrix} x_i \cos(m \theta) - y_i \sin(m \theta)
x_i \sin(m \theta) + y_i \cos(m \theta) \end{bmatrix} ]$
- For each pair $( (x_i, y_i) )$, the rotated coordinates are computed as:
🔍 Example Walkthrough
- Setup:
- Embedding dimension ($( d )$): 4
- Position ($( m )$): 3
The embedding $( \mathbf{v} )$ is:
$[ \mathbf{v} = [v_1, v_2, v_3, v_4] ]$
Split into pairs:
$[ (v_1, v_2), (v_3, v_4) ]$ -
Rotation Angles:
For $( i = 1, 2 )$:
$[ \theta_1 = 10000^{-2(1-1)/4} = 1, \quad \theta_2 = 10000^{-2(2-1)/4} ]$
Compute angles $( m \theta )$:
$[ m \theta_1 = 3 \cdot \theta_1, \quad m \theta_2 = 3 \cdot \theta_2 ]$ - Apply Rotation:
Rotate each pair:
$[ (v_1’, v_2’) = (v_1 \cos(m \theta_1) - v_2 \sin(m \theta_1), v_1 \sin(m \theta_1) + v_2 \cos(m \theta_1)) ]$
$[ (v_3’, v_4’) = (v_3 \cos(m \theta_2) - v_4 \sin(m \theta_2), v_3 \sin(m \theta_2) + v_4 \cos(m \theta_2)) ]$
🌀 Rotary Positional Embedding (RoPE) in Large Language Models (LLMs)
Rotary Positional Embedding (RoPE) is applied to the Query (Q) and Key (K) matrices in LLMs, while leaving the Value (V) matrix unaffected. Here’s an explanation of its working and how it helps extend the context window.
🔗 Encoding Relative Position
- Goal: Instead of adding positional information directly to embeddings, RoPE encodes the relative distance between tokens by rotating Q and K embeddings in a 2D space.
- The rotation angle depends on the token’s position in the sequence.
- Key Concept: The phase difference between the rotated Q and K embeddings represents the relative distance between tokens.
📖 Example
- Consider two words in a sentence: “cat” at position $(2)$ and “sat” at position $(5)$.
- RoPE rotates their Q and K embeddings based on their positions.
- The resulting phase difference between these embeddings represents the relative distance $(5 - 2 = 3)$ between the tokens.
🔍 Extending the Context Window
🧠 Challenge
- LLMs are trained with a limited context length (e.g., 2k tokens for LLaMA).
- During inference, processing longer sequences (e.g., an entire book) is often required.
🛠️ Solution with RoPE
- Rotation Interpolation:
- Rotation angles are defined during training for the original context length.
- To extend the context window, these angles are interpolated to handle longer sequences.
- Example:
- If the original context length is $(2k)$ tokens and we want to process $(4k)$ tokens:
- The rotation angles are scaled down (e.g., divided by $(2)$).
- This spreads the original rotations across the wider sequence, enabling the model to handle the extended length.
- If the original context length is $(2k)$ tokens and we want to process $(4k)$ tokens:
✅ Key Takeaways
- Encodes Relative Distance: RoPE captures the relative relationships between tokens rather than absolute positions.
- Extends Context Window: By interpolating rotation angles, RoPE allows LLMs to process sequences much longer than their training data.
- Phase Difference: The relative distance between tokens is embedded through the phase difference between rotated Q and K embeddings.
RoPE enables flexible and efficient processing of longer sequences, making it a crucial component in scaling LLMs for real-world tasks.
✅ Benefits of RoPE
- Captures Relative Positions: RoPE enables the model to focus on relationships between tokens rather than absolute positions.
- Handles Longer Contexts: RoPE can be interpolated to handle sequences longer than those seen during training.
- Maintains Embedding Properties: The rotation preserves the norm of embeddings, ensuring compatibility with downstream tasks.
2️⃣ KV Cache Optimizations: MQA and GQA
The KV cache stores keys (K) and values (V) from previous tokens during Transformer decoding, enabling efficient attention computation. However, large sequences lead to huge memory requirements.
The KV cache is a crucial optimization used in Transformer models during decoding (e.g., in text generation). It avoids redundant computations by storing intermediate outputs, enabling faster and more efficient predictions for long sequences.

🧠 What Happens in Decoding?
In decoding, a Transformer generates text one token at a time. For each new token, it uses all previously generated tokens to predict the next word. This involves the attention mechanism, which calculates relationships between tokens.
For a given token $( t )$, the attention mechanism works as follows:
- Inputs: Queries ($( Q )$), Keys ($( K )$), and Values ($( V )$).
- Attention Scores: Compute scores between $( Q )$ and $( K )$ as: $[ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V ]$ where $( d_k )$ is the dimensionality of the key vectors.
- Context Vector: Use the attention scores to compute a weighted sum of $( V )$, capturing information from relevant tokens.
🔍 Why are $( K )$ and $( V )$ recalculated without a KV cache?
In Transformer models, the $( K )$ and $( V )$ matrices are derived from the input embeddings at each layer during inference. Specifically:
Input Token Embedding
For a given token $( t )$, the model produces its embedding $( x_t )$. This embedding serves as the input to the Transformer.
Key and Value Projections
The embedding $( x_t )$ is multiplied by learnable weight matrices $( W_K )$ and $( W_V )$ to produce the $( K )$ and $( V )$ representations:
$[ K_t = x_t W_K, \quad V_t = x_t W_V ]$
These projections capture the token’s representation in the context of the attention mechanism.
Why Recalculate?
During decoding, if you don’t use a KV cache, the model must:
- Reprocess all previously generated tokens to produce their $( K )$ and $( V )$ matrices.
- This involves recomputing the embeddings for all tokens and applying the projections $( W_K )$ and $( W_V )$.
This redundancy can make decoding computationally expensive, especially for long sequences, which is why the KV cache is critical for efficiency.
Without the KV Cache:
- Recompute Keys and Values: For every new token, the model recalculates $( K )$ and $( V )$ for all previously processed tokens.
- Example: If the sequence is “The quick brown,” and the model predicts the next word, it recalculates $( K, V )$ for “The,” “quick,” and “brown” every time.
- Attention Scores: The model computes attention scores between the new query $( Q )$ and all recalculated $( K )$.
With the KV Cache:
The KV cache stores the $( K )$ and $( V )$ matrices for all previously processed tokens. When a new token is generated:
- Query Computation: Compute $( Q )$ for the new token only.
- Retrieve $( K, V )$ from Cache: Access the pre-computed $( K, V )$ matrices for all previous tokens from the cache.
- Attention Scores: Compute attention between $( Q )$ and the cached $( K )$ as: $[ \text{Attention}(Q_t, K_{\text{cached}}, V_{\text{cached}}) = \text{softmax}\left(\frac{Q_t K_{\text{cached}}^T}{\sqrt{d_k}}\right) V_{\text{cached}} ]$
- Update Cache: Add the new $( K )$ and $( V )$ for the current token to the cache.
📖 Step-by-Step Example
Initial State:
Sequence: "The quick brown"
- Stored in KV Cache:
- Keys ($( K_{\text{cached}} )$): Representations of
"The,""quick,"and"brown." - Values ($( V_{\text{cached}} )$): Contextual information for
"The,""quick,"and"brown."
- Keys ($( K_{\text{cached}} )$): Representations of
Predicting the Next Token:
- Compute $( Q_t )$: Calculate the query for the new token (e.g., predicting
"fox"). - Retrieve from Cache: Use cached $( K )$ and $( V )$ for
"The,""quick,"and"brown." - Attention Scores: Compute attention scores between $( Q_t )$ and $( K_{\text{cached}} )$:
$[
\text{Attention}(Q_t, K_{\text{cached}}, V_{\text{cached}})
]$
- These scores capture how much the new token should attend to each previous token.
- Weighted Sum: Use attention scores to compute a weighted sum of $( V_{\text{cached}} )$, giving the context for predicting the next word.
- Update Cache: Add the new token’s $( K_t )$ and $( V_t )$ to the cache.
🚀 Benefits of the KV Cache
- ⚡ Faster Decoding:
- Avoids recalculating $( K )$ and $( V )$ for previous tokens, saving computation time.
- 🧠 Efficient Memory Usage:
- Stores only the $( K )$ and $( V )$ matrices, making long-sequence processing feasible.
- 📏 Handles Long Contexts:
- Enables Transformer models to process thousands of tokens without significant slowdowns.
✅ Key Takeaways
- The KV cache stores keys ($( K )$) and values ($( V )$) for all previously processed tokens.
- When decoding new tokens:
- The model retrieves $( K )$ and $( V )$ from the cache, computes attention scores, and updates the cache.
- This optimization reduces redundant computations and enhances efficiency for long-sequence text generation.
By leveraging the KV cache, Transformer models maintain high performance even for lengthy sequences. 🛠️✨
🔹 Multi-Query Attention (MQA):
- Uses N heads for queries but 1 shared head for keys and values, reducing cache size.
- Trade-off: Smaller KV cache but slightly reduced model capacity.
🔹 Grouped-Query Attention (GQA):
- Groups queries into smaller sets, where each group shares a key-value pair.
- Example: With 32 heads and a group size of 8, only 4 key-value pairs are stored instead of 32.
🔍 Example:
For a Transformer with 32 heads:
- MHA: Stores 32 KV pairs per token.
- MQA: Stores 1 KV pair per token.
- GQA: With a group size of 8, stores 4 KV pairs per token.
✅ Benefits of MQA and GQA:
- 💾 Reduced Memory Usage: Allows for longer context lengths and efficient inference.
- 🏆 Improved Performance: GQA often matches or outperforms MHA in large models.
3️⃣ Gated Linear Units (GLU)
GLU replaces the standard Feed-Forward Network (FFN) in Transformer blocks, introducing a gating mechanism that controls information flow.

🔍 Differences Between FFN and GLU:
- FFN: FFN(x) = W2 * ReLU(W1 * x + b1) + b2
- GLU: GLU(x) = W3 * (sigmoid(W1 * x + b1) ⊗ (W2 * x + b2))
Here, ⊗ is element-wise multiplication.
🔍 Example:
In GLU:
- The sigmoid gate selectively controls which parts of the input pass through, allowing more nuanced processing.
- This enables the model to focus on important features.
✅ Benefits of GLU:
- 📈 Enhanced Performance: Improves perplexity and accuracy compared to standard FFN layers.
- 🌟 Selective Information Flow: Leads to better generalization by focusing on meaningful features.
🤖 Large Language Models (LLMs)
LLMs are scaled transformers trained on massive datasets.
🌌 Emergent Abilities
Scaling introduces new capabilities, like solving modified arithmetic and unscrambling words.
🧠 In-Context Learning
LLMs generalize to new tasks without updates:
- 0️⃣ Zero-shot learning: Perform tasks with only a description.
- 1️⃣ Few-shot learning: Perform tasks with a few examples.
🏆 Examples of LLMs
- GPT: From OpenAI (e.g., GPT-1, GPT-3).
- OPT: Meta’s open-source models (125M–175B parameters).
- LLaMA: Meta’s models (LLaMA 1, 2, 3).
- Mistral-7B: Smaller model (7B params) outperforming larger ones using GQA and SWA.
📜 Chinchilla Law
Optimal training balances model and dataset size, improving accuracy and efficiency.
🔮 Advanced Topics: Multimodal LLMs
LLMs capable of processing text and images:
- 🖼️ Cross-attention: Visual tokens injected via cross-attention layers (e.g., Flamingo).
- 🔡 Visual tokens as input: Visual info directly fed as tokens (e.g., PaLM-E, RT-2).
Multimodal LLMs are designed to process and understand information from multiple modalities, such as text, images, and even robotic actions. This expands the potential of LLMs beyond text-based tasks to interactive and versatile applications.
1. Cross-Attention to Inject Visual Information (Flamingo Style)
In this approach, the pre-trained LLM remains frozen, and cross-attention layers are added at intermediate levels to incorporate visual information.

Key Components:
- Perceiver Resampler: Transforms variable-sized visual feature maps into a fixed-size set of visual tokens.
- Example: A large image is reduced to $( n )$-dimensional tokens (e.g., $( 256 \times 768 )$), suitable for cross-attention with text embeddings.
- Gated Cross-Attention Layers: Regulate the flow of visual input using a gating mechanism like $(\text{tanh}(x))$, allowing selective integration of visual features.
Example:
Suppose the input is an image of a “dog sitting on a red couch” alongside the text “What color is the couch?”
- The Perceiver Resampler compresses the image into a fixed-size set of visual tokens.
- Cross-attention enables the model to combine visual tokens with the query text, allowing the model to answer: “The couch is red.”
Models:
- Flamingo: Combines large visual and language models using cross-attention.
- Applications: Captioning images, answering visual queries.
2. Visual Tokens as Input (PaLM-E Style)
This approach treats visual inputs as tokens (similar to text tokens) and directly feeds them into the LLM alongside text.

Key Features:
- Visual inputs (e.g., image patches, sensor states) are tokenized and appended to the text tokens as a sequence.
- The LLM processes both types of tokens using its standard architecture.
Example:
An image of a robot arm lifting a block is tokenized and appended to a text instruction: “Move the block to the left.”
- The model generates action tokens to control the robotic arm directly.
Models:
- PaLM-E: Extends PaLM to handle images, robot states, and even neural 3D representations.
- RT-2: Outputs control signals directly, useful for robotic control tasks.
🌟 Applications of Multimodal LLMs
- Enhanced Personal Assistants: Combine language understanding with image and video interpretation for context-aware interactions.
- Interactive Storytelling and Gaming: Enable immersive experiences where players use both language and visual cues.
- Accessibility Improvements: Support alternative communication methods for users with disabilities.
- Autonomous Driving:
- Multimodal LLMs can interpret complex scenarios, such as a chair flying off a truck.
- Example reasoning:
- Observation: “A chair is flying off a truck on a highway.”
- Action: “Stop the vehicle, move to safety, and report to authorities.”
📊 Mathematical Representation of Multimodal Processing
Cross-Attention Mechanism
Given a text embedding sequence $( {x_t} )$ and visual tokens $( {v_i} )$:
- Project $( x_t )$ and $( v_i )$ into the same space using learned weights: $[ Q_t = x_t W_Q, \quad K_i = v_i W_K, \quad V_i = v_i W_V ]$
- Compute attention scores: $[ \text{Attention}(Q_t, K) = \text{softmax}\left(\frac{Q_t K^\top}{\sqrt{d_k}}\right)V ]$
- Combine attended features back into the language model.
Visual Tokens as Input
For visual tokens $( v_i )$, they are concatenated with text tokens $( x_t )$: $[ \text{Input Sequence} = [x_1, x_2, \ldots, x_n, v_1, v_2, \ldots, v_m] ]$
The LLM processes this combined sequence to generate output, whether it’s text, actions, or other modalities.
🧩 Challenges and Future Directions
- Data Alignment: Effectively aligning data across modalities for seamless integration.
- Computational Complexity: Reducing the high compute costs associated with multimodal inputs.
- Bias and Fairness: Ensuring equitable performance across all modalities and user demographics.
Example: Handling a Corner Case in Autonomous Driving
Scenario: An image shows a chair flying off a truck.
- Multimodal LLM processes:
- Image is tokenized into visual features.
- Textual reasoning based on visual context: “This is dangerous.”
- Suggestion: “Stop, ensure safety, and contact authorities.”
🧩 Mixture-of-Experts (MoE)
A technique for scaling LLMs with low inference costs:
- 📊 Router: Activates only a subset of model parameters for each token.
- Balances large model size with efficient computation.
Mixture-of-Experts (MoE) models utilize routing functions to determine which experts to activate for processing a given input token. This selective activation enables MoE models to scale efficiently, handling vast numbers of parameters without proportionally increasing computational costs.

In a Mixture-of-Experts (MoE) model, different tokens are routed to different experts based on a routing function, and the outputs from these selected experts are then aggregated to make the final prediction. Here’s a breakdown of how this process works:
-
Routing Tokens to Experts:
Each token $( t )$ in the input sequence is passed through a routing function. This function determines which expert (or set of experts) should process that specific token based on its embedding (or representation).
The routing function outputs a probability distribution over all available experts, indicating which expert is most suitable for the given token.
Usually, the top-k experts (those with the highest probability) are selected to process that token. This means that for each token, you might activate different experts. -
Expert Processing:
The selected expert(s) process the token, generating their corresponding output. Each expert has its own specialized sub-network designed to handle certain types of input.
For example, one expert might specialize in handling sentence structure, while another might focus on specific language patterns or domains of knowledge. -
Aggregation of Outputs:
Once the selected experts process the token, their outputs are combined (often via a weighted sum based on the probability distribution of experts) to produce the final representation for the token.
This combined output is used for the next step in the model’s computation, which might involve making predictions for the current token or passing it through further layers of the model. -
Making Predictions:
After processing through the selected experts and aggregating their outputs, the final token representation is used to predict the next token in the sequence, based on the context provided by the experts’ specialized processing.
Example:
Let’s say we have a sentence where we need to predict the next word (token) based on the previous tokens:
Input sequence: “The quick brown fox jumps”
- Routing:
- The token “quick” might be routed to Expert 1 (specializing in sentence structure).
- The token “brown” might be routed to Expert 2 (specializing in color-related concepts).
- The token “fox” might be routed to Expert 3 (specializing in animals).
- Processing:
- Each expert processes its respective token:
- Expert 1 processes “quick”.
- Expert 2 processes “brown”.
- Expert 3 processes “fox”.
- Each expert processes its respective token:
- Aggregation:
- The outputs from Expert 1, Expert 2, and Expert 3 are aggregated (e.g., a weighted sum) to produce a final representation of the sequence “The quick brown fox” with combined knowledge from the specialized experts.
- Prediction:
- The model uses this final aggregated representation to predict the next token (e.g., “jumps”).
By routing different tokens to different experts, MoE models are able to specialize processing for each token, improving the efficiency and flexibility of the model.
🔍 How Routing Functions Work
- Purpose:
- Routing functions act as gatekeepers, deciding which experts (specialized sub-networks) are best suited for processing specific input tokens.
- Input:
- The routing function takes the input token’s representation (usually its hidden state embedding).
- Output:
- It produces a probability distribution over the available experts, representing the likelihood of each expert being the best fit.
- Selection:
- Based on the probability distribution, a subset of experts is selected. Typically, the top-k experts with the highest probabilities are activated.
- Processing:
- The selected experts process the input token, and their outputs are combined (often using a weighted sum) to produce the final representation.
1. Capacity Factor (C):
- Controls the maximum number of tokens each expert can handle.
- A capacity factor of 1 means each expert processes only one token at a time.
- Higher capacity factors enable experts to handle more tokens but may lead to load imbalance.
2. Load Balancing:
- Ensures workloads are evenly distributed among experts.
- Skewed distribution, where some experts are overused while others remain idle, can degrade performance.
3. Routing Mechanisms:
- Different strategies are used for selecting experts:
- Top-k Routing: Select the k experts with the highest probabilities.
- Stochastic Routing: Randomly sample experts based on the probability distribution, promoting exploration and load balancing.
- Learned Routing: Use a trainable neural network for routing, enabling adaptive and complex decision-making.
🔢 Example: Routing in an MoE Model
Scenario:
- 6 tokens and 3 experts with a capacity factor of 1 (each expert can process one token).
Token Assignments:
- Token 1 → Expert 1 (highest probability)
- Token 2 → Expert 3 (highest probability)
- Token 3 → Expert 2 (highest probability)
- Token 4 → Expert 1 (highest probability)
- Token 5 → Expert 3 (highest probability)
- Token 6 → Expert 2 (highest probability)
Result:
- Each expert processes 2 tokens (balanced workload).
Increasing Capacity Factor:
- With a capacity factor of 1.5, each expert could process up to 3 tokens.
- However, without proper load balancing, some experts might become overloaded, while others remain idle.
📊 Mathematical Representation
-
Input Token Representation: Let the embedding of token $( t )$ be $( x_t )$.
- Expert Selection:
- Compute the probability distribution for $( x_t )$ using the routing function: $[ p_{i,t} = \text{softmax}(W_r \cdot x_t) ]$ Where $( W_r )$ are learnable parameters of the routing function, and $( p_{i,t} )$ is the probability of selecting expert $( i )$ for token $( t )$.
- Top-k Selection:
- Select the top-k experts with the highest $( p_{i,t} )$.
- For example, if $( k = 2 )$, choose the top 2 experts for each token.
- Weighted Output:
- Compute the final token representation by combining outputs from selected experts: $[ y_t = \sum_{i \in \text{Top-k}} p_{i,t} \cdot E_i(x_t) ]$ Where $( E_i )$ is the $( i )$-th expert’s output function.
🌟 Benefits of Routing Functions
- Efficiency:
- Activating only a subset of experts reduces computational costs and memory usage.
- Specialization:
- Encourages experts to specialize in specific input types, improving performance.
- Scalability:
- MoE models can scale to trillions of parameters, making them suitable for handling highly complex tasks.
🚀 Challenges and Future Directions
- Load Balancing:
- Ensuring even workload distribution across experts.
- Optimization:
- Designing efficient and scalable routing mechanisms.
- Bias and Fairness:
- Avoiding biases in token-to-expert assignments.
Research into MoE routing continues to explore innovative techniques for improving efficiency, load balancing, and overall model performance.
