mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Fix process group initialization for CPU DDP and improve cleanup safety
This commit is contained in:
parent
104308cf78
commit
9235fe4000
|
|
@ -162,11 +162,18 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
if ddp:
|
||||
if device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
elif device_type == "cpu":
|
||||
device = torch.device("cpu")
|
||||
dist.init_process_group(backend="gloo")
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device(device_type) # mps
|
||||
else:
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
|
|
@ -177,7 +184,7 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
|
|||
|
||||
def compute_cleanup():
|
||||
"""Companion function to compute_init, to clean things up before script exit"""
|
||||
if is_ddp():
|
||||
if is_ddp() and dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
class DummyWandb:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user