mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
big change: add pretraining resumption logic so that checkpoints can now be approximately resumed and training can continue. this is useful for very long runs when you don't want the anxiety of your run crashing for some reason. alternatively, it's a way to recover training in the event of loss spikes. i mean, this should have been there in v0 but it's ok. the resumption is approximate to control complexity and bloat, but it's possible we want to change that in the future. to use, set --save_every to a step interval to write checkpoints with, and then use --resume_from_step to resume optimization from a given step. only base model training (pretraining) supports this atm, but it's ok because midtraining is comparably quite a bit faster.
This commit is contained in:
parent
91f09ccd0d
commit
c6abcdfe3a
|
|
@ -20,33 +20,32 @@ def log0(message):
|
||||||
if int(os.environ.get('RANK', 0)) == 0:
|
if int(os.environ.get('RANK', 0)) == 0:
|
||||||
logger.info(message)
|
logger.info(message)
|
||||||
|
|
||||||
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
|
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
|
||||||
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
|
if rank == 0:
|
||||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||||
# Save the model state (parameters)
|
# Save the model state parameters
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
torch.save(model_data, model_path)
|
torch.save(model_data, model_path)
|
||||||
log0(f"Saved model file to: {model_path}")
|
logger.info(f"Saved model parameters to: {model_path}")
|
||||||
# Save the optimizer state (useful for SFT or any other fine-tuning)
|
|
||||||
if optimizer_data is not None:
|
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
|
||||||
torch.save(optimizer_data, optimizer_path)
|
|
||||||
log0(f"Saved optimizer file to: {optimizer_path}")
|
|
||||||
# Save the metadata dict as json
|
# Save the metadata dict as json
|
||||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||||
with open(meta_path, "w", encoding="utf-8") as f:
|
with open(meta_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(meta_data, f, indent=2)
|
json.dump(meta_data, f, indent=2)
|
||||||
log0(f"Saved metadata file to: {meta_path}")
|
logger.info(f"Saved metadata to: {meta_path}")
|
||||||
|
# Note that optimizer state is sharded across ranks, so each rank must save its own.
|
||||||
|
if optimizer_data is not None:
|
||||||
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
|
torch.save(optimizer_data, optimizer_path)
|
||||||
|
logger.info(f"Saved optimizer state to: {optimizer_path}")
|
||||||
|
|
||||||
|
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
|
||||||
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
|
|
||||||
# Load the model state
|
# Load the model state
|
||||||
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
|
||||||
model_data = torch.load(model_path, map_location=device)
|
model_data = torch.load(model_path, map_location=device)
|
||||||
# Load the optimizer state if requested
|
# Load the optimizer state if requested
|
||||||
optimizer_data = None
|
optimizer_data = None
|
||||||
if load_optimizer:
|
if load_optimizer:
|
||||||
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
|
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
|
||||||
optimizer_data = torch.load(optimizer_path, map_location=device)
|
optimizer_data = torch.load(optimizer_path, map_location=device)
|
||||||
# Load the metadata
|
# Load the metadata
|
||||||
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
|
||||||
|
|
|
||||||
|
|
@ -148,6 +148,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||||
|
|
||||||
# Reproducibility
|
# Reproducibility
|
||||||
|
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
|
||||||
|
# The only place where global rng might be used is nn.Module initialization of the model weights.
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
torch.cuda.manual_seed(42)
|
torch.cuda.manual_seed(42)
|
||||||
|
|
|
||||||
|
|
@ -1,49 +1,87 @@
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
|
||||||
from nanochat.common import get_dist_info
|
from nanochat.common import get_dist_info
|
||||||
from nanochat.dataset import parquets_iter_batched
|
from nanochat.dataset import list_parquet_files
|
||||||
from nanochat.tokenizer import get_tokenizer
|
from nanochat.tokenizer import get_tokenizer
|
||||||
|
|
||||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
|
||||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
"""
|
||||||
|
Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||||
|
|
||||||
|
This implementation became a bit more complex because we wish to support approximate resume training.
|
||||||
|
Instead of turning this into a Class, we opt to return the state_dict with every batch,
|
||||||
|
and then the caller can pass in a state_dict to resume training from a desired point.
|
||||||
|
Note that this resumption is atm only *approximate* for simplicity.
|
||||||
|
We won't repeat the same documents but we might skip a few.
|
||||||
|
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
|
||||||
|
|
||||||
|
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
|
||||||
|
"""
|
||||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||||
|
|
||||||
|
# infinite iterator over document batches (list of text strings)
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||||
|
def document_batches():
|
||||||
|
parquet_paths = list_parquet_files()
|
||||||
|
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
|
||||||
|
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
|
||||||
|
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
|
||||||
|
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
|
||||||
|
while True: # iterate infinitely (multi-epoch)
|
||||||
|
while pq_idx < len(parquet_paths): # iterate over all parquet files
|
||||||
|
filepath = parquet_paths[pq_idx]
|
||||||
|
pf = pq.ParquetFile(filepath)
|
||||||
|
# Start from resume point if resuming on same file, otherwise from DDP rank
|
||||||
|
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
|
||||||
|
if resume_rg_idx is not None:
|
||||||
|
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
|
||||||
|
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
|
||||||
|
rg_idx = base_idx * ddp_world_size + ddp_rank
|
||||||
|
resume_rg_idx = None # set to None as we only want to do this a single time
|
||||||
|
else:
|
||||||
|
rg_idx = ddp_rank
|
||||||
|
while rg_idx < pf.num_row_groups:
|
||||||
|
rg = pf.read_row_group(rg_idx)
|
||||||
|
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
|
||||||
|
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
|
||||||
|
for i in range(0, len(batch), tokenizer_batch_size):
|
||||||
|
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
|
||||||
|
rg_idx += ddp_world_size # advance to the next row group (in DDP)
|
||||||
|
pq_idx += 1 # advance to the next parquet file
|
||||||
|
batches = document_batches()
|
||||||
|
|
||||||
|
# Now emit batches of tokens.
|
||||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||||
# get the tokenizer and the bos token
|
# get the tokenizer and the bos token
|
||||||
tokenizer = get_tokenizer()
|
tokenizer = get_tokenizer()
|
||||||
bos_token = tokenizer.get_bos_token_id()
|
bos_token = tokenizer.get_bos_token_id()
|
||||||
# scratch buffer holds the tokens for one iteration
|
# scratch buffer holds the tokens for one iteration
|
||||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||||
|
|
||||||
# infinite iterator over document batches
|
|
||||||
def document_batches():
|
|
||||||
while True:
|
|
||||||
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
|
|
||||||
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
|
|
||||||
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
|
|
||||||
for i in range(0, len(batch), tokenizer_batch_size):
|
|
||||||
yield batch[i:i+tokenizer_batch_size]
|
|
||||||
batches = document_batches()
|
|
||||||
|
|
||||||
batch_index = 0
|
|
||||||
while True:
|
while True:
|
||||||
# Accumulate enough tokens for one iteration before yielding.
|
# Accumulate enough tokens for one iteration before yielding.
|
||||||
while len(token_buffer) < needed_tokens:
|
while len(token_buffer) < needed_tokens:
|
||||||
doc_batch = next(batches)
|
doc_batch, (pq_idx, rg_idx) = next(batches)
|
||||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||||
for tokens in token_lists:
|
for tokens in token_lists:
|
||||||
token_buffer.extend(tokens)
|
token_buffer.extend(tokens)
|
||||||
batch_index += 1
|
|
||||||
# Move tokens from the deque into the scratch buffer
|
# Move tokens from the deque into the scratch buffer
|
||||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||||
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
|
||||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
|
use_cuda_optimizations = device == "cuda"
|
||||||
|
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
|
||||||
# Create the inputs/targets as 1D tensors
|
# Create the inputs/targets as 1D tensors
|
||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1]
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
# Reshape to 2D and move to GPU async
|
# Reshape to 2D and move to GPU async
|
||||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
|
||||||
|
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
|
||||||
|
yield inputs, targets, state_dict
|
||||||
|
|
||||||
|
def tokenizing_distributed_data_loader(*args, **kwargs):
|
||||||
|
# helper function that only emits the inputs/targets and not the state_dict
|
||||||
|
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|
|
||||||
|
|
@ -20,10 +20,10 @@ import wandb
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from scripts.base_eval import evaluate_model
|
from scripts.base_eval import evaluate_model
|
||||||
|
|
@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
||||||
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
warmup_ratio = 0.0 # ratio of iterations for LR warmup
|
||||||
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
|
||||||
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
|
||||||
|
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
|
||||||
# Evaluation
|
# Evaluation
|
||||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||||
sample_every = 2000 # every how many steps to sample from the model
|
sample_every = 2000 # every how many steps to sample from the model
|
||||||
|
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
|
||||||
# Output
|
# Output
|
||||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||||
# now allow CLI to override the settings via the configurator lol
|
# now allow CLI to override the settings via the configurator lol
|
||||||
|
|
@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the Model
|
# Initialize the Model
|
||||||
|
|
||||||
|
# Create a new model with random weights
|
||||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||||
with torch.device("meta"):
|
with torch.device("meta"):
|
||||||
model_config = GPTConfig(**model_config_kwargs)
|
model_config = GPTConfig(**model_config_kwargs)
|
||||||
model = GPT(model_config)
|
model = GPT(model_config)
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
|
||||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
# If we are resuming, overwrite the model parameters with those of the checkpoint
|
||||||
|
base_dir = get_base_dir()
|
||||||
|
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
||||||
|
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
||||||
|
resuming = resume_from_step != -1
|
||||||
|
if resuming:
|
||||||
|
print0(f"Resuming optimization from step {resume_from_step}")
|
||||||
|
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
|
||||||
|
model.load_state_dict(model_data, strict=True, assign=True)
|
||||||
|
del model_data # free up this memory after the copy
|
||||||
|
|
||||||
|
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
|
||||||
|
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
|
||||||
num_params = sum(p.numel() for p in model.parameters())
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
print0(f"Number of parameters: {num_params:,}")
|
print0(f"Number of parameters: {num_params:,}")
|
||||||
num_flops_per_token = model.estimate_flops()
|
num_flops_per_token = model.estimate_flops()
|
||||||
|
|
@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
|
||||||
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
|
||||||
adamw_optimizer, muon_optimizer = optimizers
|
adamw_optimizer, muon_optimizer = optimizers
|
||||||
|
|
||||||
|
if resuming:
|
||||||
|
for opt, dat in zip(optimizers, optimizer_data):
|
||||||
|
opt.load_state_dict(dat)
|
||||||
|
del optimizer_data # free up the memory
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the DataLoaders for train/val
|
# Initialize the DataLoaders for train/val
|
||||||
base_dir = get_base_dir()
|
|
||||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||||
|
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Set up hyperparameter schedulers
|
# Set up hyperparameter schedulers
|
||||||
|
|
@ -172,14 +195,24 @@ def get_muon_momentum(it):
|
||||||
return momentum
|
return momentum
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Training loop
|
# Loop state (variables updated by the training loop)
|
||||||
|
|
||||||
|
if not resuming:
|
||||||
|
step = 0
|
||||||
min_val_bpb = float("inf")
|
min_val_bpb = float("inf")
|
||||||
smooth_train_loss = 0 # EMA of training loss
|
smooth_train_loss = 0 # EMA of training loss
|
||||||
ema_beta = 0.9 # EMA decay factor
|
|
||||||
total_training_time = 0 # total wall-clock time of training
|
total_training_time = 0 # total wall-clock time of training
|
||||||
# note that we run +1 steps only so that we can eval and save at the end
|
else:
|
||||||
for step in range(num_iterations + 1):
|
step = meta_data["step"]
|
||||||
last_step = step == num_iterations
|
loop_state = meta_data["loop_state"]
|
||||||
|
min_val_bpb = loop_state["min_val_bpb"]
|
||||||
|
smooth_train_loss = loop_state["smooth_train_loss"]
|
||||||
|
total_training_time = loop_state["total_training_time"]
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# Training loop
|
||||||
|
while True:
|
||||||
|
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||||
flops_so_far = num_flops_per_token * total_batch_size * step
|
flops_so_far = num_flops_per_token * total_batch_size * step
|
||||||
|
|
||||||
# once in a while: evaluate the val bpb (all ranks participate)
|
# once in a while: evaluate the val bpb (all ranks participate)
|
||||||
|
|
@ -237,25 +270,31 @@ for step in range(num_iterations + 1):
|
||||||
print0(tokenizer.decode(sample[0]))
|
print0(tokenizer.decode(sample[0]))
|
||||||
model.train()
|
model.train()
|
||||||
|
|
||||||
# save checkpoint at the end of the run (only on master process)
|
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||||
if master_process and last_step:
|
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
|
||||||
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
|
|
||||||
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
|
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
checkpoint_dir,
|
checkpoint_dir,
|
||||||
step,
|
step,
|
||||||
orig_model.state_dict(),
|
orig_model.state_dict(), # model parameters
|
||||||
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
|
[opt.state_dict() for opt in optimizers], # optimizer states
|
||||||
{
|
{ # metadata saved as json
|
||||||
"step": step,
|
"step": step,
|
||||||
"val_bpb": val_bpb, # loss at last step
|
"val_bpb": val_bpb, # loss at last step
|
||||||
"model_config": model_config_kwargs,
|
"model_config": model_config_kwargs,
|
||||||
"user_config": user_config, # inputs to the training script
|
"user_config": user_config, # inputs to the training script
|
||||||
"device_batch_size": device_batch_size,
|
"device_batch_size": device_batch_size,
|
||||||
"max_seq_len": max_seq_len,
|
"max_seq_len": max_seq_len,
|
||||||
}
|
"dataloader_state_dict": dataloader_state_dict,
|
||||||
|
"loop_state": { # all loop state (other than step) so that we can resume training
|
||||||
|
"min_val_bpb": min_val_bpb,
|
||||||
|
"smooth_train_loss": smooth_train_loss,
|
||||||
|
"total_training_time": total_training_time,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
rank=ddp_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# termination conditions (TODO: possibly also add loss explosions etc.)
|
||||||
if last_step:
|
if last_step:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
@ -270,7 +309,7 @@ for step in range(num_iterations + 1):
|
||||||
train_loss = loss.detach() # for logging
|
train_loss = loss.detach() # for logging
|
||||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||||
loss.backward()
|
loss.backward()
|
||||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||||
# gradient clipping
|
# gradient clipping
|
||||||
grad_clip_enabled = grad_clip > 0.0
|
grad_clip_enabled = grad_clip > 0.0
|
||||||
if grad_clip_enabled:
|
if grad_clip_enabled:
|
||||||
|
|
@ -293,6 +332,7 @@ for step in range(num_iterations + 1):
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
# logging
|
# logging
|
||||||
|
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
|
||||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||||
pct_done = 100 * step / num_iterations
|
pct_done = 100 * step / num_iterations
|
||||||
|
|
@ -319,6 +359,9 @@ for step in range(num_iterations + 1):
|
||||||
log_data["train/grad_norm"] = grad_norm
|
log_data["train/grad_norm"] = grad_norm
|
||||||
wandb_run.log(log_data)
|
wandb_run.log(log_data)
|
||||||
|
|
||||||
|
# state update
|
||||||
|
step += 1
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user