From 70d0abe4321a428d648c62a0aace497a790c244a Mon Sep 17 00:00:00 2001 From: Filip Date: Tue, 10 Mar 2026 15:45:11 -0400 Subject: [PATCH] Fix a few cosmetic wandb issues: 1. Update the pinned `wandb` library version. The old version raises when given new `wandb` API keys! 2. Move the `step` argument to the right place in `wandb.log` calls. The signature is `wandb.log(data: dict, step: int, commit: bool)` - previously, step counts were being included in the data dict, meaning wandb metrics incorrectly had x-axes corresponding to the number of calls to `.log` instead of the number of training steps. 3. Move `wandb.init` later in `chat_sft.py` and `base_train.py` to include config values that are calculated or read from a checkpoint. --- pyproject.toml | 2 +- scripts/base_train.py | 36 +++++++++++++++++++++++------------- scripts/chat_rl.py | 15 ++++----------- scripts/chat_sft.py | 17 +++++++---------- uv.lock | 24 ++++++++++++------------ 5 files changed, 47 insertions(+), 47 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8b6fd95..524198b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "torch==2.9.1", "transformers>=4.57.3", "uvicorn>=0.36.0", - "wandb>=0.21.3", + "wandb>=0.25.0", "zstandard>=0.25.0", ] diff --git a/scripts/base_train.py b/scripts/base_train.py index cfbfe28..6e57ae0 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -80,7 +80,7 @@ parser.add_argument("--model-tag", type=str, default=None, help="override model args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- -# Compute init and wandb logging +# Compute init device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) @@ -95,10 +95,6 @@ else: gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") -# wandb logging init -use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) - # Flash Attention status from nanochat.flash_attention import USE_FA3 using_fa3 = USE_FA3 @@ -377,6 +373,24 @@ def get_muon_momentum(it): def get_weight_decay(it): return weight_decay_scaled * 0.5 * (1 + math.cos(math.pi * it / num_iterations)) +# ----------------------------------------------------------------------------- +# wandb logging init +use_dummy_wandb = args.run == "dummy" or not master_process +if use_dummy_wandb: + wandb_run = DummyWandb() +else: + realized_config = dict( + device_type=device_type, + compute_dtype=str(COMPUTE_DTYPE), + target_tokens=target_tokens, + total_batch_size=total_batch_size, + batch_lr_scale=batch_lr_scale, + weight_decay_scaled=weight_decay_scaled, + num_iterations=num_iterations, + total_tokens=total_tokens, + ) + wandb_run = wandb.init(project="nanochat", name=args.run, config=user_config | realized_config) + # ----------------------------------------------------------------------------- # Training loop @@ -420,11 +434,10 @@ while True: if val_bpb < min_val_bpb: min_val_bpb = val_bpb wandb_run.log({ - "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, - }) + }, step) model.train() # once in a while: estimate the CORE metric (all ranks participate) @@ -437,11 +450,10 @@ while True: results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ - "step": step, "total_training_flops": flops_so_far, "core_metric": results["core_metric"], "centered_results": results["centered_results"], - }) + }, step) model.train() # once in a while: sample from the model (only on master process) @@ -558,8 +570,7 @@ while True: epoch = f"{dataloader_state_dict['epoch']} pq: {dataloader_state_dict['pq_idx']} rg: {dataloader_state_dict['rg_idx']}" print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | bf16_mfu: {mfu:.2f} | epoch: {epoch} | total time: {total_training_time/60:.2f}m{eta_str}") if step % 100 == 0: - log_data = { - "step": step, + wandb_run.log({ "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, @@ -568,8 +579,7 @@ while True: "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, "train/epoch": epoch, - } - wandb_run.log(log_data) + }, step) # state update first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index cb2cb0e..07f00d1 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -237,10 +237,7 @@ for step in range(num_steps): print_passk = [f"Pass@{k}: {passk[k - 1].item():.4f}" for k in range(1, args.device_batch_size + 1)] print0(f"Step {step} | {', '.join(print_passk)}") log_passk = {f"pass@{k}": passk[k - 1].item() for k in range(1, args.device_batch_size + 1)} - wandb_run.log({ - "step": step, - **log_passk, - }) + wandb_run.log(log_passk, step) # Forward/Backward on rollouts over multiple examples in the dataset rewards_list = [] @@ -287,11 +284,6 @@ for step in range(num_steps): mean_reward = mean_reward_tensor.item() mean_sequence_length = mean_sequence_length_tensor.item() print0(f"Step {step}/{num_steps} | Average reward: {mean_reward} | Average sequence length: {mean_sequence_length:.2f}") - wandb_run.log({ - "step": step, - "reward": mean_reward, - "sequence_length": mean_sequence_length, - }) # Update the model parameters lrm = get_lr_multiplier(step) @@ -300,9 +292,10 @@ for step in range(num_steps): optimizer.step() model.zero_grad(set_to_none=True) wandb_run.log({ - "step": step, + "reward": mean_reward, + "sequence_length": mean_sequence_length, "lrm": lrm, - }) + }, step) # Master process saves the model once in a while. Skip first step. Save last step. if master_process and ((step > 0 and step % args.save_every == 0) or step == num_steps - 1): diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..d9a6e6c 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -84,10 +84,6 @@ if device_type == "cuda": else: gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS -# wandb logging init -use_dummy_wandb = args.run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) - # Flash Attention status if not HAS_FA3: print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") @@ -116,6 +112,10 @@ for name, fallback, source in [ else: print0(f"Using {name}={arg_val}") +# wandb logging init +use_dummy_wandb = args.run == "dummy" or not master_process +wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=args) + orig_model = model model = torch.compile(model, dynamic=False) depth = model.config.n_layer @@ -353,11 +353,10 @@ while True: if val_bpb < min_val_bpb: min_val_bpb = val_bpb wandb_run.log({ - "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "val/bpb": val_bpb, - }) + }, step) model.train() # once in a while: estimate the ChatCORE metric (all ranks participate) @@ -387,12 +386,11 @@ while True: chatcore_cat = centered_mean(categorical_tasks) print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}") wandb_run.log({ - "step": step, "total_training_flops": flops_so_far, "chatcore_metric": chatcore, "chatcore_cat": chatcore_cat, **{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()}, - }) + }, step) model.train() # save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard) @@ -476,7 +474,6 @@ while True: print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ - "step": step, "total_training_flops": flops_so_far, "total_training_time": total_training_time, "train/loss": debiased_smooth_loss, @@ -485,7 +482,7 @@ while True: "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, "train/epoch": current_epoch, - }) + }, step) # The garbage collector spends ~500ms scanning for cycles quite frequently. # We manually manage it to avoid these pauses during training. diff --git a/uv.lock b/uv.lock index bbc9519..eb653d1 100644 --- a/uv.lock +++ b/uv.lock @@ -1550,7 +1550,7 @@ requires-dist = [ { name = "torch", marker = "extra == 'gpu'", specifier = "==2.9.1", index = "https://download.pytorch.org/whl/cu128", conflict = { package = "nanochat", extra = "gpu" } }, { name = "transformers", specifier = ">=4.57.3" }, { name = "uvicorn", specifier = ">=0.36.0" }, - { name = "wandb", specifier = ">=0.21.3" }, + { name = "wandb", specifier = ">=0.25.0" }, { name = "zstandard", specifier = ">=0.25.0" }, ] provides-extras = ["cpu", "gpu"] @@ -3319,7 +3319,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.21.3" +version = "0.25.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3333,17 +3333,17 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2f/84/af6ccdf95e56f15aceb360e437fbfcca3dc91ad8ca335fe482083e29f7a5/wandb-0.21.3.tar.gz", hash = "sha256:031e24e2aad0ce735dfdcc74baf2f2c12c106f500ed24798de6ef9b9e63bb432", size = 40146972, upload-time = "2025-08-30T18:21:55.138Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/60/d94952549920469524b689479c864c692ca47eca4b8c2fe3389b64a58778/wandb-0.25.0.tar.gz", hash = "sha256:45840495a288e34245d69d07b5a0b449220fbc5b032e6b51c4f92ec9026d2ad1", size = 43951335, upload-time = "2026-02-13T00:17:45.515Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/e8/b5bfbbc7f76c11fd0665b92be8a38c6a83b27f353552233b9959b21be488/wandb-0.21.3-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:f85bac45b4482742ec9ff190af38eb00a877ddeb4875475e7e487dc19300ff03", size = 18820209, upload-time = "2025-08-30T18:21:33.47Z" }, - { url = "https://files.pythonhosted.org/packages/59/a3/03f0fcde49609df1cb3a382fb5053f601b88da448bcd415ed7f75272eee7/wandb-0.21.3-py3-none-macosx_12_0_arm64.whl", hash = "sha256:8a2b3ba419b91d47edead2755f04cef54f9e3c4496ee0c9854c3cfeff4216dd3", size = 18310636, upload-time = "2025-08-30T18:21:37.405Z" }, - { url = "https://files.pythonhosted.org/packages/1d/c3/d6048db30ff2e3c67089ba0e94878572fd26137b146f8e3b27bbdf428b31/wandb-0.21.3-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:35a1972881f3b85755befab004118234593792a9f05e07fd6345780172f4420e", size = 19053277, upload-time = "2025-08-30T18:21:39.389Z" }, - { url = "https://files.pythonhosted.org/packages/ea/7f/805c3d2fa9e3b8b6bf2bc534887c9ed97bdf22007ca8ba59424a1c8bb360/wandb-0.21.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d9cf8588cb090a2a41f589037fda72c57c9e23edfbd2ad829e575f1305d942c", size = 18130850, upload-time = "2025-08-30T18:21:41.573Z" }, - { url = "https://files.pythonhosted.org/packages/5b/af/a3252e5afac98a036f83c65ec92cadf6677ccdaacbbb2151da29f694d136/wandb-0.21.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff24b6b8e0f9da840b6bd5c7f60b0a5507bd998db40c9c2d476f9a340bec8ed", size = 19570305, upload-time = "2025-08-30T18:21:43.811Z" }, - { url = "https://files.pythonhosted.org/packages/4d/f9/4404b5a24bfd4ba027c19d30152b0fc7ebca8c49b202dee6ecb7f316082c/wandb-0.21.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4975dec19e2b343e23ed6e60f7e1290120553719f82e87a22205bede758416ad", size = 18135806, upload-time = "2025-08-30T18:21:46.211Z" }, - { url = "https://files.pythonhosted.org/packages/ff/32/9580f42899e54f3d0b4ea619b6f6a54980a4e36fd0675d58c09f0a08d3f6/wandb-0.21.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:514a0aad40ecc0bdb757b1dc86e4ac98f61d2d760445b6e1f555291562320f2d", size = 19646760, upload-time = "2025-08-30T18:21:48.768Z" }, - { url = "https://files.pythonhosted.org/packages/75/d3/faa6ddb792a158c154fb704b25c96d0478e71eabf96e3f17529fb23b6894/wandb-0.21.3-py3-none-win32.whl", hash = "sha256:45aa3d8ad53c6ee06f37490d7a329ed7d0f5ca4dbd5d05bb0c01d5da22f14691", size = 18709408, upload-time = "2025-08-30T18:21:50.859Z" }, - { url = "https://files.pythonhosted.org/packages/d8/2d/7ef56e25f78786e59fefd9b19867c325f9686317d9f7b93b5cb340360a3e/wandb-0.21.3-py3-none-win_amd64.whl", hash = "sha256:56d5a5697766f552a9933d8c6a564202194768eb0389bd5f9fe9a99cd4cee41e", size = 18709411, upload-time = "2025-08-30T18:21:52.874Z" }, + { url = "https://files.pythonhosted.org/packages/c1/7d/0c131db3ec9deaabbd32263d90863cbfbe07659527e11c35a5c738cecdc5/wandb-0.25.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:5eecb3c7b5e60d1acfa4b056bfbaa0b79a482566a9db58c9f99724b3862bc8e5", size = 23287536, upload-time = "2026-02-13T00:17:20.265Z" }, + { url = "https://files.pythonhosted.org/packages/c3/95/31bb7f76a966ec87495e5a72ac7570685be162494c41757ac871768dbc4f/wandb-0.25.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:daeedaadb183dc466e634fba90ab2bab1d4e93000912be0dee95065a0624a3fd", size = 25196062, upload-time = "2026-02-13T00:17:23.356Z" }, + { url = "https://files.pythonhosted.org/packages/d9/a1/258cdedbf30cebc692198a774cf0ef945b7ed98ee64bdaf62621281c95d8/wandb-0.25.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5e0127dbcef13eea48f4b84268da7004d34d3120ebc7b2fa9cefb72b49dbb825", size = 22799744, upload-time = "2026-02-13T00:17:26.437Z" }, + { url = "https://files.pythonhosted.org/packages/de/91/ec9465d014cfd199c5b2083d271d31b3c2aedeae66f3d8a0712f7f54bdf3/wandb-0.25.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:6c4c38077836f9b7569a35b0e1dcf1f0c43616fcd936d182f475edbfea063665", size = 25262839, upload-time = "2026-02-13T00:17:28.8Z" }, + { url = "https://files.pythonhosted.org/packages/c7/95/cb2d1c7143f534544147fb53fe87944508b8cb9a058bc5b6f8a94adbee15/wandb-0.25.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:6edd8948d305cb73745bf564b807bd73da2ccbd47c548196b8a362f7df40aed8", size = 22853714, upload-time = "2026-02-13T00:17:31.68Z" }, + { url = "https://files.pythonhosted.org/packages/d7/94/68163f70c1669edcf130822aaaea782d8198b5df74443eca0085ec596774/wandb-0.25.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:ada6f08629bb014ad6e0a19d5dec478cdaa116431baa3f0a4bf4ab8d9893611f", size = 25358037, upload-time = "2026-02-13T00:17:34.676Z" }, + { url = "https://files.pythonhosted.org/packages/cc/fb/9578eed2c01b2fc6c8b693da110aa9c73a33d7bb556480f5cfc42e48c94e/wandb-0.25.0-py3-none-win32.whl", hash = "sha256:020b42ca4d76e347709d65f59b30d4623a115edc28f462af1c92681cb17eae7c", size = 24604118, upload-time = "2026-02-13T00:17:37.641Z" }, + { url = "https://files.pythonhosted.org/packages/25/97/460f6cb738aaa39b4eb2e6b4c630b2ae4321cdd70a79d5955ea75a878981/wandb-0.25.0-py3-none-win_amd64.whl", hash = "sha256:78307ac0b328f2dc334c8607bec772851215584b62c439eb320c4af4fb077a00", size = 24604122, upload-time = "2026-02-13T00:17:39.991Z" }, + { url = "https://files.pythonhosted.org/packages/27/6c/5847b4dda1dfd52630dac08711d4348c69ed657f0698fc2d949c7f7a6622/wandb-0.25.0-py3-none-win_arm64.whl", hash = "sha256:c6174401fd6fb726295e98d57b4231c100eca96bd17de51bfc64038a57230aaf", size = 21785298, upload-time = "2026-02-13T00:17:42.475Z" }, ] [[package]]