From ae02650afe5f60ae2ccda19a5f5cd57d16ea312b Mon Sep 17 00:00:00 2001 From: karpathy Date: Thu, 16 Oct 2025 16:33:17 -0700 Subject: [PATCH] update the midtraining script too --- scripts/mid_train.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 90ab954..3a90a9c 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch - -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir +from contextlib import nullcontext +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -30,6 +30,7 @@ from tasks.smoltalk import SmolTalk # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +device_type = "" # cuda|cpu|mps (empty => autodetect) model_tag = None # model tag to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model) dtype = "bfloat16" @@ -40,7 +41,7 @@ embedding_lr = 0.2 matrix_lr = 0.02 init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate weight_decay = 0.0 -eval_every = 150 +eval_every = 150 # -1 = disable eval_tokens = 20*524288 total_batch_size = 524288 dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report @@ -50,10 +51,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi # ----------------------------------------------------------------------------- # Compute init -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +device_type = autodetect_device_type() if device_type == "" else device_type +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -168,7 +171,7 @@ while True: last_step = bool(last_step_tensor.item()) # once in a while: evaluate the val bpb (all ranks participate) - if last_step or step % eval_every == 0: + if eval_every > 0 and (last_step or step % eval_every == 0): model.eval() val_loader = build_val_loader() eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) @@ -215,7 +218,7 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -236,7 +239,7 @@ while True: for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -268,7 +271,7 @@ while True: }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}")