diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 2cbf1d4..8816057 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -27,7 +27,7 @@ class DistAdamW(torch.optim.Optimizer): for group in self.param_groups: params: list[Tensor] = group["params"] for base_i in range(len(params)): - assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}" + assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}" grad = params[base_i].grad rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size])