TinyML Distributed Training Part 2
Modern AI models are becoming increasingly large, demanding substantial computational resources and memory. This creates a gap between the computational demands of these models and the available hardware capabilities. Pruning addresses this gap by reducing model size, memory footprint, and ultimately, energy consumption.
Introduction
This document summarizes key concepts and techniques for efficient distributed training of machine learning models, as presented in MIT’s EfficientML.ai Lecture 20, focusing on hybrid parallelism, communication bottlenecks, gradient compression, and delayed updates. The core challenge addressed is how to scale training to large models and datasets while mitigating communication overhead. 🚀
1. Hybrid (Mixed) Parallelism and Auto-Parallelization
Review of Parallelization Techniques
-
Data Parallelism: Splits data across nodes, each with a full model copy. Good utilization but high memory cost. Optimizations like ZeRO and FSDP shard optimizer state, gradients, and weights. 📊
“Separate pieces of data; the data is split across multiple nodes, and each GPU maintains its own copy.” -
Pipeline Parallelism: Splits the model into layers, distributing layers across nodes. Low utilization, medium communication. 🛠
“Splitting the layers of the model…low utilization, low memory cost, medium communication.” -
Tensor Parallelism: Splits individual model tensors across nodes. High utilization, low memory cost, but high communication. 🤝
“You are splitting the model within a tensor…high utilization, low memory cost, but high communication.” -
Sequence Parallelism: Splits input sequences (e.g., tokens in NLP) across nodes. High communication during attention layers. Utilizes techniques like all-to-all reduce (UDIS) and ring attention. ✂️
“Split the tokens across multiple GPUs…actual communication happens during the attention layer.”
Hybrid/Mixed Parallelism
Combining different parallelization strategies to leverage the strengths of each while mitigating their weaknesses. Examples include:
- 2D Parallelism: Data parallelism in the outer loop and pipeline parallelism in the inner loop.
“At the outer loop, we are using data parallelism…in the inner loop, we are using pipeline parallelism.”
Pipeline parallelism in the outer loop and tensor parallelism in the inner loop.
“Performing pipeline parallelism in the outer loop…in the inner loop, we are doing tensor parallelism.”

- 2D Sequence Parallelism: Intra-node all-to-all repartition for high bandwidth and inter-node ring attention.
“Within the node, we can perform all-to-all communication using the UDIS approach, and between nodes, we can use ring attention.”

- 3D Parallelism: Combining pipeline, tensor, and data parallelism. Tensor (model) parallelism is kept within a server because of its high communication cost.
“Combining pipeline parallelism, tensor parallelism, and data parallelism…within a server, we are doing model parallelism since tensor parallelism is the most communication-heavy.”

Auto-Parallelization
The challenge of automatically finding the best parallelization strategy.
- Alpa Compiler: Automates distributed training using a hierarchical search space. 🖥
“Unified compiler for distributed training…given a computational graph, Alpa proposes a hierarchical search space.”
Hierarchical Search Space
- Inter-operator Parallelism: Partitioning the computation graph into stages (groups of layers) using dynamic programming.
“Inter-operator pass…to estimate the cost.” -
Intra-operator Parallelism: Partitioning individual operators (layers) using different techniques (e.g., data parallel, tensor parallel).
“Given A × B, we can split A across GPUs and replicate B.” - Cost-Based Optimization: Alpa uses a cost function to optimize the strategy, often matching or exceeding manually tuned systems.
“Given a computational graph and device cluster, Alpa optimizes orchestration.”
2. Bandwidth and Latency Bottlenecks
Communication as the Bottleneck 📡
- Synchronization: Synchronized SGD requires frequent communication.
“Synchronization leads to high communication frequency after each iteration.” - Large Transfer Sizes: Larger models require transferring more gradients.
“Models are getting larger, increasing the communication burden.”
Impact of Network Latency
- Latency Scale:
- Within a rack: Microseconds/milliseconds.
- Within a data center: Milliseconds.
- Across the world: Hundreds of milliseconds/seconds.
“Latency varies across scales: within a rack (microseconds), across the world (hundreds of milliseconds).”
3. Gradient Compression to Overcome Bandwidth Bottleneck
Techniques
- Gradient Pruning: Transmit only top-K gradients by magnitude; prune smaller gradients. 🌱
“Prune gradients by setting smaller values to zero.” - Local Gradient Accumulation: Accumulate pruned gradients locally, corrected with momentum.
“Accumulate gradients locally instead of transmitting immediately.” - Gradient Quantization: Reduce gradient precision (e.g., 1-bit SGD).
“Quantize gradients to lower precision.” - PowerSGD: Factorize gradients using low-rank matrices for consistency during all-reduce.
4. Delayed Gradient Update to Overcome Latency Bottleneck
-
Delayed Gradient Averaging (DGA): Overlap computation with communication by allowing “stale” gradients. ⏳
“Workers can proceed with computation while awaiting communication.” -
Correction Term: Adjust for gradient staleness by subtracting old gradients and adding fresh ones.
“Subtract stale gradients and add fresh updates.” -
Performance: Significant speedups (e.g., up to 7.5x on Raspberry Pi clusters). 📈
“Effective even in high-latency environments.”
