mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge 70d0abe432 into 1b1cc3c599
This commit is contained in:
commit
05f5849d04
|
|
@ -22,7 +22,7 @@ dependencies = [
|
|||
"torch==2.9.1",
|
||||
"transformers>=4.57.3",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"wandb>=0.25.0",
|
||||
"zstandard>=0.25.0",
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model
|
|||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compute init and wandb logging
|
||||
# Compute init
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
|
@ -95,10 +95,6 @@ else:
|
|||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
|
|
@ -384,6 +380,24 @@ def get_muon_momentum(it):
|
|||
def get_weight_decay(it):
|
||||
return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations))
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
if use_dummy_wandb:
|
||||
wandb_run = DummyWandb()
|
||||
else:
|
||||
realized_config = dict(
|
||||
device_type=device_type,
|
||||
compute_dtype=str(COMPUTE_DTYPE),
|
||||
target_tokens=target_tokens,
|
||||
total_batch_size=total_batch_size,
|
||||
batch_lr_scale=batch_lr_scale,
|
||||
weight_decay_scaled=weight_decay_scaled,
|
||||
num_iterations=num_iterations,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
wandb_run = wandb.init(project="nanochat", name=args.run, config=user_config | realized_config)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
|
||||
|
|
@ -427,11 +441,10 @@ while True:
|
|||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
}, step)
|
||||
model.train()
|
||||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
|
|
@ -444,11 +457,10 @@ while True:
|
|||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"core_metric": results["core_metric"],
|
||||
"centered_results": results["centered_results"],
|
||||
})
|
||||
}, step)
|
||||
model.train()
|
||||
|
||||
# once in a while: sample from the model (only on master process)
|
||||
|
|
@ -565,8 +577,7 @@ while True:
|
|||
epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}"
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}")
|
||||
if step % 100 == 0:
|
||||
log_data = {
|
||||
"step": step,
|
||||
wandb_run.log({
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
|
|
@ -575,8 +586,7 @@ while True:
|
|||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": epoch,
|
||||
}
|
||||
wandb_run.log(log_data)
|
||||
}, step)
|
||||
|
||||
# state update
|
||||
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
||||
|
|
|
|||
|
|
@ -237,10 +237,7 @@ for step in range(num_steps):
|
|||
print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)]
|
||||
print0(f"Step {step} | {', '.join(print_passk)}")
|
||||
log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)}
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
**log_passk,
|
||||
})
|
||||
wandb_run.log(log_passk, step)
|
||||
|
||||
# Forward/Backward on rollouts over multiple examples in the dataset
|
||||
rewards_list = []
|
||||
|
|
@ -287,11 +284,6 @@ for step in range(num_steps):
|
|||
mean_reward = mean_reward_tensor.item()
|
||||
mean_sequence_length = mean_sequence_length_tensor.item()
|
||||
print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
})
|
||||
|
||||
# Update the model parameters
|
||||
lrm = get_lr_multiplier(step)
|
||||
|
|
@ -300,9 +292,10 @@ for step in range(num_steps):
|
|||
optimizer.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"reward": mean_reward,
|
||||
"sequence_length": mean_sequence_length,
|
||||
"lrm": lrm,
|
||||
})
|
||||
}, step)
|
||||
|
||||
# Master process saves the model once in a while. Skip first step. Save last step.
|
||||
if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1):
|
||||
|
|
|
|||
|
|
@ -84,10 +84,6 @@ if device_type == "cuda":
|
|||
else:
|
||||
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if not HAS_FA3:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
|
|
@ -116,6 +112,10 @@ for name, fallback, source in [
|
|||
else:
|
||||
print0(f"Using {name}={arg_val}")
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=args)
|
||||
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
|
|
@ -353,11 +353,10 @@ while True:
|
|||
if val_bpb < min_val_bpb:
|
||||
min_val_bpb = val_bpb
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"val/bpb": val_bpb,
|
||||
})
|
||||
}, step)
|
||||
model.train()
|
||||
|
||||
# once in a while: estimate the ChatCORE metric (all ranks participate)
|
||||
|
|
@ -387,12 +386,11 @@ while True:
|
|||
chatcore_cat = centered_mean(categorical_tasks)
|
||||
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"chatcore_metric": chatcore,
|
||||
"chatcore_cat": chatcore_cat,
|
||||
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
|
||||
})
|
||||
}, step)
|
||||
model.train()
|
||||
|
||||
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
|
||||
|
|
@ -476,7 +474,6 @@ while True:
|
|||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
"total_training_time": total_training_time,
|
||||
"train/loss": debiased_smooth_loss,
|
||||
|
|
@ -485,7 +482,7 @@ while True:
|
|||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
}, step)
|
||||
|
||||
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
||||
# We manually manage it to avoid these pauses during training.
|
||||
|
|
|
|||
24
uv.lock
24
uv.lock
|
|
@ -1550,7 +1550,7 @@ requires-dist = [
|
|||
{ name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } },
|
||||
{ name = "transformers", specifier = ">=4.57.3" },
|
||||
{ name = "uvicorn", specifier = ">=0.36.0" },
|
||||
{ name = "wandb", specifier = ">=0.21.3" },
|
||||
{ name = "wandb", specifier = ">=0.25.0" },
|
||||
{ name = "zstandard", specifier = ">=0.25.0" },
|
||||
]
|
||||
provides-extras = ["cpu", "gpu"]
|
||||
|
|
@ -3319,7 +3319,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "wandb"
|
||||
version = "0.21.3"
|
||||
version = "0.25.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "click" },
|
||||
|
|
@ -3333,17 +3333,17 @@ dependencies = [
|
|||
{ name = "sentry-sdk" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/2f/84/af6ccdf95e56f15aceb360e437fbfcca3dc91ad8ca335fe482083e29f7a5/wandb-0.21.3.tar.gz", hash = "sha256:031e24e2aad0ce735dfdcc74baf2f2c12c106f500ed24798de6ef9b9e63bb432", size = 40146972, upload-time = "2025-08-30T18:21:55.138Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fd/60/d94952549920469524b689479c864c692ca47eca4b8c2fe3389b64a58778/wandb-0.25.0.tar.gz", hash = "sha256:45840495a288e34245d69d07b5a0b449220fbc5b032e6b51c4f92ec9026d2ad1", size = 43951335, upload-time = "2026-02-13T00:17:45.515Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/aa/e8/b5bfbbc7f76c11fd0665b92be8a38c6a83b27f353552233b9959b21be488/wandb-0.21.3-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:f85bac45b4482742ec9ff190af38eb00a877ddeb4875475e7e487dc19300ff03", size = 18820209, upload-time = "2025-08-30T18:21:33.47Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/59/a3/03f0fcde49609df1cb3a382fb5053f601b88da448bcd415ed7f75272eee7/wandb-0.21.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:8a2b3ba419b91d47edead2755f04cef54f9e3c4496ee0c9854c3cfeff4216dd3", size = 18310636, upload-time = "2025-08-30T18:21:37.405Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/1d/c3/d6048db30ff2e3c67089ba0e94878572fd26137b146f8e3b27bbdf428b31/wandb-0.21.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:35a1972881f3b85755befab004118234593792a9f05e07fd6345780172f4420e", size = 19053277, upload-time = "2025-08-30T18:21:39.389Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ea/7f/805c3d2fa9e3b8b6bf2bc534887c9ed97bdf22007ca8ba59424a1c8bb360/wandb-0.21.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d9cf8588cb090a2a41f589037fda72c57c9e23edfbd2ad829e575f1305d942c", size = 18130850, upload-time = "2025-08-30T18:21:41.573Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/5b/af/a3252e5afac98a036f83c65ec92cadf6677ccdaacbbb2151da29f694d136/wandb-0.21.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff24b6b8e0f9da840b6bd5c7f60b0a5507bd998db40c9c2d476f9a340bec8ed", size = 19570305, upload-time = "2025-08-30T18:21:43.811Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/4d/f9/4404b5a24bfd4ba027c19d30152b0fc7ebca8c49b202dee6ecb7f316082c/wandb-0.21.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4975dec19e2b343e23ed6e60f7e1290120553719f82e87a22205bede758416ad", size = 18135806, upload-time = "2025-08-30T18:21:46.211Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ff/32/9580f42899e54f3d0b4ea619b6f6a54980a4e36fd0675d58c09f0a08d3f6/wandb-0.21.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:514a0aad40ecc0bdb757b1dc86e4ac98f61d2d760445b6e1f555291562320f2d", size = 19646760, upload-time = "2025-08-30T18:21:48.768Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/7d/0c131db3ec9deaabbd32263d90863cbfbe07659527e11c35a5c738cecdc5/wandb-0.25.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:5eecb3c7b5e60d1acfa4b056bfbaa0b79a482566a9db58c9f99724b3862bc8e5", size = 23287536, upload-time = "2026-02-13T00:17:20.265Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c3/95/31bb7f76a966ec87495e5a72ac7570685be162494c41757ac871768dbc4f/wandb-0.25.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:daeedaadb183dc466e634fba90ab2bab1d4e93000912be0dee95065a0624a3fd", size = 25196062, upload-time = "2026-02-13T00:17:23.356Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d9/a1/258cdedbf30cebc692198a774cf0ef945b7ed98ee64bdaf62621281c95d8/wandb-0.25.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5e0127dbcef13eea48f4b84268da7004d34d3120ebc7b2fa9cefb72b49dbb825", size = 22799744, upload-time = "2026-02-13T00:17:26.437Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/de/91/ec9465d014cfd199c5b2083d271d31b3c2aedeae66f3d8a0712f7f54bdf3/wandb-0.25.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6c4c38077836f9b7569a35b0e1dcf1f0c43616fcd936d182f475edbfea063665", size = 25262839, upload-time = "2026-02-13T00:17:28.8Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c7/95/cb2d1c7143f534544147fb53fe87944508b8cb9a058bc5b6f8a94adbee15/wandb-0.25.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6edd8948d305cb73745bf564b807bd73da2ccbd47c548196b8a362f7df40aed8", size = 22853714, upload-time = "2026-02-13T00:17:31.68Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/d7/94/68163f70c1669edcf130822aaaea782d8198b5df74443eca0085ec596774/wandb-0.25.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ada6f08629bb014ad6e0a19d5dec478cdaa116431baa3f0a4bf4ab8d9893611f", size = 25358037, upload-time = "2026-02-13T00:17:34.676Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/cc/fb/9578eed2c01b2fc6c8b693da110aa9c73a33d7bb556480f5cfc42e48c94e/wandb-0.25.0-py3-none-win32.whl", hash = "sha256:020b42ca4d76e347709d65f59b30d4623a115edc28f462af1c92681cb17eae7c", size = 24604118, upload-time = "2026-02-13T00:17:37.641Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/25/97/460f6cb738aaa39b4eb2e6b4c630b2ae4321cdd70a79d5955ea75a878981/wandb-0.25.0-py3-none-win_amd64.whl", hash = "sha256:78307ac0b328f2dc334c8607bec772851215584b62c439eb320c4af4fb077a00", size = 24604122, upload-time = "2026-02-13T00:17:39.991Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/27/6c/5847b4dda1dfd52630dac08711d4348c69ed657f0698fc2d949c7f7a6622/wandb-0.25.0-py3-none-win_arm64.whl", hash = "sha256:c6174401fd6fb726295e98d57b4231c100eca96bd17de51bfc64038a57230aaf", size = 21785298, upload-time = "2026-02-13T00:17:42.475Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user