diff --git a/nanochat/common.py b/nanochat/common.py index a0867b0..d80d4ba 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -12,15 +12,14 @@ class ColoredFormatter(logging.Formatter): """Custom formatter that adds colors to log messages.""" # ANSI color codes COLORS = { - 'DEBUG': '\033[36m', # Cyan - 'INFO': '\033[32m', # Green + 'DEBUG': '\033[36m', # Cyan + 'INFO': '\033[32m', # Green 'WARNING': '\033[33m', # Yellow - 'ERROR': '\033[31m', # Red - 'CRITICAL': '\033[35m', # Magenta + 'ERROR': '\033[31m', # Red + 'CRITICAL': '\033[35m', # Magenta } RESET = '\033[0m' BOLD = '\033[1m' - def format(self, record): # Add color to the level name levelname = record.levelname @@ -35,7 +34,6 @@ class ColoredFormatter(logging.Formatter): message = re.sub(r'(Shard \d+)', rf'{self.COLORS["INFO"]}{self.BOLD}\1{self.RESET}', message) return message - def setup_default_logging(): handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) @@ -58,7 +56,7 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir -def print0(s="", **kwargs): +def print0(s="",**kwargs): ddp_rank = int(os.environ.get('RANK', 0)) if ddp_rank == 0: print(s, **kwargs)