delete spurious torch.empty allocation in adamw

fix: remove unnecessary tensor allocation in DistAdamW optimizer
This commit is contained in:
Andrej 2025-10-21 11:35:17 -07:00 committed by GitHub
commit 2e938530ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
grad_slices = [] grad_slices = []
for group in self.param_groups: for group in self.param_groups:
params: list[Tensor] = group["params"] params: list[Tensor] = group["params"]
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
for base_i in range(len(params)): for base_i in range(len(params)):
grad = params[base_i].grad grad = params[base_i].grad
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size