Refactor wandb logging initialization

This commit is contained in:
Sermet Pekin 2025-11-05 15:58:41 +03:00 committed by GitHub
parent b9f01eedd9
commit 545bb8e772
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -13,15 +13,17 @@ from collections import deque
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time
import wandb
import torch
import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, get_wandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
@ -62,8 +64,7 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config)
wandb_run = get_wandb("nanochat-mid" , run=run, master_process=master_process, user_config=user_config)
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)