From 545bb8e7723b65c8f23974d4bb2e514caebf3535 Mon Sep 17 00:00:00 2001 From: Sermet Pekin <96650846+SermetPekin@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:58:41 +0300 Subject: [PATCH] Refactor wandb logging initialization --- scripts/mid_train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..fa86f55 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -13,15 +13,17 @@ from collections import deque import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time -import wandb + import torch +import torch.distributed as dist + from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_wandb, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.checkpoint_manager import load_model -import torch.distributed as dist + from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -62,8 +64,7 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init -use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config) +wandb_run = get_wandb("nanochat-mid" , run=run, master_process=master_process, user_config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)