Refactor wandb logging initialization

Refactor wandb logging initialization
This commit is contained in:
Sermet Pekin 2025-11-05 15:55:53 +03:00 committed by GitHub
parent d9be7d4f14
commit 679ac96efe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,12 +16,11 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time import time
from contextlib import nullcontext from contextlib import nullcontext
import wandb
import torch import torch
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type from nanochat.common import compute_init, compute_cleanup, print0, get_wandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@ -75,8 +74,7 @@ synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process wandb_run = get_wandb("nanochat",run=run, master_process=master_process, user_config=user_config)
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
# Tokenizer will be useful for evaluation, also we need the vocab size # Tokenizer will be useful for evaluation, also we need the vocab size
tokenizer = get_tokenizer() tokenizer = get_tokenizer()