mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
Merge 9f15189853 into 5019accc5b
This commit is contained in:
commit
ff8942ac95
|
|
@ -26,6 +26,11 @@ def _patch_missing_config_keys(model_config_kwargs):
|
|||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
# AttnRes defaults to disabled
|
||||
if "attn_res" not in model_config_kwargs:
|
||||
model_config_kwargs["attn_res"] = False
|
||||
if "attn_res_block_size" not in model_config_kwargs:
|
||||
model_config_kwargs["attn_res_block_size"] = 4
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ Notable features:
|
|||
- no bias in linear layers
|
||||
- Group-Query Attention (GQA) support for more efficient inference
|
||||
- Flash Attention 3 integration
|
||||
- Optional Block Attention Residuals (AttnRes): https://github.com/MoonshotAI/Attention-Residuals
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
|
@ -37,11 +38,23 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
# Attention Residuals (Block AttnRes): replaces resid_lambdas/x0_lambdas with learned
|
||||
# depth-attention over block-level representations. See: https://github.com/MoonshotAI/Attention-Residuals
|
||||
attn_res: bool = False
|
||||
attn_res_block_size: int = 4 # layers per block (must be >= 2 and even)
|
||||
|
||||
|
||||
def norm(x):
|
||||
return F.rms_norm(x, (x.size(-1),)) # note that this will run in bf16, seems ok
|
||||
|
||||
def block_attn_res(blocks, partial_block, w):
|
||||
"""Depth-attention over block representations. w is a (D,) pseudo-query vector."""
|
||||
V = torch.stack(blocks + [partial_block]) # (N+1, B, T, D)
|
||||
K = norm(V)
|
||||
logits = torch.einsum('d, n b t d -> n b t', w, K)
|
||||
h = torch.einsum('n b t, n b t d -> b t d', logits.softmax(0), V)
|
||||
return h
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""nn.Linear that casts weights to match input dtype in forward.
|
||||
Replaces autocast: master weights stay fp32 for optimizer precision,
|
||||
|
|
@ -160,6 +173,13 @@ class GPT(nn.Module):
|
|||
"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# Validate AttnRes config
|
||||
if config.attn_res:
|
||||
assert config.attn_res_block_size >= 2 and config.attn_res_block_size % 2 == 0, \
|
||||
f"attn_res_block_size must be >= 2 and even, got {config.attn_res_block_size}"
|
||||
# Per-layer pseudo-query vectors for depth-attention (one before attn, one before mlp)
|
||||
self.attn_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
|
||||
self.mlp_res_proj = nn.Parameter(torch.zeros(config.n_layer, config.n_embd))
|
||||
# Compute per-layer window sizes for sliding window attention
|
||||
# window_size is (left, right) tuple: (-1, 0) for full context, (N, 0) for sliding window
|
||||
self.window_sizes = self._compute_window_sizes(config)
|
||||
|
|
@ -177,6 +197,7 @@ class GPT(nn.Module):
|
|||
# resid_lambdas: scales the residual stream at each layer (init 1.0 = neutral)
|
||||
# x0_lambdas: blends initial embedding back in at each layer (init 0.0 = disabled)
|
||||
# Separate parameters so they can have different optimizer treatment
|
||||
# When AttnRes is enabled, these are still created (for checkpoint compat) but unused in forward
|
||||
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights()
|
||||
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights()
|
||||
# Smear: mix previous token's embedding into current token (cheap bigram-like info)
|
||||
|
|
@ -247,6 +268,11 @@ class GPT(nn.Module):
|
|||
if block.attn.ve_gate is not None:
|
||||
torch.nn.init.uniform_(block.attn.ve_gate.weight, 0.0, 0.02)
|
||||
|
||||
# AttnRes pseudo-query vectors: small init
|
||||
if self.config.attn_res:
|
||||
torch.nn.init.normal_(self.attn_res_proj, mean=0.0, std=0.02)
|
||||
torch.nn.init.normal_(self.mlp_res_proj, mean=0.0, std=0.02)
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
|
|
@ -324,9 +350,13 @@ class GPT(nn.Module):
|
|||
nparams = sum(p.numel() for p in self.parameters())
|
||||
# Exclude non-matmul params: embeddings and per-layer scalars
|
||||
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
||||
attn_res_numel = 0
|
||||
if self.config.attn_res:
|
||||
attn_res_numel = self.attn_res_proj.numel() + self.mlp_res_proj.numel()
|
||||
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
||||
self.resid_lambdas.numel() + self.x0_lambdas.numel() +
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel())
|
||||
self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() +
|
||||
attn_res_numel)
|
||||
h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
# Sum attention FLOPs per layer, accounting for sliding window
|
||||
attn_flops = 0
|
||||
|
|
@ -354,7 +384,8 @@ class GPT(nn.Module):
|
|||
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
||||
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
||||
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel()
|
||||
attn_res = (self.attn_res_proj.numel() + self.mlp_res_proj.numel()) if self.config.attn_res else 0
|
||||
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel() + self.smear_gate.weight.numel() + self.smear_lambda.numel() + self.backout_lambda.numel() + attn_res
|
||||
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
||||
assert total == sum(p.numel() for p in self.parameters()), "Parameter count mismatch"
|
||||
return {
|
||||
|
|
@ -372,13 +403,14 @@ class GPT(nn.Module):
|
|||
|
||||
# Separate out all parameters into groups
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
attn_res_params = [self.attn_res_proj, self.mlp_res_proj] if self.config.attn_res else []
|
||||
value_embeds_params = list(self.value_embeds.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
resid_params = [self.resid_lambdas]
|
||||
x0_params = [self.x0_lambdas]
|
||||
smear_params = [self.smear_gate.weight, self.smear_lambda, self.backout_lambda]
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params)
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(smear_params) + len(attn_res_params)
|
||||
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
|
|
@ -393,7 +425,7 @@ class GPT(nn.Module):
|
|||
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.05),
|
||||
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
|
||||
dict(kind='adamw', params=smear_params, lr=0.2, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0),
|
||||
]
|
||||
] + ([dict(kind='adamw', params=attn_res_params, lr=0.02, betas=(0.8, 0.95), eps=1e-10, weight_decay=0.0)] if attn_res_params else [])
|
||||
# Muon groups (matrix params, grouped by shape for stacking)
|
||||
for shape in sorted({p.shape for p in matrix_params}):
|
||||
group_params = [p for p in matrix_params if p.shape == shape]
|
||||
|
|
@ -444,16 +476,41 @@ class GPT(nn.Module):
|
|||
x = x + gate * x_pre_smear
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
n_layer = self.config.n_layer
|
||||
backout_layer = n_layer // 2 # cache at halfway point
|
||||
x_backout = None
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
if not self.config.attn_res:
|
||||
# Standard residual path with resid_lambdas / x0_lambdas
|
||||
x0 = x # save initial normalized embedding for x0 residual
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
else:
|
||||
# AttnRes path: depth-attention over block-level representations
|
||||
# replaces resid_lambdas/x0_lambdas with learned depth-attention
|
||||
block_size = self.config.attn_res_block_size
|
||||
ar_blocks = [] # completed block representations
|
||||
partial = x
|
||||
for i, block in enumerate(self.transformer.h):
|
||||
ve = self.value_embeds[str(i)](idx).to(partial.dtype) if str(i) in self.value_embeds else None
|
||||
# Depth-attention before attention sublayer
|
||||
h = block_attn_res(ar_blocks, partial, self.attn_res_proj[i].to(partial.dtype))
|
||||
# Block boundary: save partial, start fresh
|
||||
if i > 0 and i % (block_size // 2) == 0:
|
||||
ar_blocks = ar_blocks + [partial]
|
||||
partial = None
|
||||
attn_out = block.attn(norm(h), ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
partial = partial + attn_out if partial is not None else attn_out
|
||||
# Depth-attention before MLP sublayer
|
||||
h = block_attn_res(ar_blocks, partial, self.mlp_res_proj[i].to(partial.dtype))
|
||||
mlp_out = block.mlp(norm(h))
|
||||
partial = partial + mlp_out
|
||||
if i == backout_layer:
|
||||
x_backout = partial
|
||||
x = partial
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
if x_backout is not None:
|
||||
x = x - self.backout_lambda.to(x.dtype) * x_backout
|
||||
|
|
|
|||
|
|
@ -52,6 +52,8 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--attn-res", action="store_true", help="enable Attention Residuals (Block AttnRes)")
|
||||
parser.add_argument("--attn-res-block-size", type=int, default=4, help="layers per AttnRes block (must be >= 2 and even)")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -137,6 +139,7 @@ def build_model_meta(depth):
|
|||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
attn_res=args.attn_res, attn_res_block_size=args.attn_res_block_size,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
|
|||
267
tests/test_attn_res.py
Normal file
267
tests/test_attn_res.py
Normal file
|
|
@ -0,0 +1,267 @@
|
|||
"""
|
||||
Test Attention Residuals (Block AttnRes) integration.
|
||||
|
||||
python -m pytest tests/test_attn_res.py -v
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
# Force float32 compute for CPU test stability (avoids bf16 NaN in small models
|
||||
# and dtype mismatches between model activations and KV cache on CUDA machines)
|
||||
import nanochat.gpt as _gpt_mod
|
||||
import nanochat.common as _common_mod
|
||||
_gpt_mod.COMPUTE_DTYPE = torch.float32
|
||||
_common_mod.COMPUTE_DTYPE = torch.float32
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig, block_attn_res
|
||||
|
||||
|
||||
def _make_config(attn_res=True, **kwargs):
|
||||
"""Create a small test config."""
|
||||
defaults = dict(
|
||||
sequence_len=64, vocab_size=256, n_layer=8, n_head=4, n_kv_head=4, n_embd=64,
|
||||
window_pattern="L", attn_res=attn_res, attn_res_block_size=4,
|
||||
)
|
||||
defaults.update(kwargs)
|
||||
return GPTConfig(**defaults)
|
||||
|
||||
|
||||
def _build_model(config):
|
||||
"""Build a model on meta device, move to CPU, init weights."""
|
||||
with torch.device("meta"):
|
||||
model = GPT(config)
|
||||
model.to_empty(device="cpu")
|
||||
model.init_weights()
|
||||
return model
|
||||
|
||||
|
||||
# ---- Unit tests for building blocks ----
|
||||
|
||||
def test_block_attn_res_function():
|
||||
"""block_attn_res produces correct shape and valid output."""
|
||||
B, T, D = 2, 8, 16
|
||||
blocks = [torch.randn(B, T, D) for _ in range(3)]
|
||||
partial = torch.randn(B, T, D)
|
||||
w = torch.randn(D)
|
||||
|
||||
h = block_attn_res(blocks, partial, w)
|
||||
assert h.shape == (B, T, D)
|
||||
assert torch.isfinite(h).all()
|
||||
|
||||
|
||||
def test_block_attn_res_single_block():
|
||||
"""With no prior blocks, output should equal the partial block itself."""
|
||||
B, T, D = 1, 4, 8
|
||||
partial = torch.randn(B, T, D)
|
||||
w = torch.randn(D)
|
||||
|
||||
h = block_attn_res([], partial, w)
|
||||
# With only one element, softmax over dim 0 gives 1.0, so h == partial
|
||||
assert torch.allclose(h, partial, atol=1e-6)
|
||||
|
||||
|
||||
# ---- Model construction tests ----
|
||||
|
||||
def test_model_creates_attn_res_params():
|
||||
"""When attn_res=True, GPT has attn_res_proj and mlp_res_proj parameters."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
assert hasattr(model, 'attn_res_proj')
|
||||
assert hasattr(model, 'mlp_res_proj')
|
||||
assert model.attn_res_proj.shape == (config.n_layer, config.n_embd)
|
||||
assert model.mlp_res_proj.shape == (config.n_layer, config.n_embd)
|
||||
|
||||
|
||||
def test_model_no_attn_res_params_when_disabled():
|
||||
"""When attn_res=False, GPT does NOT have AttnRes parameters."""
|
||||
config = _make_config(attn_res=False)
|
||||
model = _build_model(config)
|
||||
assert not hasattr(model, 'attn_res_proj')
|
||||
|
||||
|
||||
def test_attn_res_param_count_overhead():
|
||||
"""AttnRes adds minimal parameter overhead."""
|
||||
config_base = _make_config(attn_res=False)
|
||||
config_ar = _make_config(attn_res=True)
|
||||
model_base = _build_model(config_base)
|
||||
model_ar = _build_model(config_ar)
|
||||
n_base = sum(p.numel() for p in model_base.parameters())
|
||||
n_ar = sum(p.numel() for p in model_ar.parameters())
|
||||
overhead = (n_ar - n_base) / n_base
|
||||
# AttnRes adds 2 * n_layer * D params. For D=64, n_layer=8: 2*8*64 = 1024.
|
||||
assert overhead < 0.05, f"AttnRes overhead {overhead:.2%} is too high"
|
||||
assert n_ar > n_base, "AttnRes model should have more params"
|
||||
|
||||
|
||||
# ---- Forward pass tests ----
|
||||
|
||||
def test_forward_training_attn_res():
|
||||
"""Forward pass with attn_res=True produces valid loss."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
B, T = 2, 32
|
||||
idx = torch.randint(0, config.vocab_size, (B, T))
|
||||
targets = torch.randint(0, config.vocab_size, (B, T))
|
||||
loss = model(idx, targets=targets)
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
def test_forward_inference_attn_res():
|
||||
"""Forward pass without targets returns logits."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
model.eval()
|
||||
B, T = 1, 16
|
||||
idx = torch.randint(0, config.vocab_size, (B, T))
|
||||
logits = model(idx)
|
||||
assert logits.shape == (B, T, config.vocab_size)
|
||||
assert torch.isfinite(logits).all()
|
||||
|
||||
|
||||
def test_forward_standard_path_unchanged():
|
||||
"""Standard path (attn_res=False) still works correctly."""
|
||||
config = _make_config(attn_res=False)
|
||||
model = _build_model(config)
|
||||
# Disable backout for small test model (pre-existing NaN with n_embd=64)
|
||||
model.backout_lambda.data.fill_(0.0)
|
||||
B, T = 2, 32
|
||||
idx = torch.randint(0, config.vocab_size, (B, T))
|
||||
targets = torch.randint(0, config.vocab_size, (B, T))
|
||||
loss = model(idx, targets=targets)
|
||||
assert torch.isfinite(loss)
|
||||
|
||||
|
||||
def test_backward_attn_res():
|
||||
"""Gradients flow through AttnRes path."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
B, T = 2, 16
|
||||
idx = torch.randint(0, config.vocab_size, (B, T))
|
||||
targets = torch.randint(0, config.vocab_size, (B, T))
|
||||
loss = model(idx, targets=targets)
|
||||
loss.backward()
|
||||
# Check that AttnRes projection gradients are non-zero
|
||||
assert model.attn_res_proj.grad is not None
|
||||
assert model.mlp_res_proj.grad is not None
|
||||
|
||||
|
||||
# ---- Config validation tests ----
|
||||
|
||||
def test_invalid_block_size_odd():
|
||||
"""Odd block size should fail validation."""
|
||||
try:
|
||||
config = _make_config(attn_res=True, attn_res_block_size=3)
|
||||
_build_model(config)
|
||||
assert False, "Should have raised AssertionError"
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
def test_invalid_block_size_one():
|
||||
"""block_size=1 should fail validation."""
|
||||
try:
|
||||
config = _make_config(attn_res=True, attn_res_block_size=1)
|
||||
_build_model(config)
|
||||
assert False, "Should have raised AssertionError"
|
||||
except AssertionError:
|
||||
pass
|
||||
|
||||
|
||||
def test_valid_block_sizes():
|
||||
"""Various valid block sizes should work."""
|
||||
for bs in [2, 4, 6, 8]:
|
||||
config = _make_config(attn_res=True, attn_res_block_size=bs, n_layer=8)
|
||||
model = _build_model(config)
|
||||
idx = torch.randint(0, 256, (1, 16))
|
||||
logits = model(idx)
|
||||
assert torch.isfinite(logits).all(), f"NaN/Inf with block_size={bs}"
|
||||
|
||||
|
||||
# ---- Optimizer tests ----
|
||||
|
||||
def test_optimizer_setup_attn_res():
|
||||
"""Optimizer setup accounts for all AttnRes params."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
optimizer = model.setup_optimizer()
|
||||
assert optimizer is not None
|
||||
|
||||
|
||||
def test_optimizer_setup_standard():
|
||||
"""Optimizer setup works for standard path too."""
|
||||
config = _make_config(attn_res=False)
|
||||
model = _build_model(config)
|
||||
optimizer = model.setup_optimizer()
|
||||
assert optimizer is not None
|
||||
|
||||
|
||||
# ---- FLOPs and param counting tests ----
|
||||
|
||||
def test_estimate_flops_attn_res():
|
||||
"""FLOPs estimation works with AttnRes."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
flops = model.estimate_flops()
|
||||
assert flops > 0
|
||||
|
||||
|
||||
def test_num_scaling_params_attn_res():
|
||||
"""Parameter count is consistent with AttnRes."""
|
||||
config = _make_config(attn_res=True)
|
||||
model = _build_model(config)
|
||||
counts = model.num_scaling_params()
|
||||
assert counts['total'] == sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
# ---- KV cache inference tests ----
|
||||
|
||||
def test_kv_cache_inference_attn_res():
|
||||
"""AttnRes works with KV cache for inference."""
|
||||
from nanochat.engine import KVCache
|
||||
config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4)
|
||||
model = _build_model(config)
|
||||
model.eval()
|
||||
|
||||
B, T = 1, 8
|
||||
idx = torch.randint(0, config.vocab_size, (B, T))
|
||||
|
||||
# Prefill
|
||||
kv_cache = KVCache(
|
||||
batch_size=B, num_heads=config.n_kv_head,
|
||||
seq_len=32, head_dim=config.n_embd // config.n_head,
|
||||
num_layers=config.n_layer, device="cpu", dtype=torch.float32,
|
||||
)
|
||||
logits_prefill = model(idx, kv_cache=kv_cache)
|
||||
assert logits_prefill.shape == (B, T, config.vocab_size)
|
||||
|
||||
# Decode one token
|
||||
next_idx = torch.randint(0, config.vocab_size, (B, 1))
|
||||
logits_decode = model(next_idx, kv_cache=kv_cache)
|
||||
assert logits_decode.shape == (B, 1, config.vocab_size)
|
||||
assert torch.isfinite(logits_decode).all()
|
||||
|
||||
|
||||
def test_generate_attn_res():
|
||||
"""Full generation pipeline works with AttnRes."""
|
||||
from nanochat.engine import Engine
|
||||
from tests.test_engine import MockModel, ByteTokenizer
|
||||
|
||||
# We need a real model for this test, not mock
|
||||
config = _make_config(attn_res=True, n_layer=4, n_embd=64, n_head=4, n_kv_head=4)
|
||||
model = _build_model(config)
|
||||
model.eval()
|
||||
|
||||
tokenizer = ByteTokenizer()
|
||||
|
||||
# Override model methods that Engine needs
|
||||
model.get_device = lambda: torch.device("cpu")
|
||||
|
||||
engine = Engine(model, tokenizer)
|
||||
prompt = [0, 72, 101, 108, 108, 111] # some tokens
|
||||
|
||||
results, masks = engine.generate_batch(prompt, max_tokens=5, temperature=0.0, seed=42)
|
||||
assert len(results) == 1
|
||||
assert len(results[0]) >= len(prompt) # at least prompt tokens
|
||||
assert len(results[0]) <= len(prompt) + 5 # at most prompt + max_tokens
|
||||
Loading…
Reference in New Issue
Block a user