mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-21 02:44:13 +00:00
slightly nicer error message
This commit is contained in:
parent
f1bf69d562
commit
bbc57da7d5
|
|
@ -27,7 +27,7 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
for base_i in range(len(params)):
|
||||
assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}"
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user