nanochat/VALIDATION_REPORT.md

237 lines
6.6 KiB
Markdown

# torch.compile Validation Report
## Status: READY FOR MANUAL TESTING
This document tracks the validation status of the `torch.compile` implementation for the chat SFT training script.
---
## Prerequisite Tasks Assessment
### Task 42: Fixed Padding Implementation
**Status**: ❌ NOT IMPLEMENTED
**Current State**:
- The `collate_and_yield` function (lines 89-109 in `scripts/chat_sft.py`) uses dynamic padding:
```python
ncols = max(len(ids) for ids, mask in batch) - 1 # Line 94
```
- No `max_seq_len` constant is defined (unlike `base_train.py` and `mid_train.py`)
**Required for Task 43**: Fixed padding with constant `max_seq_len=2048` must be implemented before `torch.compile` with `dynamic=False` can work effectively.
---
### Task 43: torch.compile with dynamic=False
**Status**: ❌ NOT ENABLED
**Current State**:
- Line 72 in `scripts/chat_sft.py`:
```python
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
```
- The torch.compile call is commented out
- Uses `dynamic=True` (should be `dynamic=False`)
**Required Change**:
```python
model = torch.compile(model, dynamic=False)
```
---
### Task 44: Use orig_model for Evaluation and Checkpointing
**Status**: ⚠️ PARTIALLY IMPLEMENTED
**Current State**:
- ✅ Line 71: `orig_model = model` - Variable is created
- ❌ Line 173: Uses `model` for validation (should be OK for gradient computation)
- ❌ Line 192: `run_chat_eval("MMLU", model, ...)` - Should use `orig_model`
- ❌ Line 251: `model.state_dict()` - Should use `orig_model.state_dict()`
**Required Changes**:
1. Update evaluation calls to use `orig_model`:
```python
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, ...)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, ...)
```
2. Update checkpoint saving to use `orig_model`:
```python
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(), # Changed from model.state_dict()
None,
{...}
)
```
---
## Validation Instrumentation Added
The following temporary logging has been added to `scripts/chat_sft.py` to facilitate validation:
### 1. Compilation Status Detection (Line ~76)
```python
if hasattr(model, '_orig_mod'):
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
else:
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
```
**Purpose**: Confirms whether torch.compile is active at startup
---
### 2. Batch Shape Logging (Line ~211)
```python
if step < 3 and micro_step == 0:
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
```
**Purpose**: Verifies fixed padding by checking if all batches have constant shape `(4, 2048)`
**Expected Output** (with fixed padding):
```
[VALIDATION] Step 0 | Batch shape: torch.Size([4, 2048])
[VALIDATION] Step 1 | Batch shape: torch.Size([4, 2048])
[VALIDATION] Step 2 | Batch shape: torch.Size([4, 2048])
```
---
### 3. Performance Metrics (Line ~236)
```python
# Tracks step_time and calculates tokens/sec every 10 steps
# Excludes first 5 warmup iterations
```
**Purpose**: Measures performance improvement from torch.compile
**Expected Output**:
```
[VALIDATION] Avg time/step: 2.450s | Tokens/sec: 3265.3
[VALIDATION] Avg time/step: 2.380s | Tokens/sec: 3358.0
```
**Key Metrics**:
- Baseline (without compile): Record tokens/sec
- With compile: Should show 1.3-1.5x improvement (30-50% faster)
---
## Test Execution Plan
Once prerequisites (Tasks 42, 43, 44) are completed, run the following tests:
### Test 1: Baseline (No Compilation)
```bash
# Comment out line 72 (torch.compile)
torchrun --standalone --nproc_per_node=1 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
```
**Record**:
- [ ] All batch shapes are `(4, 2048)`
- [ ] Tokens/sec: _______
- [ ] Avg time/step: _______
- [ ] Final loss: _______
---
### Test 2: With Compilation
```bash
# Uncomment line 72 and set dynamic=False
torchrun --standalone --nproc_per_node=1 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
```
**Verify**:
- [ ] Compilation message appears: `[VALIDATION] ✓ Model is compiled`
- [ ] No recompilation messages after initial compilation
- [ ] Tokens/sec improvement: _______ (target: ≥1.3x baseline)
- [ ] Loss trajectory matches Test 1 (within ±5%)
---
### Test 3: Multi-GPU (4 GPUs)
```bash
torchrun --standalone --nproc_per_node=4 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
```
**Verify**:
- [ ] All 4 ranks initialize successfully
- [ ] No DDP synchronization errors
- [ ] Performance improvement similar to single-GPU test
---
## Success Criteria
### Functional Requirements
- [ ] Constant batch shapes throughout training (verified by logging)
- [ ] Successful compilation without errors
- [ ] Zero recompilations during training
- [ ] Zero recompilations during evaluation (using orig_model)
- [ ] Checkpoints save and load correctly
- [ ] Works in both single-GPU and multi-GPU configurations
### Performance Requirements
- [ ] 30-50% speed improvement (tokens/sec ratio ≥ 1.3x)
- [ ] Initial compilation time ≤ 60 seconds
- [ ] GPU memory usage within 10% of baseline
### Accuracy Requirements
- [ ] Loss convergence matches baseline (within ±5% at iteration 100)
- [ ] Evaluation metrics match historical baselines
---
## Current Blockers
1. **Task 42 (Fixed Padding)**: Must be implemented to enable `dynamic=False` compilation
2. **Task 43 (Enable Compilation)**: Line 72 must be uncommented and changed to `dynamic=False`
3. **Task 44 (Use orig_model)**: Evaluation and checkpointing must use uncompiled model
**Recommendation**: Complete prerequisite tasks before proceeding with validation tests.
---
## Rollback Procedure
If validation fails, disable compilation by commenting out line 72:
```python
# model = torch.compile(model, dynamic=False)
```
To remove validation logging after successful testing:
1. Remove lines ~159-161 (performance tracking variables)
2. Remove line ~163 (step_start_time)
3. Remove lines ~211-213 (batch shape logging)
4. Remove lines ~233-245 (performance metrics calculation)
5. Remove lines ~76-80 (compilation status logging)
---
## Notes
- **Validation logging is temporary** and should be removed after testing
- Performance measurements should exclude the first 5 warmup iterations
- Expected net time savings: 15-20 minutes per full SFT training run
- PyTorch version must be ≥ 2.0 for torch.compile support