Replace wandb initialization with get_wandb function

This commit is contained in:
Sermet Pekin 2025-11-05 15:56:49 +03:00 committed by GitHub
parent 679ac96efe
commit 523714b5c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -19,11 +19,12 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
import os import os
import itertools import itertools
import re import re
import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_wandb
from nanochat.checkpoint_manager import save_checkpoint, load_model from nanochat.checkpoint_manager import save_checkpoint, load_model
from nanochat.engine import Engine from nanochat.engine import Engine
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
@ -60,8 +61,7 @@ dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process wandb_run = get_wandb("nanochat-rl" , run=run, master_process=master_process, user_config=user_config)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
# Init model and tokenizer # Init model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="eval") model, tokenizer, meta = load_model(source, device, phase="eval")