revert formatting changes to minimize diff and merge conflicts

This commit is contained in:
svlandeg 2025-10-29 11:42:56 +01:00
parent 2b58e2dd2a
commit cbd560a83d

View File

@ -8,58 +8,45 @@ import logging
import torch import torch
import torch.distributed as dist import torch.distributed as dist
class ColoredFormatter(logging.Formatter): class ColoredFormatter(logging.Formatter):
"""Custom formatter that adds colors to log messages.""" """Custom formatter that adds colors to log messages."""
# ANSI color codes # ANSI color codes
COLORS = { COLORS = {
"DEBUG": "\033[36m", # Cyan 'DEBUG': '\033[36m', # Cyan
"INFO": "\033[32m", # Green 'INFO': '\033[32m', # Green
"WARNING": "\033[33m", # Yellow 'WARNING': '\033[33m', # Yellow
"ERROR": "\033[31m", # Red 'ERROR': '\033[31m', # Red
"CRITICAL": "\033[35m", # Magenta 'CRITICAL': '\033[35m', # Magenta
} }
RESET = "\033[0m" RESET = '\033[0m'
BOLD = "\033[1m" BOLD = '\033[1m'
def format(self, record): def format(self, record):
# Add color to the level name # Add color to the level name
levelname = record.levelname levelname = record.levelname
if levelname in self.COLORS: if levelname in self.COLORS:
record.levelname = ( record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
)
# Format the message # Format the message
message = super().format(record) message = super().format(record)
# Add color to specific parts of the message # Add color to specific parts of the message
if levelname == "INFO": if levelname == 'INFO':
# Highlight numbers and percentages # Highlight numbers and percentages
message = re.sub( message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
r"(\d+\.?\d*\s*(?:GB|MB|%|docs))", message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
rf"{self.BOLD}\1{self.RESET}",
message,
)
message = re.sub(
r"(Shard \d+)",
rf"{self.COLORS['INFO']}{self.BOLD}\1{self.RESET}",
message,
)
return message return message
def setup_default_logging(): def setup_default_logging():
handler = logging.StreamHandler() handler = logging.StreamHandler()
handler.setFormatter( handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") logging.basicConfig(
level=logging.INFO,
handlers=[handler]
) )
logging.basicConfig(level=logging.INFO, handlers=[handler])
setup_default_logging() setup_default_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_base_dir(): def get_base_dir():
# co-locate nanochat intermediates with other cached data in ~/.cache (by default) # co-locate nanochat intermediates with other cached data in ~/.cache (by default)
if os.environ.get("NANOCHAT_BASE_DIR"): if os.environ.get("NANOCHAT_BASE_DIR"):
@ -71,13 +58,11 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def print0(s="", **kwargs): def print0(s="", **kwargs):
ddp_rank = int(os.environ.get("RANK", 0)) ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0: if ddp_rank == 0:
print(s, **kwargs) print(s, **kwargs)
def print_banner(): def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
banner = """ banner = """
@ -92,23 +77,20 @@ def print_banner():
""" """
print0(banner) print0(banner)
def is_ddp(): def is_ddp():
# TODO is there a proper way # TODO is there a proper way
return int(os.environ.get("RANK", -1)) != -1 return int(os.environ.get('RANK', -1)) != -1
def get_dist_info(): def get_dist_info():
if is_ddp(): if is_ddp():
assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"]) assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
ddp_rank = int(os.environ["RANK"]) ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ["WORLD_SIZE"]) ddp_world_size = int(os.environ['WORLD_SIZE'])
return True, ddp_rank, ddp_local_rank, ddp_world_size return True, ddp_rank, ddp_local_rank, ddp_world_size
else: else:
return False, 0, 0, 1 return False, 0, 0, 1
def compute_init(): def compute_init():
"""Basic initialization that we keep doing over and over, so make common.""" """Basic initialization that we keep doing over and over, so make common."""
@ -124,7 +106,7 @@ def compute_init():
# torch.backends.cudnn.benchmark = False # torch.backends.cudnn.benchmark = False
# Precision # Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional # Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -141,21 +123,16 @@ def compute_init():
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup(): def compute_cleanup():
"""Companion function to compute_init, to clean things up before script exit""" """Companion function to compute_init, to clean things up before script exit"""
if is_ddp(): if is_ddp():
dist.destroy_process_group() dist.destroy_process_group()
class DummyWandb: class DummyWandb:
"""Useful if we wish to not use wandb but have all the same signatures""" """Useful if we wish to not use wandb but have all the same signatures"""
def __init__(self): def __init__(self):
pass pass
def log(self, *args, **kwargs): def log(self, *args, **kwargs):
pass pass
def finish(self): def finish(self):
pass pass