ML Training Optimization: FLOPs, Profiling, and Learning Strategies
๐ป Fun disclaimer: Used GPT to get all the beautiful visual gradients, but the content is mine!
When training large-scale machine learning models, optimization goes beyond just hyperparameter tuning. This guide covers the essential aspects of efficient ML training: computational constraints, performance profiling, and learning strategies that can save you significant costs and time.
1. FLOPs and Chinchilla Scaling Law
When training large-scale ML models, you typically have FLOPs (Floating Point Operations) constraints. The Chinchilla scaling law provides crucial guidance on how to allocate your compute budget effectively.
Chinchilla Scaling Law
For a fixed compute budget (FLOPs), you need to decide between having more parameters (bigger model) or training the model for longer (showing it more data).
Two Critical Cases to Avoid
1. Compute Inefficient Training
2. Data Inefficient Training
2. Profiling Your Code
Profiling your training code is essential for maximizing GPU utilization and getting the best performance for your investment. This is different from hyperparameter tuning, which focuses on model learning rather than computational efficiency.
Key Bottlenecks to Monitor
I/O Bottleneck
Don't assume that just because your GPU can handle a larger batch size, you should use it. PyTorch data loaders work on CPU threads, and if your GPU finishes processing batch 1 but your data loader isn't ready with batch 2, your GPU sits idle.
Memory Bottleneck
Good profiling reveals what's consuming your memory. Common culprits include:
- Per-layer activations
- Gradients storage
- Temporary tensor assignments
- Optimizer states
Memory Optimization Techniques
Gradient Checkpointing
- Trades computation for memory
- Recomputes activations during backward pass
- Can reduce memory by 50-80%
Mixed Precision
- Uses FP16 for forward pass
- Maintains FP32 for gradients
- Reduces memory by ~50%
CPU โ GPU Transfer Bottleneck
Moving data between CPU and GPU is often a major bottleneck due to bandwidth limitations. Common scenarios that cause this issue:
- Using
.item()to extract scalar values - Checkpointing weights to CPU
- Frequent data transfers during training
Kernel Overhead
Launching many small kernels can create overhead. The CPU tells the GPU to launch numerous kernels, and the GPU may struggle to keep up with the launch rate.
Profiling Priority
Always profile your code first to identify bottlenecks before focusing on accuracy improvements. This approach will save you significant costs.
3. Learning Strategies
Once you've optimized your computational efficiency, focus on improving model performance through effective learning strategies.
Batch Size Selection
Choose the highest batch size your GPU and data loader can handle, but ensure you maintain some stochasticity in your updates. When you change batch size, adjust your learning rate accordingly (usually linearly).
Gradient Accumulation
If your learning is too noisy (loss oscillates up and down), consider gradient accumulation to smooth the updates:
Frequently Asked Questions
When do you stop training? What is the ideal loss?
What if training loss keeps dropping but validation loss increases?
This is classic overfitting. Solutions include:
- Add regularization (dropout, weight decay)
- Collect more training data
- Implement early stopping
- Reduce model complexity
How do I know if my learning rate is too high or low?
Learning Rate Too High
- Loss oscillates or spikes
- Gradients explode
- Training becomes unstable
Learning Rate Too Low
- Loss crawls down slowly
- Training stalls early
- Very slow convergence
Key Takeaways
Effective ML training optimization requires balancing computational efficiency, proper profiling, and smart learning strategies. Always profile first to identify bottlenecks, then focus on model performance improvements. This systematic approach will save you both time and money.