mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
## Problem When running SFT with small device-batch-size (≤8), fully-masked micro-batches cause NaN loss from step 1, corrupting gradients permanently. This happens when a micro-batch contains only 'User' tokens (all targets=-1), especially common with small batch sizes on consumer GPUs. Root cause: torch.nn.functional.cross_entropy with reduction='mean' returns NaN when all labels are -1 (division by zero in mean computation). ## Solution Added validation in the training loop to detect and skip fully-masked batches: - Check (y != -1).any() before computing loss - Skip backward() for batches with no valid targets (zero gradient contribution) - Track skipped batches and warn user if >5% in first 100 steps - Log skipped batches as loss=0 for transparency ## Testing - Added comprehensive test suite (test_sft_masked_batches.py) - Tests cover: fully masked, partially masked, and unmasked batches - Documents cross_entropy behavior with ignore_index=-1 - Validates the fix logic ## Impact - Fixes #590: NaN loss with small batch sizes - No performance impact for normal batches - Helps users on consumer GPUs (RTX 3060, etc.) - Prevents silent gradient corruption Resolves #590 |
||
|---|---|---|
| .. | ||
| test_attention_fallback.py | ||
| test_engine.py | ||
| test_sft_masked_batches.py | ||