Implements Block AttnRes from MoonshotAI (https://github.com/MoonshotAI/Attention-Residuals)
as an optional alternative to resid_lambdas/x0_lambdas for residual connections.
When enabled via --attn-res, replaces standard residual scaling with learned
depth-attention over block-level representations. Layers are partitioned into
blocks; at each sublayer, a softmax-weighted combination of all completed blocks
plus the current partial block determines the input to attention/MLP.
Design follows nanochat conventions:
- Two nn.Parameter(n_layer, D) pseudo-query vectors on GPT (like resid_lambdas)
- Uses existing parameterless norm() for key normalization (no learnable RMSNorm)
- Block class unchanged — all AttnRes logic lives in GPT.forward
- Minimal 6-line block_attn_res() core function
Changes:
- nanochat/gpt.py: block_attn_res(), AttnRes path in GPT.forward, config/init/optimizer
- nanochat/checkpoint_manager.py: backward-compat config patching
- scripts/base_train.py: --attn-res and --attn-res-block-size CLI args
- tests/test_attn_res.py: 18 tests covering unit/forward/backward/optimizer/inference
GPU results (depth=4, 20 steps, RTX 6000 Ada):
Standard: val_bpb 3.21 → 2.80, ~840K tok/sec
AttnRes: val_bpb 3.21 → 2.61, ~780K tok/sec
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
SDPA fallback now respects sliding window during single-token KV-cache
decode by slicing K/V to the last (window + 1) tokens.
Also simplifies the mask building for chunk inference to properly apply
sliding window in that path as well.
Fixes#452
Co-Authored-By: Kartik Vashishta <kartikv776@gmail.com>
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* test: add engine generation tests for expected invariants
- test_seed_reproducibility
- test_temperature_zero_determinism
- test_max_tokens_respected
- test_num_samples_count
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
* Fix temperature test
* add test for seed variation in sampling
Add test for seed variation in sampling with temperature > 0.
* Rename test for clarity
* Shorten assert msg
---------
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Previously, when generating multiple samples (num_samples > 1), the first
token after prefill was sampled once and broadcast to all rows, causing
all samples to start identically. Now the prefill logits are expanded to
num_samples and sampled independently for each row.
Also simplified the generation loop by moving the forward pass to the end
of the loop, eliminating the first_iteration flag and if/else branching.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Performance varies by machine and load, making hard assertions flaky.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>