Floating-Point Precision & Exploding Gradients
Today while fine-tuning Llama 3.1 locally, I felt the pain first hand of training in lower precision. Floating-point precision isn't just about saving memory or slightly less accurate inference—it directly affects training stability.
Previously, I primarily saw reduced floating-point precision (like float16 or float8) as a memory optimization technique, useful for fitting larger models and speeding up computations. While aware it could slightly hurt inference quality due to rounding, I hadn't fully internalized how significantly these errors impact training stability.
During backpropagation, tiny rounding errors inherent in low-precision floating-point arithmetic accumulate with each calculation. In deep networks or recurrent models with many sequential operations, these small errors compound. This accumulation can magnify gradients uncontrollably (exploding gradients) or shrink them to zero (vanishing gradients), both of which destabilize or halt the training process.
Here's a clear example showing how float16 can overflow to infinity while float32 and float64 (torch default) maintain precision when values grow large:
import torch
initial_grad = 1e5
multiplier = 0.25
iterations = 10
dtypes = {
"float64": torch.float64,
"float32": torch.float32,
"float16": torch.float16,
}
results = {}
for name, dtype in dtypes.items():
grad = torch.tensor(initial_grad, dtype=dtype)
for _ in range(iterations):
grad *= multiplier
results[name] = grad
# Results
+-------------+-------------+
| Data Type | Result |
+=============+=============+
| float64 | 0.0953674 |
+-------------+-------------+
| float32 | 0.0953674 |
+-------------+-------------+
| float16 | inf |
+-------------+-------------+
Running this shows significant deviation caused purely by rounding errors. At higher precision (like float64), these errors diminish, keeping gradients stable.
To mitigate floating-point instability during training/fine-tuning:
-
Gradient Clipping (
torch.nn.utils.clip_grad_norm_
) is essential to cap runaway gradient values before they cause issues. -
Lower Learning Rates reduce the magnitude of weight updates, lessening the impact of potential gradient explosions.
-
Mixed-Precision Training: Selectively use higher precision (e.g.,
float32
) for critical computations like loss calculation or weight updates while keeping less sensitive parts in lower precision. -
Normalization Layers (Batch Norm, Layer Norm, RMS Norm) help keep activations and gradients within stable numerical ranges.
-
Proper Weight Initialization (e.g., Xavier/Glorot or He initialization) sets initial weights to values less likely to cause immediate gradient issues.