nanochat/nanochat/common.py
google-labs-jules[bot] 1f9b734358 Use gloo backend for DDP on AMD ROCm to avoid NCCL crashes
On consumer AMD hardware (like APUs or gaming GPUs) running ROCm, the default `nccl` backend (which wraps RCCL) often fails with `invalid device function` due to architecture mismatches or kernel issues.

This change detects the presence of `torch.version.hip` and forces the `gloo` backend for `torch.distributed.init_process_group`. While `gloo` is slower for data transfer, it is CPU-based and significantly more robust for these setups, ensuring the training script can run without crashing.
2025-11-23 06:49:07 +00:00

214 lines
9.0 KiB
Python

"""
Common utilities for nanochat.
"""
import os
import re
import logging
import urllib.request
import torch
import torch.distributed as dist
from filelock import FileLock
class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages."""
# ANSI color codes
COLORS = {
'DEBUG': '\033[36m', # Cyan
'INFO': '\033[32m', # Green
'WARNING': '\033[33m', # Yellow
'ERROR': '\033[31m', # Red
'CRITICAL': '\033[35m', # Magenta
}
RESET = '\033[0m'
BOLD = '\033[1m'
def format(self, record):
# Add color to the level name
levelname = record.levelname
if levelname in self.COLORS:
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
# Format the message
message = super().format(record)
# Add color to specific parts of the message
if levelname == 'INFO':
# Highlight numbers and percentages
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
return message
def setup_default_logging():
handler = logging.StreamHandler()
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logging.basicConfig(
level=logging.INFO,
handlers=[handler]
)
setup_default_logging()
logger = logging.getLogger(__name__)
def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"):
nanochat_dir = os.environ.get("NANOCHAT_BASE_DIR")
else:
home_dir = os.path.expanduser("~")
cache_dir = os.path.join(home_dir, ".cache")
nanochat_dir = os.path.join(cache_dir, "nanochat")
os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir
def download_file_with_lock(url, filename, postprocess_fn=None):
"""
Downloads a file from a URL to a local path in the base directory.
Uses a lock file to prevent concurrent downloads among multiple ranks.
"""
base_dir = get_base_dir()
file_path = os.path.join(base_dir, filename)
lock_path = file_path + ".lock"
if os.path.exists(file_path):
return file_path
with FileLock(lock_path):
# Only a single rank can acquire this lock
# All other ranks block until it is released
# Recheck after acquiring lock
if os.path.exists(file_path):
return file_path
# Download the content as bytes
print(f"Downloading {url}...")
with urllib.request.urlopen(url) as response:
content = response.read() # bytes
# Write to local file
with open(file_path, 'wb') as f:
f.write(content)
print(f"Downloaded to {file_path}")
# Run the postprocess function if provided
if postprocess_fn is not None:
postprocess_fn(file_path)
return file_path
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
print(s, **kwargs)
def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """
█████ █████
░░███ ░░███
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
"""
print0(banner)
def is_ddp():
# TODO is there a proper way
return int(os.environ.get('RANK', -1)) != -1
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size
else:
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
# Check for CUDA/ROCm but also ensure we actually have devices (device_count > 0)
has_cuda = torch.cuda.is_available()
has_rocm = hasattr(torch.version, "hip") and torch.version.hip
if (has_cuda or has_rocm) and torch.cuda.device_count() > 0:
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common."""
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
if device_type == "cuda":
if not torch.cuda.is_available():
if hasattr(torch.version, "hip") and torch.version.hip:
print("WARNING: CUDA not available but ROCm detected. Proceeding with 'cuda' device. You might need to set HSA_OVERRIDE_GFX_VERSION if on an APU.")
else:
assert False, "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if device_type == "cuda":
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
# On AMD ROCm (especially consumer/APU hardware), NCCL/RCCL can be unstable or mismatched.
# Fallback to gloo if HIP is detected to avoid "invalid device function" crashes.
# While slower than NCCL, it ensures functionality on a wider range of AMD hardware.
is_rocm = hasattr(torch.version, "hip") and torch.version.hip
backend = "gloo" if is_rocm else "nccl"
# gloo backend does not accept 'device_id' argument
if backend == "gloo":
dist.init_process_group(backend=backend)
else:
dist.init_process_group(backend=backend, device_id=device)
dist.barrier()
elif device_type == "cpu":
device = torch.device("cpu")
dist.init_process_group(backend="gloo")
dist.barrier()
else:
device = torch.device(device_type) # mps
else:
device = torch.device(device_type) # mps|cpu
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit"""
if is_ddp() and dist.is_initialized():
dist.destroy_process_group()
class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self):
pass
def log(self, *args, **kwargs):
pass
def finish(self):
pass