From bbc57da7d594029bbaf8604e1875b833ecb9b7a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mat=C4=9Bj=20Kripner?= Date: Tue, 9 Dec 2025 12:46:48 +0100 Subject: [PATCH] slightly nicer error message --- nanochat/adamw.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 2cbf1d4..8816057 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -27,7 +27,7 @@ class DistAdamW(torch.optim.Optimizer): for group in self.param_groups: params: list[Tensor] = group["params"] for base_i in range(len(params)): - assert params[base_i].shape[0] % world_size == 0, f"Parameter shape {params[base_i].shape} must be divisible by world size {world_size}" + assert params[base_i].shape[0] % world_size == 0, f"First dim of parameter shape {params[base_i].shape} must be divisible by world size {world_size}" grad = params[base_i].grad rank_size = grad.shape[0] // world_size grad_slice = torch.empty_like(grad[:rank_size])