mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-20 04:59:08 +00:00
refactor out steps beforehand
This commit is contained in:
parent
8f979a8bda
commit
35ec43822c
|
|
@ -211,9 +211,15 @@ for step in range(num_iterations):
|
|||
break
|
||||
|
||||
# evaluate the gradient
|
||||
steps = [next(train_loader) for _ in range(grad_accum_steps)]
|
||||
num_tokens = torch.tensor(0, device=device) # the number of "active" tokens of supervision seen
|
||||
for micro_step in range(grad_accum_steps):
|
||||
train_inputs, train_targets = next(train_loader)
|
||||
num_tokens += sum((targets >= 0).sum() for _, targets in steps)
|
||||
|
||||
if ddp:
|
||||
dist.all_reduce(num_tokens, op=dist.ReduceOp.SUM) # sum over ranks
|
||||
|
||||
for micro_step, (train_inputs, train_targets) in enumerate(steps):
|
||||
batch_num_tokens = (train_targets >= 0).sum()
|
||||
with autocast_ctx:
|
||||
loss = model(train_inputs, train_targets)
|
||||
train_loss = loss.detach() # for logging
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user