nanochat/nanochat/gpt.py
Yan Meng 741e54f360 perf: architecture + optimizer optimizations — 94.6 min to GPT-2 (4.3% speedup)
Two rounds of WeCo-guided D12 optimization, validated on D24.
Key changes: smaller sliding windows (seq/8), VE every 3rd layer,
RoPE 200K, smear removed, exponential residual decay, optimizer
buffer pre-allocation. Mean CORE=0.2591 across 3 D24 runs.
2026-03-19 23:15:26 +00:00

510 lines
25 KiB
Python

"""
GPT model (rewrite, a lot simpler)
Notable features:
- rotary embeddings (and no positional embeddings)
- QK norm
- untied weights for token embedding and lm_head
- relu^2 activation in MLP
- norm after token embedding
- no learnable params in rmsnorm
- no bias in linear layers
- Group-Query Attention (GQA) support for more efficient inference
- Flash Attention 3 integration
"""
from functools import partial
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0, COMPUTE_DTYPE
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
@dataclass
class GPTConfig:
sequence_len: int = 2048
vocab_size: int = 32768
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768
# Sliding window attention pattern string, tiled across layers. Final layer always L.
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
def norm(x):
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
class Linear(nn.Linear):
"""nn.Linear that casts weights to match input dtype in forward.
Replaces autocast: master weights stay fp32 for optimizer precision,
but matmuls run in the activation dtype (typically bf16 from embeddings)."""
def forward(self, x):
w = self.weight
if w.dtype != x.dtype:
w = w.to(dtype=x.dtype)
return F.linear(x, w)
class EmbeddingLinear(nn.Module):
"""Lightweight linear layer for lm_head without redundant dtype casting."""
def __init__(self, in_features, out_features, bias=False, device=None, dtype=None):
super().__init__()
assert not bias
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, device=device, dtype=dtype))
def forward(self, x):
return F.linear(x, self.weight)
def has_ve(layer_idx, n_layer):
"""Returns True if GPT layer should have Value Embedding (every 3rd layer, last layer always included)."""
return layer_idx % 3 == (n_layer - 1) % 3
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last dim into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 12
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
# Project the input to get queries, keys, and values
# Shape: (B, T, H, D) - FA3's native layout, no transpose needed!
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
if ve is not None:
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
gate = 3 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 3)
v = v + gate.unsqueeze(-1) * ve
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
q = q * 1.2 # sharper attention (split scale between Q and K), TODO think through better
k = k * 1.2
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config, pad_vocab_size_to=64):
"""
NOTE a major footgun: this __init__ function runs in meta device context (!!)
Therefore, any calculations inside here are shapes and dtypes only, no actual data.
=> We actually initialize all data (parameters, buffers, etc.) in init_weights() instead.
"""
super().__init__()
self.config = config
# Compute per-layer window sizes for sliding window attention
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
self.window_sizes = self._compute_window_sizes(config)
# Pad vocab for efficiency (DDP, tensor cores). This is just an optimization - outputs are cropped in forward().
# https://huggingface.co/docs/transformers/main_classes/model#transformers.PreTrainedModel.resize_token_embeddings
padded_vocab_size = ((config.vocab_size + pad_vocab_size_to - 1) // pad_vocab_size_to) * pad_vocab_size_to
if padded_vocab_size != config.vocab_size:
print0(f"Padding vocab_size from {config.vocab_size} to {padded_vocab_size} for efficiency")
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(padded_vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = EmbeddingLinear(config.n_embd, padded_vocab_size, bias=False)
# Per-layer learnable scalars (inspired by modded-nanogpt)
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
# Separate parameters so they can have different optimizer treatment
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
# Backout: subtract cached mid-layer residual before final norm to remove low-level features
self.backout_lambda = nn.Parameter(0.2 * torch.ones(1))
# Value embeddings (ResFormer-style): every 3rd layer, last layer always included
head_dim = config.n_embd // config.n_head
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
@torch.no_grad()
def init_weights(self):
"""
Initialize the full model in this one function for maximum clarity.
wte (embedding): normal, std=1.0
lm_head: normal, std=0.001
for each block:
attn.c_q: uniform, std=1/sqrt(n_embd)
attn.c_k: uniform, std=1/sqrt(n_embd)
attn.c_v: uniform, std=1/sqrt(n_embd)
attn.c_proj: zeros
mlp.c_fc: uniform, std=1/sqrt(n_embd)
mlp.c_proj: zeros
"""
# Embedding and unembedding
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=0.8)
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
# Transformer blocks: uniform init with bound = sqrt(3) * std (same standard deviation as normal)
n_embd = self.config.n_embd
s = 3**0.5 * n_embd**-0.5 # sqrt(3) multiplier makes sure Uniform achieves the same std as Normal
for block in self.transformer.h:
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s) # weights use Uniform to avoid outliers
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -0.85 * s, 0.85 * s)
torch.nn.init.uniform_(block.attn.c_proj.weight, -0.008, 0.008) # small nonzero init
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s * 0.4, s * 0.4) # 0.4x init scale for c_fc
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
# Per-layer resid init: exponential decay, stronger at early layers
import math
n_layer = self.config.n_layer
resid_start, resid_end = 1.18, 1.06
resid_decay = math.log(resid_start / resid_end) / max(n_layer - 1, 1)
for i in range(n_layer):
self.resid_lambdas.data[i] = resid_start * math.exp(-resid_decay * i)
# x0 init: first-half only, linearly decaying, zero for deep layers
half_depth = max(1, n_layer // 2)
for i in range(n_layer):
if i < half_depth:
frac = i / max(half_depth - 1, 1)
self.x0_lambdas.data[i] = 0.24 * (1.0 - frac) + 0.08 * frac
else:
self.x0_lambdas.data[i] = 0.0
# Value embeddings (init like c_v: uniform with same std)
for ve in self.value_embeds.values():
torch.nn.init.uniform_(ve.weight, -s, s)
# Gate weights init with small positive values so gates start slightly above neutral
for block in self.transformer.h:
if block.attn.ve_gate is not None:
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast embeddings to COMPUTE_DTYPE: optimizer can tolerate reduced-precision
# embeddings and it saves memory. Exception: fp16 requires fp32 embeddings
# because GradScaler cannot unscale fp16 gradients.
if COMPUTE_DTYPE != torch.float16:
self.transformer.wte.to(dtype=COMPUTE_DTYPE)
self.lm_head.to(dtype=COMPUTE_DTYPE)
for ve in self.value_embeds.values():
ve.to(dtype=COMPUTE_DTYPE)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=200000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
# stride the channels
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
# stride the time steps
t = torch.arange(seq_len, dtype=torch.float32, device=device)
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.to(COMPUTE_DTYPE), sin.to(COMPUTE_DTYPE)
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
def _compute_window_sizes(self, config):
"""
Compute per-layer window sizes for sliding window attention.
Returns list of (left, right) tuples for FA3's window_size parameter:
- left: how many tokens before current position to attend to (-1 = unlimited)
- right: how many tokens after current position to attend to (0 for causal)
Pattern string is tiled across layers. Final layer always gets L (full context).
Characters: L=long (full context), S=short (quarter context)
"""
pattern = config.window_pattern.upper()
assert all(c in "SL" for c in pattern), f"Invalid window_pattern: {pattern}. Use only S and L."
# Map characters to window sizes
long_window = config.sequence_len
short_window = max(256, -(-long_window // 8 // 128) * 128) # ceil to FA3 tile size (2048 -> 256)
char_to_window = {
"L": (long_window, 0),
"S": (short_window, 0),
}
# Tile pattern across layers
window_sizes = []
for layer_idx in range(config.n_layer):
char = pattern[layer_idx % len(pattern)]
window_sizes.append(char_to_window[char])
# Final layer always gets full context
window_sizes[-1] = (long_window, 0)
return window_sizes
def get_device(self):
return self.transformer.wte.weight.device
def estimate_flops(self):
"""
Return the estimated FLOPs per token for the model (forward + backward).
Each matmul weight parameter contributes 2 FLOPs (multiply *, accumulate +) in forward, and 2X that in backward => 2+4=6.
Cleanest explanation of this: https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4
On top of that, 12 * h * q * effective_seq_len accounts for key @ query matmul flops inside attention.
With sliding windows, effective_seq_len varies per layer (capped by window size).
Ref: https://arxiv.org/abs/2204.02311 (PaLM paper).
This is ~1% off from the exact formulas of Chinchilla paper, the difference is:
- Chinchilla counts the embedding layer as flops (? weird, it's just a lookup => we ignore)
- Chinchilla counts exp/sum/divide in attention softmax as flops (a little sus and very tiny => we ignore)
"""
nparams = sum(p.numel() for p in self.parameters())
# Exclude non-matmul params: embeddings and per-layer scalars
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
self.backout_lambda.numel())
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
# Sum attention FLOPs per layer, accounting for sliding window
attn_flops = 0
for window_size in self.window_sizes:
window = window_size[0] # (left, right) tuple, we use left
effective_seq = t if window < 0 else min(window, t)
attn_flops += 12 * h * q * effective_seq
num_flops_per_token = 6 * (nparams - nparams_exclude) + attn_flops
return num_flops_per_token
def num_scaling_params(self):
"""
Return detailed parameter counts for scaling law analysis.
Different papers use different conventions:
- Kaplan et al. excluded embedding parameters
- Chinchilla included all parameters
Ref: https://arxiv.org/abs/2203.15556 (Chinchilla paper)
Ref: https://arxiv.org/abs/2001.08361 (Kaplan et al. original scaling laws paper)
Returns a dict with counts for each parameter group, so downstream analysis
can experiment with which combination gives the cleanest scaling laws.
"""
# Count each group separately (mirrors the grouping in setup_optimizers)
wte = sum(p.numel() for p in self.transformer.wte.parameters())
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
lm_head = sum(p.numel() for p in self.lm_head.parameters())
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.backout_lambda.numel()
total = wte + value_embeds + lm_head + transformer_matrices + scalars
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
return {
'wte': wte,
'value_embeds': value_embeds,
'lm_head': lm_head,
'transformer_matrices': transformer_matrices,
'scalars': scalars,
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
backout_params = [self.backout_lambda]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(backout_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
print0(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
# Build param_groups with all required fields explicit
param_groups = [
# AdamW groups (embeddings, lm_head, scalars)
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=(0.8, 0.96), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.001),
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale * 0.5, betas=(0.8, 0.995), eps=1e-10, weight_decay=0.01),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
dict(kind='adamw', params=backout_params, lr=0.15, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):
group_params = [p for p in matrix_params if p.shape == shape]
param_groups.append(dict(
kind='muon', params=group_params, lr=matrix_lr,
momentum=0.95, ns_steps=5, beta2=0.9, weight_decay=weight_decay,
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == COMPUTE_DTYPE, f"Rotary embeddings must be in {COMPUTE_DTYPE}, got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
# Embed the tokens
x = self.transformer.wte(idx) # embed current token
x = x.to(COMPUTE_DTYPE) # ensure activations are in compute dtype (no-op usually, but active for fp16 code path)
x = norm(x)
# Forward the trunk of the Transformer
x0 = x # save initial normalized embedding for x0 residual
n_layer = self.config.n_layer
backout_layer = n_layer // 2 # cache at halfway point
x_backout = None
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
if i == backout_layer:
x_backout = x
# Subtract mid-layer residual to remove low-level features before logit projection
if x_backout is not None:
x = x - self.backout_lambda.to(x.dtype) * x_backout
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15 # smoothly cap the logits to the range [-softcap, softcap]
logits = self.lm_head(x) # (B, T, padded_vocab_size) <- very big tensor, large amount of memory
logits = logits[..., :self.config.vocab_size] # slice to remove padding
logits = logits.float() # switch to fp32 for logit softcap and loss computation
logits = softcap * torch.tanh(logits / softcap) # squash the logits
if targets is not None:
# training: given the targets, compute and return the loss
# TODO experiment with chunked cross-entropy?
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference: just return the logits directly
return logits
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume:
- batch size is 1
- ids and the yielded tokens are simple Python lists and ints
"""
assert isinstance(tokens, list)
device = self.get_device()
rng = None
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token