diff --git a/nanochat/common.py b/nanochat/common.py index fb6c49a..8b10df9 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -76,8 +76,7 @@ def print_banner(): print0(banner) def is_ddp(): - if dist.is_initialized(): - return True + # TODO is there a proper way return int(os.environ.get('RANK', -1)) != -1 def get_dist_info():