mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
241 lines
9.1 KiB
Python
241 lines
9.1 KiB
Python
"""
|
|
Common utilities for nanochat.
|
|
"""
|
|
|
|
import os
|
|
import re
|
|
import logging
|
|
import urllib.request
|
|
import torch
|
|
import torch.distributed as dist
|
|
from filelock import FileLock
|
|
|
|
class ColoredFormatter(logging.Formatter):
|
|
"""Custom formatter that adds colors to log messages."""
|
|
# ANSI color codes
|
|
COLORS = {
|
|
'DEBUG': '\033[36m', # Cyan
|
|
'INFO': '\033[32m', # Green
|
|
'WARNING': '\033[33m', # Yellow
|
|
'ERROR': '\033[31m', # Red
|
|
'CRITICAL': '\033[35m', # Magenta
|
|
}
|
|
RESET = '\033[0m'
|
|
BOLD = '\033[1m'
|
|
def format(self, record):
|
|
# Add color to the level name
|
|
levelname = record.levelname
|
|
if levelname in self.COLORS:
|
|
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
|
# Format the message
|
|
message = super().format(record)
|
|
# Add color to specific parts of the message
|
|
if levelname == 'INFO':
|
|
# Highlight numbers and percentages
|
|
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
|
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
|
return message
|
|
|
|
def setup_default_logging():
|
|
handler = logging.StreamHandler()
|
|
handler.setLevel(logging.INFO)
|
|
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
handlers=[handler]
|
|
)
|
|
|
|
def setup_file_logger(logger_name, filename, level=logging.DEBUG, formatter=None):
|
|
clean_name = os.path.basename(filename)
|
|
if clean_name != filename or not clean_name:
|
|
raise ValueError(f"Invalid log filename provided: {filename}")
|
|
if not clean_name.endswith(".log"):
|
|
clean_name += ".log"
|
|
logs_dir = get_logs_dir()
|
|
path = os.path.join(logs_dir, clean_name)
|
|
|
|
handler = logging.FileHandler(path, mode="w")
|
|
handler.setLevel(level)
|
|
handler.setFormatter(
|
|
formatter
|
|
or ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
)
|
|
logger = logging.getLogger(logger_name)
|
|
logger.addHandler(handler)
|
|
return path
|
|
|
|
setup_default_logging()
|
|
logger = logging.getLogger(__name__)
|
|
|
|
def get_base_dir():
|
|
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
|
if os.environ.get("NANOCHAT_BASE_DIR"):
|
|
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
|
|
else:
|
|
home_dir = os.path.expanduser("~")
|
|
cache_dir = os.path.join(home_dir, ".cache")
|
|
nanochat_dir = os.path.join(cache_dir, "nanochat")
|
|
os.makedirs(nanochat_dir, exist_ok=True)
|
|
return nanochat_dir
|
|
|
|
def get_project_root():
|
|
# locates the project root by walking upward from this file
|
|
_PROJECT_MARKERS = ('.git', 'uv.lock')
|
|
curr = os.path.dirname(os.path.abspath(__file__))
|
|
while True:
|
|
if any(os.path.exists(os.path.join(curr, m)) for m in _PROJECT_MARKERS):
|
|
return curr
|
|
parent = os.path.dirname(curr)
|
|
if parent == curr: # reached filesystem root
|
|
return None
|
|
curr = parent
|
|
|
|
def get_logs_dir():
|
|
"""
|
|
Resolves the directory where log files should be written.
|
|
- if $LOG_DIR is set, use that.
|
|
- else if, project root is detected, use <project_root>/logs.
|
|
- else, fall back to <get_base_dir()>/logs
|
|
"""
|
|
env = os.environ.get("LOG_DIR")
|
|
if env:
|
|
path = os.path.abspath(env)
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
root = get_project_root()
|
|
if not root:
|
|
root = get_base_dir()
|
|
logs = os.path.join(root, 'logs')
|
|
os.makedirs(logs, exist_ok=True)
|
|
return logs
|
|
|
|
def download_file_with_lock(url, filename, postprocess_fn=None):
|
|
"""
|
|
Downloads a file from a URL to a local path in the base directory.
|
|
Uses a lock file to prevent concurrent downloads among multiple ranks.
|
|
"""
|
|
base_dir = get_base_dir()
|
|
file_path = os.path.join(base_dir, filename)
|
|
lock_path = file_path + ".lock"
|
|
|
|
if os.path.exists(file_path):
|
|
return file_path
|
|
|
|
with FileLock(lock_path):
|
|
# Only a single rank can acquire this lock
|
|
# All other ranks block until it is released
|
|
|
|
# Recheck after acquiring lock
|
|
if os.path.exists(file_path):
|
|
return file_path
|
|
|
|
# Download the content as bytes
|
|
print(f"Downloading {url}...")
|
|
with urllib.request.urlopen(url) as response:
|
|
content = response.read() # bytes
|
|
|
|
# Write to local file
|
|
with open(file_path, 'wb') as f:
|
|
f.write(content)
|
|
print(f"Downloaded to {file_path}")
|
|
|
|
# Run the postprocess function if provided
|
|
if postprocess_fn is not None:
|
|
postprocess_fn(file_path)
|
|
|
|
return file_path
|
|
|
|
def print0(s="",**kwargs):
|
|
ddp_rank = int(os.environ.get('RANK', 0))
|
|
if ddp_rank == 0:
|
|
print(s, **kwargs)
|
|
|
|
def print_banner():
|
|
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
|
banner = """
|
|
█████ █████
|
|
░░███ ░░███
|
|
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
|
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
|
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
|
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
|
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
|
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
|
"""
|
|
print0(banner)
|
|
|
|
def is_ddp():
|
|
# TODO is there a proper way
|
|
return int(os.environ.get('RANK', -1)) != -1
|
|
|
|
def get_dist_info():
|
|
if is_ddp():
|
|
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
|
ddp_rank = int(os.environ['RANK'])
|
|
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
|
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
|
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
|
else:
|
|
return False, 0, 0, 1
|
|
|
|
def autodetect_device_type():
|
|
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
|
if torch.cuda.is_available():
|
|
device_type = "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
device_type = "mps"
|
|
else:
|
|
device_type = "cpu"
|
|
print0(f"Autodetected device type: {device_type}")
|
|
return device_type
|
|
|
|
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|
"""Basic initialization that we keep doing over and over, so make common."""
|
|
|
|
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
|
if device_type == "cuda":
|
|
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
|
if device_type == "mps":
|
|
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
|
|
|
# Reproducibility
|
|
torch.manual_seed(42)
|
|
if device_type == "cuda":
|
|
torch.cuda.manual_seed(42)
|
|
# skipping full reproducibility for now, possibly investigate slowdown later
|
|
# torch.use_deterministic_algorithms(True)
|
|
|
|
# Precision
|
|
if device_type == "cuda":
|
|
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
|
|
|
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
|
if ddp and device_type == "cuda":
|
|
device = torch.device("cuda", ddp_local_rank)
|
|
torch.cuda.set_device(device) # make "cuda" default to this device
|
|
dist.init_process_group(backend="nccl", device_id=device)
|
|
dist.barrier()
|
|
else:
|
|
device = torch.device(device_type) # mps|cpu
|
|
|
|
if ddp_rank == 0:
|
|
logger.info(f"Distributed world size: {ddp_world_size}")
|
|
|
|
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
|
|
|
def compute_cleanup():
|
|
"""Companion function to compute_init, to clean things up before script exit"""
|
|
if is_ddp():
|
|
dist.destroy_process_group()
|
|
|
|
class DummyWandb:
|
|
"""Useful if we wish to not use wandb but have all the same signatures"""
|
|
def __init__(self):
|
|
pass
|
|
def log(self, *args, **kwargs):
|
|
pass
|
|
def finish(self):
|
|
pass
|