mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Replace wandb initialization with get_wandb function
This commit is contained in:
parent
679ac96efe
commit
523714b5c8
|
|
@ -19,11 +19,12 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=default
|
|||
import os
|
||||
import itertools
|
||||
import re
|
||||
import wandb
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_wandb
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -60,8 +61,7 @@ dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
|||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-rl", name=run, config=user_config)
|
||||
wandb_run = get_wandb("nanochat-rl" , run=run, master_process=master_process, user_config=user_config)
|
||||
|
||||
# Init model and tokenizer
|
||||
model, tokenizer, meta = load_model(source, device, phase="eval")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user