mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-14 18:19:41 +00:00
bring back an assert guarding against bad param sizing
This commit is contained in:
parent
012da1a78b
commit
98eed6df18
|
|
@ -377,6 +377,7 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
param_infos[p] = dict(future=future, grad_slice=grad, is_small=True)
|
||||
else:
|
||||
# Large params: reduce_scatter
|
||||
assert grad.shape[0] % world_size == 0, f"AdamW reduce_scatter requires shape[0] ({grad.shape[0]}) divisible by world_size ({world_size})"
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user