From 5617ce8d69e0c65c005513f5388e8c1ff9aec2e8 Mon Sep 17 00:00:00 2001 From: Sermet Pekin <96650846+SermetPekin@users.noreply.github.com> Date: Tue, 21 Oct 2025 16:44:02 +0300 Subject: [PATCH] optimize print0 function on common.py optimize print0 function on common.py to check RANK environment variable only once while importing the module. --- nanochat/common.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 3ec9992..0f430b3 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -56,10 +56,17 @@ def get_base_dir(): os.makedirs(nanochat_dir, exist_ok=True) return nanochat_dir -def print0(s="",**kwargs): - ddp_rank = int(os.environ.get('RANK', 0)) - if ddp_rank == 0: +# Determine at import time which function to use +_ddp_rank = int(os.environ.get('RANK', 0)) + +if _ddp_rank == 0: + # On rank 0: print0 is just print + def print0(s="", **kwargs): print(s, **kwargs) +else: + # On other ranks: print0 is a no-op + def print0(s="", **kwargs): + pass def print_banner(): # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/