optimize print0 function on common.py

optimize print0 function on common.py to check RANK environment variable only once while importing the module.
This commit is contained in:
Sermet Pekin 2025-10-21 16:44:02 +03:00 committed by GitHub
parent ce64059d65
commit 5617ce8d69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,10 +56,17 @@ def get_base_dir():
os.makedirs(nanochat_dir, exist_ok=True) os.makedirs(nanochat_dir, exist_ok=True)
return nanochat_dir return nanochat_dir
def print0(s="",**kwargs): # Determine at import time which function to use
ddp_rank = int(os.environ.get('RANK', 0)) _ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
if _ddp_rank == 0:
# On rank 0: print0 is just print
def print0(s="", **kwargs):
print(s, **kwargs) print(s, **kwargs)
else:
# On other ranks: print0 is a no-op
def print0(s="", **kwargs):
pass
def print_banner(): def print_banner():
# Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/ # Cool DOS Rebel font ASCII banner made with https://manytools.org/hacker-tools/ascii-banner/