From cbd560a83d93a8de8ebb238608ee571e7952e2ac Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 29 Oct 2025 11:42:56 +0100 Subject: [PATCH] revert formatting changes to minimize diff and merge conflicts --- nanochat/common.py | 67 +++++++++++++++------------------------------- 1 file changed, 22 insertions(+), 45 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index bb825ff..a0867b0 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -8,58 +8,45 @@ import logging import torch import torch.distributed as dist - class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" - # ANSI color codes COLORS = { - "DEBUG": "\033[36m", # Cyan - "INFO": "\033[32m", # Green - "WARNING": "\033[33m", # Yellow - "ERROR": "\033[31m", # Red - "CRITICAL": "\033[35m", # Magenta + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green + 'WARNING': '\033[33m', # Yellow + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta } - RESET = "\033[0m" - BOLD = "\033[1m" + RESET = '\033[0m' + BOLD = '\033[1m' def format(self, record): # Add color to the level name levelname = record.levelname if levelname in self.COLORS: - record.levelname = ( - f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" - ) + record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}" # Format the message message = super().format(record) # Add color to specific parts of the message - if levelname == "INFO": + if levelname == 'INFO': # Highlight numbers and percentages - message = re.sub( - r"(\d+\.?\d*\s*(?:GB|MB|%|docs))", - rf"{self.BOLD}\1{self.RESET}", - message, - ) - message = re.sub( - r"(Shard \d+)", - rf"{self.COLORS['INFO']}{self.BOLD}\1{self.RESET}", - message, - ) + message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message) + message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) return message def setup_default_logging(): handler = logging.StreamHandler() - handler.setFormatter( - ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) + logging.basicConfig( + level=logging.INFO, + handlers=[handler] ) - logging.basicConfig(level=logging.INFO, handlers=[handler]) - setup_default_logging() logger = logging.getLogger(__name__) - def get_base_dir(): # co-locate nanochat intermediates with other cached data in ~/.cache (by default) if os.environ.get("NANOCHAT_BASE_DIR"): @@ -71,13 +58,11 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir - def print0(s="", **kwargs): - ddp_rank = int(os.environ.get("RANK", 0)) + ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs) - def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ banner = """ @@ -92,23 +77,20 @@ def print_banner(): """ print0(banner) - def is_ddp(): # TODO is there a proper way - return int(os.environ.get("RANK", -1)) != -1 - + return int(os.environ.get('RANK', -1)) != -1 def get_dist_info(): if is_ddp(): - assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]) - ddp_rank = int(os.environ["RANK"]) - ddp_local_rank = int(os.environ["LOCAL_RANK"]) - ddp_world_size = int(os.environ["WORLD_SIZE"]) + assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) + ddp_rank = int(os.environ['RANK']) + ddp_local_rank = int(os.environ['LOCAL_RANK']) + ddp_world_size = int(os.environ['WORLD_SIZE']) return True, ddp_rank, ddp_local_rank, ddp_world_size else: return False, 0, 0, 1 - def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" @@ -124,7 +106,7 @@ def compute_init(): # torch.backends.cudnn.benchmark = False # Precision - torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls + torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() @@ -141,21 +123,16 @@ def compute_init(): return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device - def compute_cleanup(): """Companion function to compute_init, to clean things up before script exit""" if is_ddp(): dist.destroy_process_group() - class DummyWandb: """Useful if we wish to not use wandb but have all the same signatures""" - def __init__(self): pass - def log(self, *args, **kwargs): pass - def finish(self): pass