From b9f01eedd9e1f69f006cfd88f709ae584b1e36b3 Mon Sep 17 00:00:00 2001 From: Sermet Pekin <96650846+SermetPekin@users.noreply.github.com> Date: Wed, 5 Nov 2025 15:58:02 +0300 Subject: [PATCH] Refactor wandb initialization in chat_sft.py --- scripts/chat_sft.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index e6e4565..595d24c 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -12,12 +12,11 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" -import wandb import torch import torch.distributed as dist from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, get_wandb, autodetect_device_type from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import save_checkpoint from nanochat.engine import Engine @@ -69,8 +68,7 @@ ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # wandb logging init -use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True) +wandb_run = get_wandb("nanochat-sft" , run=run, master_process=master_process, user_config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)