diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..fb6c49a 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -76,7 +76,8 @@ def print_banner(): print0(banner) def is_ddp(): - # TODO is there a proper way + if dist.is_initialized(): + return True return int(os.environ.get('RANK', -1)) != -1 def get_dist_info():