Allow torchrun with 1 device

This commit is contained in:
Chris McCormick 2026-01-26 12:03:22 -08:00
parent 2e58d05782
commit d1595fb2d1

View File

@ -350,8 +350,8 @@ class GPT(nn.Module):
dict(params=x0_params, lr=scalar_lr),
]
# MuonAdamW for single-GPU, DistMuonAdamW for multi-GPU (with communication overlap)
OptimizerClass = DistMuonAdamW if ddp else MuonAdamW
# MuonAdamW for single-GPU, DistMuonAdamW for multi-GPU
OptimizerClass = DistMuonAdamW if (ddp and world_size > 1) else MuonAdamW
optimizer = OptimizerClass(
adamw_groups=adam_groups,
muon_params=matrix_params,