nanochat/tests
vivekvar-dl 55eb345515 fix: prevent NaN loss in SFT with fully-masked micro-batches (#590)
## Problem
When running SFT with small device-batch-size (≤8), fully-masked micro-batches
cause NaN loss from step 1, corrupting gradients permanently. This happens when
a micro-batch contains only 'User' tokens (all targets=-1), especially common
with small batch sizes on consumer GPUs.

Root cause: torch.nn.functional.cross_entropy with reduction='mean' returns NaN
when all labels are -1 (division by zero in mean computation).

## Solution
Added validation in the training loop to detect and skip fully-masked batches:
- Check (y != -1).any() before computing loss
- Skip backward() for batches with no valid targets (zero gradient contribution)
- Track skipped batches and warn user if >5% in first 100 steps
- Log skipped batches as loss=0 for transparency

## Testing
- Added comprehensive test suite (test_sft_masked_batches.py)
- Tests cover: fully masked, partially masked, and unmasked batches
- Documents cross_entropy behavior with ignore_index=-1
- Validates the fix logic

## Impact
- Fixes #590: NaN loss with small batch sizes
- No performance impact for normal batches
- Helps users on consumer GPUs (RTX 3060, etc.)
- Prevents silent gradient corruption

Resolves #590
2026-03-22 18:15:40 +00:00
..
test_attention_fallback.py delete autocast, an unnecessary thorn in my side, manage dtypes directly 2026-03-04 23:55:30 +00:00
test_engine.py Fix MockModel's device definition (#535) 2026-02-17 16:03:46 -08:00
test_sft_masked_batches.py fix: prevent NaN loss in SFT with fully-masked micro-batches (#590) 2026-03-22 18:15:40 +00:00