diff --git a/nanochat/adamw.py b/nanochat/adamw.py index 07b82de..db591de 100644 --- a/nanochat/adamw.py +++ b/nanochat/adamw.py @@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer): grad_slices = [] for group in self.param_groups: params: list[Tensor] = group["params"] - grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly for base_i in range(len(params)): grad = params[base_i].grad rank_size = grad.shape[0] // world_size