mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
big change: add pretraining resumption logic so that checkpoints can now be approximately resumed and training can continue. this is useful for very long runs when you don't want the anxiety of your run crashing for some reason. alternatively, it's a way to recover training in the event of loss spikes. i mean, this should have been there in v0 but it's ok. the resumption is approximate to control complexity and bloat, but it's possible we want to change that in the future. to use, set --save_every to a step interval to write checkpoints with, and then use --resume_from_step to resume optimization from a given step. only base model training (pretraining) supports this atm, but it's ok because midtraining is comparably quite a bit faster.
This commit is contained in:
parent
91f09ccd0d
commit
c6abcdfe3a
|
|
@ -20,33 +20,32 @@ def log0(message):
|
|||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logger.info(message)
|
||||
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||
if rank == 0:
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
# Save the model state (parameters)
|
||||
# Save the model state parameters
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
torch.save(model_data, model_path)
|
||||
log0(f"Saved model file to: {model_path}")
|
||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
||||
logger.info(f"Saved model parameters to: {model_path}")
|
||||
# Save the metadata dict as json
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
with open(meta_path, "w", encoding="utf-8") as f:
|
||||
json.dump(meta_data, f, indent=2)
|
||||
log0(f"Saved metadata file to: {meta_path}")
|
||||
logger.info(f"Saved metadata to: {meta_path}")
|
||||
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||
if optimizer_data is not None:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
torch.save(optimizer_data, optimizer_path)
|
||||
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||
# Load the model state
|
||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||
model_data = torch.load(model_path, map_location=device)
|
||||
# Load the optimizer state if requested
|
||||
optimizer_data = None
|
||||
if load_optimizer:
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||
# Load the metadata
|
||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||
|
|
|
|||
|
|
@ -148,6 +148,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
||||
torch.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
|
|
|
|||
|
|
@ -1,49 +1,87 @@
|
|||
from collections import deque
|
||||
|
||||
import torch
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.dataset import list_parquet_files
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||
"""
|
||||
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||
Note that this resumption is atm only *approximate* for simplicity.
|
||||
We won't repeat the same documents but we might skip a few.
|
||||
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||
|
||||
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
# infinite iterator over document batches (list of text strings)
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
def document_batches():
|
||||
parquet_paths = list_parquet_files()
|
||||
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||
while True: # iterate infinitely (multi-epoch)
|
||||
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||
filepath = parquet_paths[pq_idx]
|
||||
pf = pq.ParquetFile(filepath)
|
||||
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||
if resume_rg_idx is not None:
|
||||
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||
else:
|
||||
rg_idx = ddp_rank
|
||||
while rg_idx < pf.num_row_groups:
|
||||
rg = pf.read_row_group(rg_idx)
|
||||
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||
pq_idx += 1 # advance to the next parquet file
|
||||
batches = document_batches()
|
||||
|
||||
# Now emit batches of tokens.
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
while True:
|
||||
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
||||
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
||||
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
||||
for i in range(0, len(batch), tokenizer_batch_size):
|
||||
yield batch[i:i+tokenizer_batch_size]
|
||||
batches = document_batches()
|
||||
|
||||
batch_index = 0
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding.
|
||||
while len(token_buffer) < needed_tokens:
|
||||
doc_batch = next(batches)
|
||||
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
|
||||
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||
use_cuda_optimizations = device == "cuda"
|
||||
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
inputs_cpu = scratch[:-1]
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||
yield inputs, targets, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||
# helper function that only emits the inputs/targets and not the state_dict
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -20,10 +20,10 @@ import wandb
|
|||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from scripts.base_eval import evaluate_model
|
||||
|
|
@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
|||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||
# Output
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
|
|
@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
|||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
|
||||
# Create a new model with random weights
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
||||
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||
base_dir = get_base_dir()
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
resuming = resume_from_step != -1
|
||||
if resuming:
|
||||
print0(f"Resuming optimization from step {resume_from_step}")
|
||||
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
del model_data # free up this memory after the copy
|
||||
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
|||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||
adamw_optimizer, muon_optimizer = optimizers
|
||||
|
||||
if resuming:
|
||||
for opt, dat in zip(optimizers, optimizer_data):
|
||||
opt.load_state_dict(dat)
|
||||
del optimizer_data # free up the memory
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Set up hyperparameter schedulers
|
||||
|
|
@ -172,14 +195,24 @@ def get_muon_momentum(it):
|
|||
return momentum
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
# Loop state (variables updated by the training loop)
|
||||
|
||||
if not resuming:
|
||||
step = 0
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
# note that we run +1 steps only so that we can eval and save at the end
|
||||
for step in range(num_iterations + 1):
|
||||
last_step = step == num_iterations
|
||||
else:
|
||||
step = meta_data["step"]
|
||||
loop_state = meta_data["loop_state"]
|
||||
min_val_bpb = loop_state["min_val_bpb"]
|
||||
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||
total_training_time = loop_state["total_training_time"]
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
|
|
@ -237,25 +270,31 @@ for step in range(num_iterations + 1):
|
|||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (only on master process)
|
||||
if master_process and last_step:
|
||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||
save_checkpoint(
|
||||
checkpoint_dir,
|
||||
step,
|
||||
orig_model.state_dict(),
|
||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
||||
{
|
||||
orig_model.state_dict(), # model parameters
|
||||
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||
{ # metadata saved as json
|
||||
"step": step,
|
||||
"val_bpb": val_bpb, # loss at last step
|
||||
"model_config": model_config_kwargs,
|
||||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
"dataloader_state_dict": dataloader_state_dict,
|
||||
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||
"min_val_bpb": min_val_bpb,
|
||||
"smooth_train_loss": smooth_train_loss,
|
||||
"total_training_time": total_training_time,
|
||||
},
|
||||
},
|
||||
rank=ddp_rank,
|
||||
)
|
||||
|
||||
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||
if last_step:
|
||||
break
|
||||
|
||||
|
|
@ -270,7 +309,7 @@ for step in range(num_iterations + 1):
|
|||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# gradient clipping
|
||||
grad_clip_enabled = grad_clip > 0.0
|
||||
if grad_clip_enabled:
|
||||
|
|
@ -293,6 +332,7 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
|
||||
# logging
|
||||
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * step / num_iterations
|
||||
|
|
@ -319,6 +359,9 @@ for step in range(num_iterations + 1):
|
|||
log_data["train/grad_norm"] = grad_norm
|
||||
wandb_run.log(log_data)
|
||||
|
||||
# state update
|
||||
step += 1
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user