nanochat/nanochat/common.py
Sermet Pekin d9be7d4f14
add get_wandb function that will either return DummyWandb or real wandb initalized
This will centralize current approact of creating wandb or DummyWandb depending on some variables set on the script
2025-11-05 15:54:48 +03:00

198 lines
7.8 KiB
Python

"""
Common utilities for nanochat.
"""
from filelock import FileLock
import os
import re
import logging
import urllib.request
import torch
import torch.distributed as dist
import wandb
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass
def get_wandb(project:str="nanochat-rl", run="dummy", master_process:bool=False, user_config:dict=None):
"""Initialize wandb logging or return a dummy logger for non-master processes."""
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project=project, name=run, config=user_config)
return wandb_run
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
else:
home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, ".cache")
nanochat_dir = os.path.join(cache_dir, "nanochat")
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with FileLock(lock_path):
# Only a single rank can acquire this lock
# All other ranks block until it is released
# Recheck after acquiring lock
if os.path.exists(file_path):
return file_path
# Download the content as bytes
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read() # bytes
# Write to local file
with open(file_path, 'wb') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Run the postprocess function if provided
if postprocess_fn is not None:
postprocess_fn(file_path)
return file_path
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
█████ █████
░░███ ░░███
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp():
dist.destroy_process_group()