mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
test: add torch.compile performance validation logging with multi-GPU compatibility checks
This commit is contained in:
parent
49d29417f1
commit
47935c69d5
236
VALIDATION_REPORT.md
Normal file
236
VALIDATION_REPORT.md
Normal file
|
|
@ -0,0 +1,236 @@
|
||||||
|
# torch.compile Validation Report
|
||||||
|
|
||||||
|
## Status: READY FOR MANUAL TESTING
|
||||||
|
|
||||||
|
This document tracks the validation status of the `torch.compile` implementation for the chat SFT training script.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Prerequisite Tasks Assessment
|
||||||
|
|
||||||
|
### Task 42: Fixed Padding Implementation
|
||||||
|
**Status**: ❌ NOT IMPLEMENTED
|
||||||
|
|
||||||
|
**Current State**:
|
||||||
|
- The `collate_and_yield` function (lines 89-109 in `scripts/chat_sft.py`) uses dynamic padding:
|
||||||
|
```python
|
||||||
|
ncols = max(len(ids) for ids, mask in batch) - 1 # Line 94
|
||||||
|
```
|
||||||
|
- No `max_seq_len` constant is defined (unlike `base_train.py` and `mid_train.py`)
|
||||||
|
|
||||||
|
**Required for Task 43**: Fixed padding with constant `max_seq_len=2048` must be implemented before `torch.compile` with `dynamic=False` can work effectively.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 43: torch.compile with dynamic=False
|
||||||
|
**Status**: ❌ NOT ENABLED
|
||||||
|
|
||||||
|
**Current State**:
|
||||||
|
- Line 72 in `scripts/chat_sft.py`:
|
||||||
|
```python
|
||||||
|
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||||
|
```
|
||||||
|
- The torch.compile call is commented out
|
||||||
|
- Uses `dynamic=True` (should be `dynamic=False`)
|
||||||
|
|
||||||
|
**Required Change**:
|
||||||
|
```python
|
||||||
|
model = torch.compile(model, dynamic=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Task 44: Use orig_model for Evaluation and Checkpointing
|
||||||
|
**Status**: ⚠️ PARTIALLY IMPLEMENTED
|
||||||
|
|
||||||
|
**Current State**:
|
||||||
|
- ✅ Line 71: `orig_model = model` - Variable is created
|
||||||
|
- ❌ Line 173: Uses `model` for validation (should be OK for gradient computation)
|
||||||
|
- ❌ Line 192: `run_chat_eval("MMLU", model, ...)` - Should use `orig_model`
|
||||||
|
- ❌ Line 251: `model.state_dict()` - Should use `orig_model.state_dict()`
|
||||||
|
|
||||||
|
**Required Changes**:
|
||||||
|
1. Update evaluation calls to use `orig_model`:
|
||||||
|
```python
|
||||||
|
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, ...)
|
||||||
|
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, ...)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Update checkpoint saving to use `orig_model`:
|
||||||
|
```python
|
||||||
|
save_checkpoint(
|
||||||
|
checkpoint_dir,
|
||||||
|
step,
|
||||||
|
orig_model.state_dict(), # Changed from model.state_dict()
|
||||||
|
None,
|
||||||
|
{...}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Validation Instrumentation Added
|
||||||
|
|
||||||
|
The following temporary logging has been added to `scripts/chat_sft.py` to facilitate validation:
|
||||||
|
|
||||||
|
### 1. Compilation Status Detection (Line ~76)
|
||||||
|
```python
|
||||||
|
if hasattr(model, '_orig_mod'):
|
||||||
|
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
|
||||||
|
else:
|
||||||
|
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Confirms whether torch.compile is active at startup
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2. Batch Shape Logging (Line ~211)
|
||||||
|
```python
|
||||||
|
if step < 3 and micro_step == 0:
|
||||||
|
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Verifies fixed padding by checking if all batches have constant shape `(4, 2048)`
|
||||||
|
|
||||||
|
**Expected Output** (with fixed padding):
|
||||||
|
```
|
||||||
|
[VALIDATION] Step 0 | Batch shape: torch.Size([4, 2048])
|
||||||
|
[VALIDATION] Step 1 | Batch shape: torch.Size([4, 2048])
|
||||||
|
[VALIDATION] Step 2 | Batch shape: torch.Size([4, 2048])
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 3. Performance Metrics (Line ~236)
|
||||||
|
```python
|
||||||
|
# Tracks step_time and calculates tokens/sec every 10 steps
|
||||||
|
# Excludes first 5 warmup iterations
|
||||||
|
```
|
||||||
|
|
||||||
|
**Purpose**: Measures performance improvement from torch.compile
|
||||||
|
|
||||||
|
**Expected Output**:
|
||||||
|
```
|
||||||
|
[VALIDATION] Avg time/step: 2.450s | Tokens/sec: 3265.3
|
||||||
|
[VALIDATION] Avg time/step: 2.380s | Tokens/sec: 3358.0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Metrics**:
|
||||||
|
- Baseline (without compile): Record tokens/sec
|
||||||
|
- With compile: Should show 1.3-1.5x improvement (30-50% faster)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Execution Plan
|
||||||
|
|
||||||
|
Once prerequisites (Tasks 42, 43, 44) are completed, run the following tests:
|
||||||
|
|
||||||
|
### Test 1: Baseline (No Compilation)
|
||||||
|
```bash
|
||||||
|
# Comment out line 72 (torch.compile)
|
||||||
|
torchrun --standalone --nproc_per_node=1 \
|
||||||
|
-m scripts.chat_sft -- \
|
||||||
|
--max_iterations=100 \
|
||||||
|
--model_source=base \
|
||||||
|
--model_tag=d20 \
|
||||||
|
--step=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Record**:
|
||||||
|
- [ ] All batch shapes are `(4, 2048)`
|
||||||
|
- [ ] Tokens/sec: _______
|
||||||
|
- [ ] Avg time/step: _______
|
||||||
|
- [ ] Final loss: _______
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Test 2: With Compilation
|
||||||
|
```bash
|
||||||
|
# Uncomment line 72 and set dynamic=False
|
||||||
|
torchrun --standalone --nproc_per_node=1 \
|
||||||
|
-m scripts.chat_sft -- \
|
||||||
|
--max_iterations=100 \
|
||||||
|
--model_source=base \
|
||||||
|
--model_tag=d20 \
|
||||||
|
--step=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verify**:
|
||||||
|
- [ ] Compilation message appears: `[VALIDATION] ✓ Model is compiled`
|
||||||
|
- [ ] No recompilation messages after initial compilation
|
||||||
|
- [ ] Tokens/sec improvement: _______ (target: ≥1.3x baseline)
|
||||||
|
- [ ] Loss trajectory matches Test 1 (within ±5%)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Test 3: Multi-GPU (4 GPUs)
|
||||||
|
```bash
|
||||||
|
torchrun --standalone --nproc_per_node=4 \
|
||||||
|
-m scripts.chat_sft -- \
|
||||||
|
--max_iterations=100 \
|
||||||
|
--model_source=base \
|
||||||
|
--model_tag=d20 \
|
||||||
|
--step=0
|
||||||
|
```
|
||||||
|
|
||||||
|
**Verify**:
|
||||||
|
- [ ] All 4 ranks initialize successfully
|
||||||
|
- [ ] No DDP synchronization errors
|
||||||
|
- [ ] Performance improvement similar to single-GPU test
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Success Criteria
|
||||||
|
|
||||||
|
### Functional Requirements
|
||||||
|
- [ ] Constant batch shapes throughout training (verified by logging)
|
||||||
|
- [ ] Successful compilation without errors
|
||||||
|
- [ ] Zero recompilations during training
|
||||||
|
- [ ] Zero recompilations during evaluation (using orig_model)
|
||||||
|
- [ ] Checkpoints save and load correctly
|
||||||
|
- [ ] Works in both single-GPU and multi-GPU configurations
|
||||||
|
|
||||||
|
### Performance Requirements
|
||||||
|
- [ ] 30-50% speed improvement (tokens/sec ratio ≥ 1.3x)
|
||||||
|
- [ ] Initial compilation time ≤ 60 seconds
|
||||||
|
- [ ] GPU memory usage within 10% of baseline
|
||||||
|
|
||||||
|
### Accuracy Requirements
|
||||||
|
- [ ] Loss convergence matches baseline (within ±5% at iteration 100)
|
||||||
|
- [ ] Evaluation metrics match historical baselines
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Current Blockers
|
||||||
|
|
||||||
|
1. **Task 42 (Fixed Padding)**: Must be implemented to enable `dynamic=False` compilation
|
||||||
|
2. **Task 43 (Enable Compilation)**: Line 72 must be uncommented and changed to `dynamic=False`
|
||||||
|
3. **Task 44 (Use orig_model)**: Evaluation and checkpointing must use uncompiled model
|
||||||
|
|
||||||
|
**Recommendation**: Complete prerequisite tasks before proceeding with validation tests.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Rollback Procedure
|
||||||
|
|
||||||
|
If validation fails, disable compilation by commenting out line 72:
|
||||||
|
```python
|
||||||
|
# model = torch.compile(model, dynamic=False)
|
||||||
|
```
|
||||||
|
|
||||||
|
To remove validation logging after successful testing:
|
||||||
|
1. Remove lines ~159-161 (performance tracking variables)
|
||||||
|
2. Remove line ~163 (step_start_time)
|
||||||
|
3. Remove lines ~211-213 (batch shape logging)
|
||||||
|
4. Remove lines ~233-245 (performance metrics calculation)
|
||||||
|
5. Remove lines ~76-80 (compilation status logging)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- **Validation logging is temporary** and should be removed after testing
|
||||||
|
- Performance measurements should exclude the first 5 warmup iterations
|
||||||
|
- Expected net time savings: 15-20 minutes per full SFT training run
|
||||||
|
- PyTorch version must be ≥ 2.0 for torch.compile support
|
||||||
|
|
@ -70,7 +70,12 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
||||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||||
orig_model = model # original, uncompiled model
|
orig_model = model # original, uncompiled model
|
||||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||||
engine = Engine(orig_model, tokenizer) # will be used for inline model evaluation only
|
# Validation: Log compilation status
|
||||||
|
if hasattr(model, '_orig_mod'):
|
||||||
|
print0("[VALIDATION] ✓ Model is compiled (torch.compile detected)")
|
||||||
|
else:
|
||||||
|
print0("[VALIDATION] ✗ Model is NOT compiled (running in eager mode)")
|
||||||
|
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Task data mixture we'll train on
|
# Task data mixture we'll train on
|
||||||
|
|
@ -156,10 +161,16 @@ def get_lr_multiplier(it):
|
||||||
lrm = 1.0 - it / num_iterations
|
lrm = 1.0 - it / num_iterations
|
||||||
return lrm
|
return lrm
|
||||||
|
|
||||||
|
# Validation: Performance tracking variables
|
||||||
|
import time
|
||||||
|
step_times = []
|
||||||
|
step_tokens = []
|
||||||
|
|
||||||
# Go!
|
# Go!
|
||||||
step = 0
|
step = 0
|
||||||
train_iter = iter(train_loader)
|
train_iter = iter(train_loader)
|
||||||
for step in range(num_iterations):
|
for step in range(num_iterations):
|
||||||
|
step_start_time = time.time()
|
||||||
last_step = step == num_iterations - 1
|
last_step = step == num_iterations - 1
|
||||||
|
|
||||||
# evaluate the validation loss
|
# evaluate the validation loss
|
||||||
|
|
@ -189,8 +200,8 @@ for step in range(num_iterations):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
with torch.no_grad(), autocast_ctx:
|
with torch.no_grad(), autocast_ctx:
|
||||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", orig_model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||||
print0(f"Step {step:05d} | {metrics_str}")
|
print0(f"Step {step:05d} | {metrics_str}")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
|
|
@ -206,6 +217,9 @@ for step in range(num_iterations):
|
||||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
train_inputs, train_targets = next(train_iter)
|
train_inputs, train_targets = next(train_iter)
|
||||||
|
# Validation: Log batch shapes for first 3 steps to verify fixed padding
|
||||||
|
if step < 3 and micro_step == 0:
|
||||||
|
print0(f"[VALIDATION] Step {step} | Batch shape: {train_inputs.shape}")
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
loss = model(train_inputs, train_targets)
|
loss = model(train_inputs, train_targets)
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
|
|
@ -226,15 +240,33 @@ for step in range(num_iterations):
|
||||||
opt.step()
|
opt.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
|
# Validation: Calculate performance metrics
|
||||||
|
step_end_time = time.time()
|
||||||
|
step_time = step_end_time - step_start_time
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
train_loss_item = train_loss.item()
|
train_loss_item = train_loss.item()
|
||||||
num_tokens_item = num_tokens.item()
|
num_tokens_item = num_tokens.item()
|
||||||
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,}")
|
|
||||||
|
# Validation: Track performance (skip first 5 warmup iterations)
|
||||||
|
if step >= 5:
|
||||||
|
step_times.append(step_time)
|
||||||
|
step_tokens.append(num_tokens_item)
|
||||||
|
|
||||||
|
# Validation: Calculate and log performance metrics every 10 steps (after warmup)
|
||||||
|
if step >= 5 and step % 10 == 0:
|
||||||
|
avg_step_time = sum(step_times[-10:]) / len(step_times[-10:]) if len(step_times) >= 10 else sum(step_times) / len(step_times)
|
||||||
|
avg_tokens = sum(step_tokens[-10:]) / len(step_tokens[-10:]) if len(step_tokens) >= 10 else sum(step_tokens) / len(step_tokens)
|
||||||
|
tokens_per_sec = avg_tokens / avg_step_time if avg_step_time > 0 else 0
|
||||||
|
print0(f"[VALIDATION] Avg time/step: {avg_step_time:.3f}s | Tokens/sec: {tokens_per_sec:.1f}")
|
||||||
|
|
||||||
|
print0(f"Step {step:05d}/{num_iterations:05d} | Training loss: {train_loss_item:.6f}| lrm: {lrm:.6f}| num_tokens: {num_tokens_item:,} | time: {step_time:.3f}s")
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
"step": step,
|
"step": step,
|
||||||
"lrm": lrm,
|
"lrm": lrm,
|
||||||
"train_loss": train_loss_item,
|
"train_loss": train_loss_item,
|
||||||
"num_tokens": num_tokens_item,
|
"num_tokens": num_tokens_item,
|
||||||
|
"step_time": step_time,
|
||||||
})
|
})
|
||||||
step += 1
|
step += 1
|
||||||
|
|
||||||
|
|
@ -248,7 +280,7 @@ if master_process:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
orig_model.state_dict(),
|
model.state_dict(),
|
||||||
None, # note: we don't bother to save the optimizer state
|
None, # note: we don't bother to save the optimizer state
|
||||||
{
|
{
|
||||||
"step": step,
|
"step": step,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user