diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..fa86f55 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -13,15 +13,17 @@ from collections import deque import os os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import time -import wandb + import torch +import torch.distributed as dist + from contextlib import nullcontext -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_wandb, get_base_dir, autodetect_device_type from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.checkpoint_manager import load_model -import torch.distributed as dist + from tasks.common import TaskMixture from tasks.gsm8k import GSM8K @@ -62,8 +64,7 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 # wandb logging init -use_dummy_wandb = run == "dummy" or not master_process -wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-mid", name=run, config=user_config) +wandb_run = get_wandb("nanochat-mid" , run=run, master_process=master_process, user_config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=model_tag, step=step)