diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..747669b 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -76,8 +76,7 @@ def print_banner(): print0(banner) def is_ddp(): - # TODO is there a proper way - return int(os.environ.get('RANK', -1)) != -1 + return dist.is_available() and dist.is_initialized() def get_dist_info(): if is_ddp():