TinyML Distributed Training Part 1

Modern AI models are becoming increasingly large, demanding substantial computational resources and memory. This creates a gap between the computational demands of these models and the available hardware capabilities. Pruning addresses this gap by reducing model size, memory footprint, and ultimately, energy consumption.

Course link

🚀 Introduction

This document summarizes key concepts from an MIT lecture on distributed training for machine learning models, focusing on techniques to accelerate the training process. The lecture highlights the growing need for distributed training due to the increasing size and complexity of modern models, particularly in areas like natural language processing (NLP) and computer vision. The content covers various parallelization strategies, communication primitives, memory optimization techniques, and emerging methods for handling long sequence data.

1️⃣ Background and Motivation

🔧 The Problem:

  • Modern machine learning models, particularly large language models (LLMs), require vast computational resources and training time.
  • Training models like GPT-3 on a single GPU would take hundreds of years, making research and development cycles impractically slow.

    “Without distributed training, a single GPU would take 355 years to finish GPT-3!”

⚡ Need for Speed:

  • The lecture emphasizes the need to shorten the design cycle to improve developer productivity and time-to-market.
  • Distributed training using multiple GPUs is presented as the solution to this challenge.

    “Developers/Researchers’ time is more valuable than hardware.”
    “The development and research cycle will be greatly boosted.”

📈 Scale:

  • The lecture cites an example of the Summit supercomputer, using thousands of GPUs (e.g., 1536 GPUs) to reduce the training time of a video model from days to minutes.

    “Speedup the training by 200x, from 2 days to 14 minutes.”

📊 Exponential Growth:

  • Both model size (in parameters) and computation cost are increasing exponentially, particularly in NLP.

    “NLP model size is increasing exponentially.”
    “Better models always come with higher computational costs.”

🛠️ GPU Limitations:

  • Even the most advanced GPUs, such as the Nvidia A100, cannot fit large model weights into their memory, requiring new strategies.

    “Even the best GPU CANNOT fit the model weights into memory!”

2️⃣ Parallelization Methods

The lecture introduces four key parallelization techniques to train deep learning models efficiently:

🗂️ Data Parallelism:

  • Splits the training data across multiple GPUs.
  • Each GPU holds a copy of the entire model but processes only a subset of the training data.

    “Splitting the training data to multiple GPUs… each GPU is sharing the same model.”

🏗️ Pipeline Parallelism:

  • Partitions the model’s layers across multiple GPUs.
  • Each GPU is responsible for a different set of layers and processes data sequentially (pipelined manner).
  • Allows models that are too large to fit in one GPU’s memory.

    “We are partitioning the model… GPUs are processing different layers of the model.”

🔢 Tensor Parallelism:

  • Partitions individual layers or matrix operations (tensors) of the model across multiple GPUs.
  • Splits computations within a single layer, like matrix multiplication, across devices.

    “We’re partitioning different parts of the matrix multiplication across GPUs.”

📜 Sequence Parallelism:

  • A new technique (added this year) that specifically addresses long context data (e.g., documents, videos) by splitting the input sequence (tokens) across multiple GPUs.
  • Each GPU handles the whole model but processes only a segment of the input sequence.

    “Each GPU has the whole model but is only responsible for processing part of the sequence.”

3️⃣ Data Parallelism in Depth

alt_text

🖧 Parameter Server Approach:

  • Uses a central parameter server to store model weights and aggregate gradients from worker nodes (GPUs).
  • Workers download the model, compute local gradients, send them back to the parameter server, and then receive updated weights.
  • The central server acts as a synchronization point between training nodes.

    “The parameter server receives the gradients from workers and sends back the aggregated results.”

🛑 Communication Bottleneck:

  • The parameter server model has high bandwidth requirements, which scale linearly with the number of workers, potentially becoming a bottleneck.

    “The bandwidth of the central parameter server grows linearly with the number of workers, which can be a bottleneck.”

💡 Distributed training is the backbone of scaling machine learning to modern, complex problems. By understanding these techniques, developers can leverage the full power of multi-GPU systems.

4️⃣ Communication Primitives

alt_text

🔄 One-to-One:

  • Send and Receive: Transfers data between two specific processes.

🔁 One-to-Many:

  • Scatter: Sends a tensor to all workers.
  • Gather: Receives tensors from all workers.
  • Reduce: Gathers and performs aggregation (e.g., sum, average).
  • Broadcast: Sends identical copies of a tensor to all workers.

🌐 Many-to-Many:

  • All-Reduce: Performs a reduction operation (e.g., sum, average) across all workers.
  • All-Gather: Gathers tensors from all workers into a single tensor.

🖧 Parameter Server and Communication Primitives:

alt_text

  • Replicate & Pull: Uses a broadcast primitive.
  • Push & Sum: Uses a reduce primitive.

⚙️ Optimizing All-Reduce:

alt_text

  1. Naive All-Reduce:
    • Sequential reduction operations across all workers.
    • Complexity: Time: O(N), Bandwidth: O(N).

      “Each worker requires receiving the data from all other workers, so both time and bandwidth are O(N), which is not good.”

  2. Ring All-Reduce:
    • Sends data in a ring pattern between adjacent nodes.
    • Complexity: Time: O(N), Bandwidth: O(1).

      “We only need to communicate with our neighbors, so bandwidth needs are minimal.”

  3. Parallel All-Reduce:
    • Performs simultaneous reduce operations at each node.
    • Complexity: Time: O(1), Bandwidth: O(N²).
  4. Recursive Halving All-Reduce:
    • Combines the advantages of ring and parallel all-reduce.
    • Complexity: Time: O(log(N)).

      “For N workers, all-reduce can finish within log(N) steps.”

🛠️ Parameter Server vs. Recursive Halving All-Reduce

📉 Parameter Server Approach:
  • How It Works:
    • A central controller (parameter server) manages the training process.
    • Workers: Compute gradients and send them to the parameter server.
    • Parameter Server: Aggregates results and broadcasts updated weights back to workers.
  • Challenges:
    • Bandwidth requirements grow linearly with the number of workers (O(N)).
      • Broadcast: The parameter server sends the model to all workers.
      • Reduce: The parameter server receives gradients from all workers.
    • This linear scaling becomes a bottleneck for large-scale distributed training.
🔄 All-Reduce:
  • Decentralized Alternative:
    • Eliminates the central server.
    • Each worker aggregates gradients using all-reduce, leading to better scalability.
  • Implementation Techniques:
  1. Naive Sequential All-Reduce:
    • Description: Performs a single reduce operation at each step.
    • Complexity: Time: O(N), Bandwidth: O(N).
    • Drawback: Inefficient for large numbers of workers.
  2. Ring All-Reduce:
    • Description: Workers communicate with neighbors in a ring topology.
    • Complexity: Time: O(N), Bandwidth: O(1).
    • Advantage: Bandwidth efficiency.

      “Each worker communicates only with its neighbor, minimizing bandwidth requirements.”

  3. Parallel All-Reduce:
    • Description: Simultaneously performs reduce operations at all nodes.
    • Complexity: Time: O(1), Bandwidth: O(N²).
    • Drawback: High bandwidth overhead.
  4. Recursive Halving All-Reduce (Butterfly All-Reduce):
    • Description: Combines the advantages of ring and parallel methods.
    • Complexity: Time: O(log N).
    • Process: Data is exchanged with neighbors at increasing offsets.
    • Advantage: Efficient and scalable for large-scale training.

      “This technique completes in log(N) steps, avoiding bottlenecks seen in other methods.”

🌟 Why Recursive Halving All-Reduce is Preferred:
  • Scalability: Efficient for large worker counts, thanks to its logarithmic time complexity.
  • Bandwidth Efficiency: Avoids centralized bottlenecks by leveraging decentralized communication.
  • Practical Adoption: Many labs (including the presenter’s) favor this method over parameter servers.

    “The lab does not use a centralized parameter server but instead relies on recursive halving all-reduce for its scalability and efficiency.”

5️⃣ Reducing Memory in Data Parallelism (ZeRO and FSDP)

🧮 Memory Breakdown:

  • Per Parameter:
    • Model Weights: 2 bytes (fp16).
    • Gradients: 2 bytes (fp16).
    • Optimizer States: 12 bytes (e.g., full precision master copy, momentum terms for Adam).

🛠️ ZeRO (Zero Redundancy Optimizer):

  • Family of techniques to reduce memory consumption by sharding parameters across GPUs:
    • ZeRO-1: Shards optimizer states.
    • ZeRO-2: Shards optimizer states and gradients.
    • ZeRO-3: Shards optimizer states, gradients, and model weights.

alt_text

📦 Fully Sharded Data Parallel (FSDP):

  • PyTorch implementation of ZeRO-3.
  • Essential for training very large models.

    “In PyTorch, ZeRO-3 (developed by Microsoft) is implemented as Fully Sharded Data Parallel (FSDP), a widely used library.”

6️⃣ Pipeline Parallelism (Revisited)

alt_text

🏗️ Model Partitioning:

  • Splits model layers across GPUs, enabling the model to be larger than the capacity of one GPU.

💤 Naive Implementation:

  • Results in underutilization of GPUs, with only one GPU active at a time.

⚡ Micro-Batching:

  • Splits a single large batch into smaller micro-batches.
  • Allows for pipelined execution and much higher GPU utilization.

7️⃣ Tensor Parallelism (Revisited)

alt_text

🔬 Fine-Grained Partitioning:

  • Further divides layers (e.g., fully connected layers, attention layers) by splitting matrix multiplication operations across GPUs.

🖥️ Tensor Parallelism for MLP Layers

Tensor parallelism distributes the model’s tensors (weights and activations) across multiple GPUs, minimizing communication while maximizing computational efficiency. Here’s how it’s applied to MLP layers:

🔢 First Linear Layer Partitioning:
  • Weight Matrix Partitioning:
    • The 2x6 weight matrix is split column-wise across GPUs.
    • Each GPU holds a 2x3 portion.
  • Input Activation:
    • The 1x2 input activation is broadcasted to all GPUs.
    • Each GPU receives a full copy of the input activation.
  • Local Computation:
    • Each GPU performs matrix multiplication with its portion of the weight matrix.
    • Results in two partial output matrices, Y1 and Y2.
    • 🛑 No communication is needed at this step.
🔢 Second Linear Layer Partitioning:
  • Weight Matrix Partitioning:
    • The 6x2 weight matrix is split row-wise across GPUs.
    • Each GPU holds a 3x2 portion.
  • Matrix Multiplication:
    • Each GPU multiplies its partial output matrix (Y1 or Y2) with its portion of the second weight matrix.
    • Results in partial sums Z1 and Z2.
    • 🛑 No communication is required here either.
🔄 All-Reduce Operation:
  • Final Output Combination:
    • To produce the complete 1x2 output, partial sums (Z1 and Z2) are combined using an all-reduce operation.
    • This sums the partial outputs from all GPUs.
💡 Key Idea:
  • Efficiency: Computation happens locally on each GPU without requiring communication during matrix multiplications.
  • Communication Points:
    • At the start (input broadcast).
    • At the end (all-reduce).
  • This design maximizes efficiency for two-layer MLPs.

🖥️ Tensor Parallelism for Attention Layers

Tensor parallelism is also applied to the attention mechanism, involving QKV (Query, Key, Value) transformations, attention computation, and output projection. Here’s the breakdown:

🧠 QKV Projection Layer Partitioning:
  • Weight Matrix Partitioning:
    • Q, K, and V matrices are split column-wise across GPUs.
    • Each GPU holds half of the hidden/head dimension.
  • Input Activation:
    • The input activation X is broadcasted to all GPUs.
    • Each GPU receives a copy of the input activation.
  • Local Computation:
    • Each GPU computes its partial Q, K, and V matrices.
    • Example: GPU1 computes Q1, K1, V1; GPU2 computes Q2, K2, V2.
🧮 Attention Calculation:
  • Local Attention Computation:
    • Each GPU calculates attention using its partial Q, K, and V.
    • This involves:
      • Computing Q × Kᵀ.
      • Applying softmax.
      • Multiplying the result with V.
    • Results in partial attention outputs Y1 and Y2.
    • 🛑 No communication is needed here.
🔢 Output Projection Layer Partitioning:
  • Weight Matrix Partitioning:
    • The output projection matrix is split row-wise across GPUs.
  • Local Computation:
    • Each GPU multiplies its attention output (Y1 or Y2) with its portion of the projection matrix.
    • Results in partial sums Z1 and Z2.
    • 🛑 No communication is needed here either.
🔄 All-Reduce Operation:
  • Final Output Combination:
    • Partial sums (Z1 and Z2) are combined using an all-reduce operation.
    • Produces the final output matrix Z.
💡 Key Idea:
  • Efficiency: Matrix multiplications occur locally on each GPU, minimizing inter-GPU communication.
  • Communication Points:
    • At the start (input broadcast).
    • At the end (all-reduce).
  • The structure of Q, K, and V allows for parallel processing with minimal communication overhead.

8️⃣ Sequence Parallelism (New Technique): Distributing Long Sequences Across GPUs

Sequence parallelism is a method for handling long sequences of data (e.g., long documents, videos, or robot trajectories) by distributing sequence processing across multiple GPUs. This approach is particularly valuable for large language models with extensive context requirements.

🌟 Core Idea

  • 🔀 Splitting by Token Dimension: Unlike data parallelism (which splits data by batch), sequence parallelism splits the data along the token dimension.
  • 🤝 Shared Weights Across GPUs: Each GPU processes a portion of the sequence, but the model weights are shared across GPUs.
  • 🤔 Attention Challenge:
    • The attention mechanism requires each token to attend to all other tokens.
    • Simple sequence splitting limits the attention range, as tokens on one GPU can only attend to other tokens on the same GPU.
    • To overcome this, inter-GPU communication ensures tokens across all GPUs can attend to one another.

🛠️ Two Methods for Implementing Sequence Parallelism

1️⃣ Re-partitioning Data in Attention Layers (DeepSpeed Approach):

alt_text

  • Initial Data Split:
    • The sequence is divided into segments, with each GPU processing a different segment.
    • Example: Three GPUs handle three chapters, with GPU1 processing Chapter 1, GPU2 Chapter 2, and GPU3 Chapter 3.
  • QKV Transformation:
    • Each GPU computes Q, K, and V matrices for its segment using shared model weights.
  • All-to-All Communication:
    • After QKV transformation, GPUs exchange Q, K, and V matrices for all tokens related to each attention head.
    • Each GPU gathers the required matrices for a specific attention head over the entire sequence.
  • Attention Calculation:
    • With all-to-all communication completed, each GPU computes attention for a specific head over all sequence tokens.
  • Partitioning Strategy:
    • Sequence segments are distributed across GPUs.
    • Attention heads are partitioned across GPUs, with the number of GPUs equal to the number of sequence partitions and attention heads.
  • Limitations:
    • 🚨 All-to-all communication introduces significant overhead.
    • 📉 Parallelism is limited by the number of attention heads, which is fixed after model training.
2️⃣ Ring Attention:

alt_text

  • Ring Structure:
    • GPUs are arranged in a ring.
  • Rotated Communication:
    • In each cycle, GPUs send the keys and values (KV) of their sequence portion to the next GPU in the ring.
    • Queries remain on their respective GPUs.
  • Attention Calculation:
    • Each GPU calculates attention using its queries and the received KVs.
    • Initially, GPUs process attention locally for their sequence partition.
    • Over subsequent cycles, KVs rotate around the ring, enabling queries on each GPU to attend to all parts of the sequence.
  • Iterative Attention:
    • Through multiple communication cycles, all sequence tokens can attend to one another.
  • Advantages:
    • 🚀 Higher Parallelism: Not limited by the number of attention heads.
    • 📈 Extended Context: Improves context length for both training and inference.

🔗 Combining Methods for 2D Parallelism

  • Hybrid Strategy:
    • Intra-Node Communication: DeepSpeed-style re-partitioning is used within nodes due to higher intra-node bandwidth.
    • Inter-Node Communication: Ring attention is employed between nodes to optimize bandwidth efficiency.
  • Benefit:
    • This 2D parallelism strategy balances the strengths of both methods, leveraging high bandwidth within nodes and maximizing parallelism across nodes.