mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
update the midtraining script too
This commit is contained in:
parent
df600b6ed5
commit
ae02650afe
|
|
@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import time
|
import time
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -30,6 +30,7 @@ from tasks.smoltalk import SmolTalk
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||||
|
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||||
step = None # step to load the model from (base model or midtrained model)
|
step = None # step to load the model from (base model or midtrained model)
|
||||||
dtype = "bfloat16"
|
dtype = "bfloat16"
|
||||||
|
|
@ -40,7 +41,7 @@ embedding_lr = 0.2
|
||||||
matrix_lr = 0.02
|
matrix_lr = 0.02
|
||||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||||
weight_decay = 0.0
|
weight_decay = 0.0
|
||||||
eval_every = 150
|
eval_every = 150 # -1 = disable
|
||||||
eval_tokens = 20*524288
|
eval_tokens = 20*524288
|
||||||
total_batch_size = 524288
|
total_batch_size = 524288
|
||||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||||
|
|
@ -50,10 +51,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||||
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||||
master_process = ddp_rank == 0
|
master_process = ddp_rank == 0
|
||||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||||
|
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
use_dummy_wandb = run == "dummy" or not master_process
|
||||||
|
|
@ -168,7 +171,7 @@ while True:
|
||||||
last_step = bool(last_step_tensor.item())
|
last_step = bool(last_step_tensor.item())
|
||||||
|
|
||||||
# once in a while: evaluate the val bpb (all ranks participate)
|
# once in a while: evaluate the val bpb (all ranks participate)
|
||||||
if last_step or step % eval_every == 0:
|
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loader = build_val_loader()
|
val_loader = build_val_loader()
|
||||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||||
|
|
@ -215,7 +218,7 @@ while True:
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
# single training step
|
# single training step
|
||||||
# evaluate the gradient
|
# evaluate the gradient
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
for micro_step in range(grad_accum_steps):
|
for micro_step in range(grad_accum_steps):
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
|
|
@ -236,7 +239,7 @@ while True:
|
||||||
for opt in optimizers:
|
for opt in optimizers:
|
||||||
opt.step()
|
opt.step()
|
||||||
model.zero_grad(set_to_none=True)
|
model.zero_grad(set_to_none=True)
|
||||||
torch.cuda.synchronize()
|
synchronize()
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
dt = t1 - t0
|
dt = t1 - t0
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
@ -268,7 +271,7 @@ while True:
|
||||||
})
|
})
|
||||||
|
|
||||||
# print a few more stats
|
# print a few more stats
|
||||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user