diff --git a/nanochat/common.py b/nanochat/common.py index d4a9828..1f249c5 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -1,14 +1,30 @@ """ Common utilities for nanochat. """ - +from filelock import FileLock import os import re import logging import urllib.request import torch import torch.distributed as dist -from filelock import FileLock +import wandb + +class DummyWandb: + """Useful if we wish to not use wandb but have all the same signatures""" + def __init__(self): + pass + def log(self, *args, **kwargs): + pass + def finish(self): + pass + +def get_wandb(project:str="nanochat-rl", run="dummy", master_process:bool=False, user_config:dict=None): + """Initialize wandb logging or return a dummy logger for non-master processes.""" + # wandb logging init + use_dummy_wandb = run == "dummy" or not master_process + wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project=project, name=run, config=user_config) + return wandb_run class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" @@ -178,11 +194,4 @@ def compute_cleanup(): if is_ddp(): dist.destroy_process_group() -class DummyWandb: - """Useful if we wish to not use wandb but have all the same signatures""" - def __init__(self): - pass - def log(self, *args, **kwargs): - pass - def finish(self): - pass +