mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-02 23:40:36 +00:00
Add PyTorch and CUDA memory profiling systems
Capture PyTorch execution traces and CUDA memory snapshots. Traces display detailed CPU and CUDA activity, including individual CUDA kernel calls. CUDA memory snapshots visualize all memory allocations, helping diagnose CUDA out-of-memory errors, investigate memory leaks, or understand GPU memory usage for educational purposes. Enable profiling with the --enable_profiling=True flag in speedrun.sh. See PROFILING.md for documentation and example visualizations.
This commit is contained in:
parent
d6d86cbf4c
commit
50b236fbcc
84
PROFILING.md
Normal file
84
PROFILING.md
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
# Profiling Guide for Nanochat Training
|
||||
|
||||
## Overview
|
||||
|
||||
The profiling system supports two types of profiling:
|
||||
1. **PyTorch Profiler** - Captures detailed CPU/CUDA activity traces and memory timelines
|
||||
2. **CUDA Memory Snapshot** - Captures detailed memory allocation/deallocation events
|
||||
|
||||
## Quick Start
|
||||
|
||||
To enable profiling, add `--enable_profiling=True` to the `scripts.base_train` command in `speedrun.sh` e.g.:
|
||||
|
||||
```bash
|
||||
torchrun --standalone --nproc_per_node=$NUM_GPU scripts.base_train --enable_profiling=True
|
||||
```
|
||||
|
||||
This enables both PyTorch profiler and CUDA memory profiling for all phases. Look for cyan-colored `[PROFILER:phase_name]` messages in the output to track profiling progress.
|
||||
|
||||
## Output Files
|
||||
|
||||
Profiling outputs are saved to `<base_dir>/profile_traces/<timestamp>/` where each run gets its own timestamped subdirectory to avoid overwriting previous profiles (e.g. `~/.cache/nanochat/profile_traces/20250117_143022/`).
|
||||
|
||||
### PyTorch Profiler Outputs
|
||||
- `{stage}-{phase_name}_trace.json` - Chrome trace file (viewable in chrome://tracing or https://ui.perfetto.dev/)
|
||||
- `{stage}-{phase_name}_memory_timeline.html` - Memory timeline visualization
|
||||
|
||||
### CUDA Memory Snapshot Outputs
|
||||
- `{stage}-{phase_name}_mem.pickle` - Memory snapshot (can be analyzed by visiting https://docs.pytorch.org/memory_viz)
|
||||
|
||||
**Stage prefixes**: Files are prefixed with `stage1-`, `stage2-`, `stage3-`, `stage4-` to indicate the order of profiling phases:
|
||||
- `stage1-` = model_init (model creation, weight init, compilation)
|
||||
- `stage2-` = eval_bpb (first validation evaluation)
|
||||
- `stage3-` = training_microsteps (training micro-steps)
|
||||
- `stage4-` = optimizer_step (gradient clipping, learning rate scheduling, optimizer step, zero_grad)
|
||||
|
||||
### Example Output Files
|
||||
|
||||
Typical profiling run generates the following files (sizes may vary based on model size and profiling duration):
|
||||
|
||||
```bash
|
||||
$ cd ~/.cache/nanochat/profile_traces/20251018_111748
|
||||
$ du -hs *
|
||||
100K stage1-model_init_mem.pickle
|
||||
56K stage1-model_init_memory_timeline.html
|
||||
64M stage1-model_init_trace.json
|
||||
53M stage2-eval_bpb_mem.pickle
|
||||
52K stage2-eval_bpb_memory_timeline.html
|
||||
73M stage2-eval_bpb_trace.json
|
||||
43M stage3-training_microsteps_mem.pickle
|
||||
236K stage3-training_microsteps_memory_timeline.html
|
||||
126M stage3-training_microsteps_trace.json
|
||||
1.2M stage4-optimizer_step_mem.pickle
|
||||
76K stage4-optimizer_step_memory_timeline.html
|
||||
369M stage4-optimizer_step_trace.json
|
||||
```
|
||||
|
||||
**Total**: ~728 MB for a complete profiling run with all phases enabled.
|
||||
|
||||
## Viewing Profiling Results
|
||||
|
||||
### Chrome Traces
|
||||
1. Open a web browser and visit https://ui.perfetto.dev/
|
||||
2. Click "Open trace file" and select a `*_trace.json` file
|
||||
3. Explore the timeline view with zoom and pan
|
||||
|
||||

|
||||
|
||||
*Nanochat profiling results: Training microsteps trace showing CPU/CUDA activity timeline down to individual CUDA kernel calls*
|
||||
|
||||
### Memory Timelines
|
||||
1. Open the `*_memory_timeline.html` file in a web browser
|
||||
2. Explore memory allocation patterns over time
|
||||
|
||||

|
||||
|
||||
*Nanochat profiling results: Memory timeline visualization showing allocation patterns across training micro-steps*
|
||||
|
||||
### CUDA Memory Snapshots
|
||||
1. Visit https://pytorch.org/memory_viz in a web browser
|
||||
2. Upload the `*_mem.pickle` file to analyze memory patterns
|
||||
|
||||

|
||||
|
||||
*Nanochat profiling results: CUDA memory snapshot showing detailed memory allocations by category*
|
||||
|
|
@ -117,6 +117,10 @@ I haven't invested too much here but some tests exist, especially for the tokeni
|
|||
python -m pytest tests/test_rustbpe.py -v -s
|
||||
```
|
||||
|
||||
## Profiling
|
||||
|
||||
This project includes tools to capture PyTorch traces and CUDA memory snapshots for detailed profiling of CPU/GPU activity and memory usage. Traces show execution down to individual CUDA kernel calls, while memory snapshots visualize allocations to help debug OOM errors, memory leaks, or simply learn how GPU memory is managed. Enable with `--enable_profiling=True` in speedrun.sh, and see [PROFILING.md](PROFILING.md) for details and example visualizations.
|
||||
|
||||
## Contributing
|
||||
|
||||
nanochat is nowhere finished. The goal is to improve the state of the art in micro models that are accessible to work with end to end on budgets of < $1000 dollars. Accessibility is about overall cost but also about cognitive complexity - nanochat is not an exhaustively configurable LLM "framework"; there will be no giant configuration objects, model factories, or if-then-else monsters in the code base. It is a single, cohesive, minimal, readable, hackable, maximally-forkable "strong baseline" codebase designed to run start to end and produce a concrete ChatGPT clone and its report card.
|
||||
|
|
|
|||
BIN
assets/images/nanochat-profiling-memory-snapshot.jpeg
Normal file
BIN
assets/images/nanochat-profiling-memory-snapshot.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 708 KiB |
BIN
assets/images/nanochat-profiling-memory-timeline.jpeg
Normal file
BIN
assets/images/nanochat-profiling-memory-timeline.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 193 KiB |
BIN
assets/images/nanochat-profiling-trace.jpeg
Normal file
BIN
assets/images/nanochat-profiling-trace.jpeg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 452 KiB |
|
|
@ -6,7 +6,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_bpb(model, batches, steps, token_bytes):
|
||||
def evaluate_bpb(model, batches, steps, token_bytes, profiler=None, profile_phase="eval_bpb"):
|
||||
"""
|
||||
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
|
||||
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
|
||||
|
|
@ -28,6 +28,13 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_nats = torch.tensor(0.0, dtype=torch.float32, device=model.get_device())
|
||||
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
||||
batch_iter = iter(batches)
|
||||
|
||||
# Start profiling if profiler is provided
|
||||
profile_ctx = None
|
||||
if profiler is not None:
|
||||
profile_ctx = profiler.profile_section(profile_phase, warmup=1, active=min(10, steps))
|
||||
profile_ctx.__enter__()
|
||||
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
|
|
@ -51,6 +58,14 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
num_bytes2d = token_bytes[y]
|
||||
total_nats += (loss2d * (num_bytes2d > 0)).sum()
|
||||
total_bytes += num_bytes2d.sum()
|
||||
|
||||
# Step profiler if active
|
||||
if profile_ctx is not None:
|
||||
profile_ctx.step()
|
||||
|
||||
# Stop profiling if it was started
|
||||
if profile_ctx is not None:
|
||||
profile_ctx.__exit__(None, None, None)
|
||||
# sum reduce across all ranks
|
||||
world_size = dist.get_world_size() if dist.is_initialized() else 1
|
||||
if world_size > 1:
|
||||
|
|
|
|||
199
nanochat/profiling.py
Normal file
199
nanochat/profiling.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Profiling utilities for nanochat training.
|
||||
|
||||
Provides unified profiling interface that activates both PyTorch profiler (traces)
|
||||
and CUDA memory snapshots for different training phases.
|
||||
"""
|
||||
|
||||
import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
from torch.profiler import profile, ProfilerActivity
|
||||
from typing import Optional, Callable
|
||||
|
||||
|
||||
class ProfilingManager:
|
||||
"""Manages profiling for different phases of training."""
|
||||
|
||||
# ANSI color codes
|
||||
CYAN = "\033[36m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: str,
|
||||
ddp_local_rank: int,
|
||||
master_process: bool,
|
||||
enable_profiling: bool = False,
|
||||
print_fn: Callable = print,
|
||||
):
|
||||
self.base_dir = base_dir
|
||||
self.ddp_local_rank = ddp_local_rank
|
||||
self.master_process = master_process
|
||||
self.enable_profiling = enable_profiling
|
||||
self.max_mem_events_per_snapshot = 200000
|
||||
self.print_fn = print_fn
|
||||
|
||||
# Colored prefix for profiler messages
|
||||
self.prefix = f"{self.CYAN}[PROFILER]{self.RESET}"
|
||||
|
||||
# Create timestamped subdirectory for this run to avoid overwriting previous profiles
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.profile_dir = os.path.join(base_dir, "profile_traces", timestamp)
|
||||
if enable_profiling and master_process:
|
||||
os.makedirs(self.profile_dir, exist_ok=True)
|
||||
self.print_fn(f"{self.prefix} Output directory: {self.profile_dir}")
|
||||
|
||||
self.active_profiler: Optional[profile] = None
|
||||
self.cuda_memory_recording = False
|
||||
self.current_phase: Optional[str] = None # Track current profiling phase
|
||||
|
||||
# Stage mapping for ordered file naming
|
||||
self.stage_map = {
|
||||
"model_init": "stage1",
|
||||
"eval_bpb": "stage2",
|
||||
"training_microsteps": "stage3",
|
||||
"optimizer_step": "stage4",
|
||||
}
|
||||
|
||||
def _get_stage_prefix(self, phase_name: str) -> str:
|
||||
"""Get stage prefix for a phase name."""
|
||||
return self.stage_map.get(phase_name, "stageX")
|
||||
|
||||
def _get_log_prefix(self, phase_name: str) -> str:
|
||||
"""Get colored log prefix with phase name."""
|
||||
return f"{self.CYAN}[PROFILER:{phase_name}]{self.RESET}"
|
||||
|
||||
def start_torch_profiler(self, phase_name: str, warmup: int = 0, active: int = 1, repeat: int = 1):
|
||||
"""
|
||||
Start PyTorch profiler for a specific phase.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the profiling phase (e.g., "model_init", "eval_bpb")
|
||||
warmup: Number of warmup steps before active profiling (default: 0)
|
||||
active: Number of active profiling steps that capture traces (default: 1)
|
||||
repeat: Number of times to repeat the profiling cycle (default: 1)
|
||||
|
||||
The profiler runs according to its schedule (warmup + active steps) and
|
||||
auto-completes after the scheduled steps. No need to call stop() explicitly.
|
||||
Call step_torch_profiler() on each iteration to advance the profiler.
|
||||
"""
|
||||
if not self.enable_profiling or not self.master_process:
|
||||
return None
|
||||
|
||||
stage_prefix = self._get_stage_prefix(phase_name)
|
||||
log_prefix = self._get_log_prefix(phase_name)
|
||||
|
||||
def trace_handler(p):
|
||||
output_path = os.path.join(self.profile_dir, f"{stage_prefix}-{phase_name}_trace.json")
|
||||
self.print_fn(f"{log_prefix} Exporting Chrome trace to: {output_path}")
|
||||
p.export_chrome_trace(output_path)
|
||||
memory_path = os.path.join(self.profile_dir, f"{stage_prefix}-{phase_name}_memory_timeline.html")
|
||||
self.print_fn(f"{log_prefix} Exporting memory timeline to: {memory_path}")
|
||||
p.export_memory_timeline(memory_path, device=f"cuda:{self.ddp_local_rank}")
|
||||
self.print_fn(f"{log_prefix} Trace export complete")
|
||||
|
||||
prof = profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=torch.profiler.schedule(wait=0, warmup=warmup, active=active, repeat=repeat),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
with_stack=True,
|
||||
on_trace_ready=trace_handler,
|
||||
)
|
||||
prof.start()
|
||||
self.print_fn(f"{log_prefix} Torch profiler started: warmup={warmup}, active={active}, repeat={repeat}")
|
||||
self.active_profiler = prof
|
||||
self.current_phase = phase_name # Track current phase
|
||||
return prof
|
||||
|
||||
def stop_torch_profiler(self):
|
||||
"""
|
||||
Stop the active PyTorch profiler (optional - profiler stops automatically after active steps).
|
||||
Only needed if you want to stop profiling early before the schedule completes.
|
||||
"""
|
||||
if self.active_profiler is not None:
|
||||
log_prefix = self._get_log_prefix(self.current_phase) if self.current_phase else self.prefix
|
||||
self.active_profiler.stop()
|
||||
self.print_fn(f"{log_prefix} Torch profiler stopped (early)")
|
||||
self.active_profiler = None
|
||||
self.current_phase = None # Clear current phase
|
||||
|
||||
def step_torch_profiler(self):
|
||||
"""Step the active PyTorch profiler."""
|
||||
if self.active_profiler is not None:
|
||||
self.active_profiler.step()
|
||||
|
||||
def start_cuda_memory_recording(self, phase_name: Optional[str] = None):
|
||||
"""Start CUDA memory snapshot recording."""
|
||||
if not self.enable_profiling or not self.master_process:
|
||||
return
|
||||
|
||||
if not self.cuda_memory_recording:
|
||||
if phase_name:
|
||||
self.current_phase = phase_name # Track phase if provided
|
||||
log_prefix = self._get_log_prefix(self.current_phase) if self.current_phase else self.prefix
|
||||
self.print_fn(f"{log_prefix} Starting CUDA memory snapshot recording with max_entries={self.max_mem_events_per_snapshot}")
|
||||
torch.cuda.memory._record_memory_history(
|
||||
max_entries=self.max_mem_events_per_snapshot
|
||||
)
|
||||
self.cuda_memory_recording = True
|
||||
|
||||
def dump_cuda_memory_snapshot(self, phase_name: str):
|
||||
"""Dump and stop CUDA memory snapshot."""
|
||||
if not self.enable_profiling or not self.master_process:
|
||||
return
|
||||
|
||||
if self.cuda_memory_recording:
|
||||
stage_prefix = self._get_stage_prefix(phase_name)
|
||||
log_prefix = self._get_log_prefix(phase_name)
|
||||
snapshot_path = os.path.join(self.profile_dir, f"{stage_prefix}-{phase_name}_mem.pickle")
|
||||
self.print_fn(f"{log_prefix} Dumping CUDA memory snapshot to: {snapshot_path}")
|
||||
torch.cuda.memory._dump_snapshot(snapshot_path)
|
||||
self.print_fn(f"{log_prefix} Stopping CUDA memory snapshot recording")
|
||||
torch.cuda.memory._record_memory_history(enabled=None)
|
||||
self.cuda_memory_recording = False
|
||||
self.print_fn(f"{log_prefix} CUDA memory snapshot complete")
|
||||
|
||||
def profile_section(self, phase_name: str, warmup: int = 0, active: int = 1, repeat: int = 1):
|
||||
"""
|
||||
Context manager for profiling a section of code.
|
||||
|
||||
Args:
|
||||
phase_name: Name of the profiling phase
|
||||
warmup: Number of warmup steps before active profiling (default: 0)
|
||||
active: Number of active profiling steps that capture traces (default: 1)
|
||||
repeat: Number of times to repeat the profiling cycle (default: 1)
|
||||
|
||||
Usage:
|
||||
with profiler.profile_section("model_loading", warmup=0, active=1):
|
||||
# code to profile
|
||||
pass
|
||||
"""
|
||||
return ProfilingContext(self, phase_name, warmup, active, repeat)
|
||||
|
||||
|
||||
class ProfilingContext:
|
||||
"""Context manager for profiling a specific section."""
|
||||
|
||||
def __init__(self, manager: ProfilingManager, phase_name: str, warmup: int, active: int, repeat: int):
|
||||
self.manager = manager
|
||||
self.phase_name = phase_name
|
||||
self.warmup = warmup
|
||||
self.active = active
|
||||
self.repeat = repeat
|
||||
|
||||
def __enter__(self):
|
||||
self.manager.start_cuda_memory_recording(self.phase_name)
|
||||
self.manager.start_torch_profiler(self.phase_name, warmup=self.warmup, active=self.active, repeat=self.repeat)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Profiler auto-completes after scheduled steps, no need to stop explicitly
|
||||
self.manager.dump_cuda_memory_snapshot(self.phase_name)
|
||||
|
||||
def step(self):
|
||||
"""Call this at the end of each iteration in the profiled section."""
|
||||
self.manager.step_torch_profiler()
|
||||
# Profiler will auto-complete after warmup + active steps
|
||||
|
||||
|
|
@ -8,6 +8,7 @@ dependencies = [
|
|||
"datasets>=4.0.0",
|
||||
"fastapi>=0.117.1",
|
||||
"files-to-prompt>=0.6",
|
||||
"matplotlib>=3.0.0",
|
||||
"numpy==1.26.4",
|
||||
"psutil>=7.1.0",
|
||||
"regex>=2025.9.1",
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.profiling import ProfilingManager
|
||||
from scripts.base_eval import evaluate_model
|
||||
print_banner()
|
||||
|
||||
|
|
@ -48,6 +49,9 @@ eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
|||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Profiling configuration (output files will be placed in ~/.cache/nanochat/profile_traces/<timestamp>/ by default)
|
||||
# Master switch: enables both PyTorch profiler (traces) and CUDA memory profiler (snapshots)
|
||||
enable_profiling = False
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
|
|
@ -61,6 +65,18 @@ ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
|||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Get base directory early for profiling setup
|
||||
base_dir = get_base_dir()
|
||||
|
||||
# Initialize profiling manager
|
||||
profiler = ProfilingManager(
|
||||
base_dir=base_dir,
|
||||
ddp_local_rank=ddp_local_rank,
|
||||
master_process=master_process,
|
||||
enable_profiling=enable_profiling,
|
||||
print_fn=print0,
|
||||
)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
|
|
@ -93,6 +109,10 @@ print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {
|
|||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
# Start profiling model initialization
|
||||
if enable_profiling:
|
||||
profiler.start_cuda_memory_recording("model_init")
|
||||
profiler.start_torch_profiler("model_init", warmup=0, active=1)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
|
|
@ -100,6 +120,10 @@ model.to_empty(device="cuda")
|
|||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
# Complete profiling model initialization
|
||||
if enable_profiling:
|
||||
profiler.step_torch_profiler()
|
||||
profiler.dump_cuda_memory_snapshot("model_init")
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -131,7 +155,6 @@ optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=
|
|||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
|
|
@ -179,7 +202,9 @@ for step in range(num_iterations + 1):
|
|||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
with autocast_ctx:
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
# Pass profiler for first evaluation if profiling is enabled
|
||||
prof_arg = profiler if (enable_profiling and step == 0) else None
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes, profiler=prof_arg)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.4f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
|
|
@ -254,6 +279,13 @@ for step in range(num_iterations + 1):
|
|||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
|
||||
# Profile micro-steps if enabled (only for first 10 steps)
|
||||
profile_ctx = None
|
||||
if enable_profiling and step == 0:
|
||||
profile_ctx = profiler.profile_section("training_microsteps", warmup=1, active=10)
|
||||
profile_ctx.__enter__()
|
||||
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
loss = model(x, y)
|
||||
|
|
@ -261,6 +293,19 @@ for step in range(num_iterations + 1):
|
|||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
if profile_ctx is not None:
|
||||
profile_ctx.step()
|
||||
|
||||
# Close profiling context if it was opened
|
||||
if profile_ctx is not None:
|
||||
profile_ctx.__exit__(None, None, None)
|
||||
|
||||
# Start optimizer step profiling if enabled
|
||||
optimizer_profile_ctx = None
|
||||
if enable_profiling and step == 0:
|
||||
optimizer_profile_ctx = profiler.profile_section("optimizer_step", warmup=0, active=1)
|
||||
optimizer_profile_ctx.__enter__()
|
||||
|
||||
# gradient clipping (TODO possibly expertiment with)
|
||||
if grad_clip > 0.0:
|
||||
torch.nn.utils.clip_grad_norm_(orig_model.parameters(), grad_clip)
|
||||
|
|
@ -276,6 +321,11 @@ for step in range(num_iterations + 1):
|
|||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Step and close optimizer profiling if active
|
||||
if optimizer_profile_ctx is not None:
|
||||
optimizer_profile_ctx.step()
|
||||
optimizer_profile_ctx.__exit__(None, None, None)
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user