Optimize print0 function: Cache DDP rank evaluation for better performance

Refactor print0 function to conditionally define behavior based on DDP rank.
This commit is contained in:
Sermet Pekin 2025-10-21 16:07:56 +03:00 committed by GitHub
parent 2e9669e03a
commit ce64059d65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -18,10 +18,17 @@ import os
import sys
from ast import literal_eval
def print0(s="",**kwargs):
ddp_rank = int(os.environ.get('RANK', 0))
if ddp_rank == 0:
# Determine at import time which function to use
_ddp_rank = int(os.environ.get('RANK', 0))
if _ddp_rank == 0:
# On rank 0: print0 is just print
def print0(s="", **kwargs):
print(s, **kwargs)
else:
# On other ranks: print0 is a no-op
def print0(s="", **kwargs):
pass
for arg in sys.argv[1:]:
if '=' not in arg: