fix normalization

This commit is contained in:
svlandeg 2025-12-30 11:13:05 +01:00
parent fc565d7294
commit 1ca9280328

View File

@ -221,9 +221,9 @@ for step in range(num_iterations):
for micro_step, (train_inputs, train_targets) in enumerate(steps):
batch_num_tokens = (train_targets >= 0).sum()
with autocast_ctx:
loss = model(train_inputs, train_targets)
loss = model(train_inputs, train_targets, loss_reduction='sum')
loss = loss / num_tokens # normalize loss here
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward() # accumulate the gradient
# learning rate scheduler