6.6 KiB
torch.compile Validation Report
Status: READY FOR MANUAL TESTING
This document tracks the validation status of the torch.compile implementation for the chat SFT training script.
Prerequisite Tasks Assessment
Task 42: Fixed Padding Implementation
Status: ❌ NOT IMPLEMENTED
Current State:
- The
collate_and_yieldfunction (lines 89-109 inscripts/chat_sft.py) uses dynamic padding:ncols = max(len(ids) for ids, mask in batch) - 1 # Line 94 - No
max_seq_lenconstant is defined (unlikebase_train.pyandmid_train.py)
Required for Task 43: Fixed padding with constant max_seq_len=2048 must be implemented before torch.compile with dynamic=False can work effectively.
Task 43: torch.compile with dynamic=False
Status: ❌ NOT ENABLED
Current State:
- Line 72 in
scripts/chat_sft.py:# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs - The torch.compile call is commented out
- Uses
dynamic=True(should bedynamic=False)
Required Change:
model = torch.compile(model, dynamic=False)
Task 44: Use orig_model for Evaluation and Checkpointing
Status: ⚠️ PARTIALLY IMPLEMENTED
Current State:
- ✅ Line 71:
orig_model = model- Variable is created - ❌ Line 173: Uses
modelfor validation (should be OK for gradient computation) - ❌ Line 192:
run_chat_eval("MMLU", model, ...)- Should useorig_model - ❌ Line 251:
model.state_dict()- Should useorig_model.state_dict()
Required Changes:
-
Update evaluation calls to use
orig_model:metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, ...) metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, ...) -
Update checkpoint saving to use
orig_model:save_checkpoint( checkpoint_dir, step, orig_model.state_dict(), # Changed from model.state_dict() None, {...} )
Validation Instrumentation Added
The following temporary logging has been added to scripts/chat_sft.py to facilitate validation:
1. Compilation Status Detection (Line ~76)
if hasattr(model, '_orig_mod'):
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
else:
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
Purpose: Confirms whether torch.compile is active at startup
2. Batch Shape Logging (Line ~211)
if step < 3 and micro_step == 0:
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
Purpose: Verifies fixed padding by checking if all batches have constant shape (4, 2048)
Expected Output (with fixed padding):
[VALIDATION] Step 0 | Batch shape: torch.Size([4, 2048])
[VALIDATION] Step 1 | Batch shape: torch.Size([4, 2048])
[VALIDATION] Step 2 | Batch shape: torch.Size([4, 2048])
3. Performance Metrics (Line ~236)
# Tracks step_time and calculates tokens/sec every 10 steps
# Excludes first 5 warmup iterations
Purpose: Measures performance improvement from torch.compile
Expected Output:
[VALIDATION] Avg time/step: 2.450s | Tokens/sec: 3265.3
[VALIDATION] Avg time/step: 2.380s | Tokens/sec: 3358.0
Key Metrics:
- Baseline (without compile): Record tokens/sec
- With compile: Should show 1.3-1.5x improvement (30-50% faster)
Test Execution Plan
Once prerequisites (Tasks 42, 43, 44) are completed, run the following tests:
Test 1: Baseline (No Compilation)
# Comment out line 72 (torch.compile)
torchrun --standalone --nproc_per_node=1 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
Record:
- All batch shapes are
(4, 2048) - Tokens/sec: _______
- Avg time/step: _______
- Final loss: _______
Test 2: With Compilation
# Uncomment line 72 and set dynamic=False
torchrun --standalone --nproc_per_node=1 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
Verify:
- Compilation message appears:
[VALIDATION] ✓ Model is compiled - No recompilation messages after initial compilation
- Tokens/sec improvement: _______ (target: ≥1.3x baseline)
- Loss trajectory matches Test 1 (within ±5%)
Test 3: Multi-GPU (4 GPUs)
torchrun --standalone --nproc_per_node=4 \
-m scripts.chat_sft -- \
--max_iterations=100 \
--model_source=base \
--model_tag=d20 \
--step=0
Verify:
- All 4 ranks initialize successfully
- No DDP synchronization errors
- Performance improvement similar to single-GPU test
Success Criteria
Functional Requirements
- Constant batch shapes throughout training (verified by logging)
- Successful compilation without errors
- Zero recompilations during training
- Zero recompilations during evaluation (using orig_model)
- Checkpoints save and load correctly
- Works in both single-GPU and multi-GPU configurations
Performance Requirements
- 30-50% speed improvement (tokens/sec ratio ≥ 1.3x)
- Initial compilation time ≤ 60 seconds
- GPU memory usage within 10% of baseline
Accuracy Requirements
- Loss convergence matches baseline (within ±5% at iteration 100)
- Evaluation metrics match historical baselines
Current Blockers
- Task 42 (Fixed Padding): Must be implemented to enable
dynamic=Falsecompilation - Task 43 (Enable Compilation): Line 72 must be uncommented and changed to
dynamic=False - Task 44 (Use orig_model): Evaluation and checkpointing must use uncompiled model
Recommendation: Complete prerequisite tasks before proceeding with validation tests.
Rollback Procedure
If validation fails, disable compilation by commenting out line 72:
# model = torch.compile(model, dynamic=False)
To remove validation logging after successful testing:
- Remove lines ~159-161 (performance tracking variables)
- Remove line ~163 (step_start_time)
- Remove lines ~211-213 (batch shape logging)
- Remove lines ~233-245 (performance metrics calculation)
- Remove lines ~76-80 (compilation status logging)
Notes
- Validation logging is temporary and should be removed after testing
- Performance measurements should exclude the first 5 warmup iterations
- Expected net time savings: 15-20 minutes per full SFT training run
- PyTorch version must be ≥ 2.0 for torch.compile support