mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Update logo in code as well
This commit is contained in:
parent
938cb31f1a
commit
2b58e2dd2a
|
|
@ -8,43 +8,58 @@ import logging
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
class ColoredFormatter(logging.Formatter):
|
class ColoredFormatter(logging.Formatter):
|
||||||
"""Custom formatter that adds colors to log messages."""
|
"""Custom formatter that adds colors to log messages."""
|
||||||
|
|
||||||
# ANSI color codes
|
# ANSI color codes
|
||||||
COLORS = {
|
COLORS = {
|
||||||
'DEBUG': '\033[36m', # Cyan
|
"DEBUG": "\033[36m", # Cyan
|
||||||
'INFO': '\033[32m', # Green
|
"INFO": "\033[32m", # Green
|
||||||
'WARNING': '\033[33m', # Yellow
|
"WARNING": "\033[33m", # Yellow
|
||||||
'ERROR': '\033[31m', # Red
|
"ERROR": "\033[31m", # Red
|
||||||
'CRITICAL': '\033[35m', # Magenta
|
"CRITICAL": "\033[35m", # Magenta
|
||||||
}
|
}
|
||||||
RESET = '\033[0m'
|
RESET = "\033[0m"
|
||||||
BOLD = '\033[1m'
|
BOLD = "\033[1m"
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
# Add color to the level name
|
# Add color to the level name
|
||||||
levelname = record.levelname
|
levelname = record.levelname
|
||||||
if levelname in self.COLORS:
|
if levelname in self.COLORS:
|
||||||
record.levelname = f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
record.levelname = (
|
||||||
|
f"{self.COLORS[levelname]}{self.BOLD}{levelname}{self.RESET}"
|
||||||
|
)
|
||||||
# Format the message
|
# Format the message
|
||||||
message = super().format(record)
|
message = super().format(record)
|
||||||
# Add color to specific parts of the message
|
# Add color to specific parts of the message
|
||||||
if levelname == 'INFO':
|
if levelname == "INFO":
|
||||||
# Highlight numbers and percentages
|
# Highlight numbers and percentages
|
||||||
message = re.sub(r'(\d+\.?\d*\s*(?:GB|MB|%|docs))', rf'{self.BOLD}\1{self.RESET}', message)
|
message = re.sub(
|
||||||
message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message)
|
r"(\d+\.?\d*\s*(?:GB|MB|%|docs))",
|
||||||
|
rf"{self.BOLD}\1{self.RESET}",
|
||||||
|
message,
|
||||||
|
)
|
||||||
|
message = re.sub(
|
||||||
|
r"(Shard \d+)",
|
||||||
|
rf"{self.COLORS['INFO']}{self.BOLD}\1{self.RESET}",
|
||||||
|
message,
|
||||||
|
)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def setup_default_logging():
|
def setup_default_logging():
|
||||||
handler = logging.StreamHandler()
|
handler = logging.StreamHandler()
|
||||||
handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
|
handler.setFormatter(
|
||||||
logging.basicConfig(
|
ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
level=logging.INFO,
|
|
||||||
handlers=[handler]
|
|
||||||
)
|
)
|
||||||
|
logging.basicConfig(level=logging.INFO, handlers=[handler])
|
||||||
|
|
||||||
|
|
||||||
setup_default_logging()
|
setup_default_logging()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_base_dir():
|
def get_base_dir():
|
||||||
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
# co-locate nanochat intermediates with other cached data in ~/.cache (by default)
|
||||||
if os.environ.get("NANOCHAT_BASE_DIR"):
|
if os.environ.get("NANOCHAT_BASE_DIR"):
|
||||||
|
|
@ -56,39 +71,44 @@ def get_base_dir():
|
||||||
os.makedirs(nanochat_dir, exist_ok=True)
|
os.makedirs(nanochat_dir, exist_ok=True)
|
||||||
return nanochat_dir
|
return nanochat_dir
|
||||||
|
|
||||||
def print0(s="",**kwargs):
|
|
||||||
ddp_rank = int(os.environ.get('RANK', 0))
|
def print0(s="", **kwargs):
|
||||||
|
ddp_rank = int(os.environ.get("RANK", 0))
|
||||||
if ddp_rank == 0:
|
if ddp_rank == 0:
|
||||||
print(s, **kwargs)
|
print(s, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def print_banner():
|
def print_banner():
|
||||||
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/
|
||||||
banner = """
|
banner = """
|
||||||
█████ █████
|
█████ █████
|
||||||
░░███ ░░███
|
░░███ ░░███
|
||||||
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
████████ ██████ ████████ ██████ ██████ ░███████ ██████ ███████
|
||||||
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███ ░░░███░
|
░░███░░███ ░░░░░███ ░░███░░███ ███░░███ ███░░███ ░███░░███ ░░░░░███░░░███░
|
||||||
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
░███ ░███ ███████ ░███ ░███ ░███ ░███░███ ░░░ ░███ ░███ ███████ ░███
|
||||||
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
░███ ░███ ███░░███ ░███ ░███ ░███ ░███░███ ███ ░███ ░███ ███░░███ ░███ ███
|
||||||
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░████████ ░░█████
|
████ █████░░████████ ████ █████░░██████ ░░██████ ████ █████░░███████ ░░█████
|
||||||
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
░░░░ ░░░░░ ░░░░░░░░ ░░░░ ░░░░░ ░░░░░░ ░░░░░░ ░░░░ ░░░░░ ░░░░░░░░ ░░░░░
|
||||||
"""
|
"""
|
||||||
print0(banner)
|
print0(banner)
|
||||||
|
|
||||||
|
|
||||||
def is_ddp():
|
def is_ddp():
|
||||||
# TODO is there a proper way
|
# TODO is there a proper way
|
||||||
return int(os.environ.get('RANK', -1)) != -1
|
return int(os.environ.get("RANK", -1)) != -1
|
||||||
|
|
||||||
|
|
||||||
def get_dist_info():
|
def get_dist_info():
|
||||||
if is_ddp():
|
if is_ddp():
|
||||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
assert all(var in os.environ for var in ["RANK", "LOCAL_RANK", "WORLD_SIZE"])
|
||||||
ddp_rank = int(os.environ['RANK'])
|
ddp_rank = int(os.environ["RANK"])
|
||||||
ddp_local_rank = int(os.environ['LOCAL_RANK'])
|
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
||||||
ddp_world_size = int(os.environ['WORLD_SIZE'])
|
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
||||||
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
return True, ddp_rank, ddp_local_rank, ddp_world_size
|
||||||
else:
|
else:
|
||||||
return False, 0, 0, 1
|
return False, 0, 0, 1
|
||||||
|
|
||||||
|
|
||||||
def compute_init():
|
def compute_init():
|
||||||
"""Basic initialization that we keep doing over and over, so make common."""
|
"""Basic initialization that we keep doing over and over, so make common."""
|
||||||
|
|
||||||
|
|
@ -121,16 +141,21 @@ def compute_init():
|
||||||
|
|
||||||
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
|
||||||
|
|
||||||
|
|
||||||
def compute_cleanup():
|
def compute_cleanup():
|
||||||
"""Companion function to compute_init, to clean things up before script exit"""
|
"""Companion function to compute_init, to clean things up before script exit"""
|
||||||
if is_ddp():
|
if is_ddp():
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
class DummyWandb:
|
class DummyWandb:
|
||||||
"""Useful if we wish to not use wandb but have all the same signatures"""
|
"""Useful if we wish to not use wandb but have all the same signatures"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def log(self, *args, **kwargs):
|
def log(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user