implement differential attention layers

This commit is contained in:
Tianyu Luo 2026-03-05 21:56:47 -05:00
parent 83dccc20ae
commit c6d44cf463
2 changed files with 149 additions and 21 deletions

View File

@ -14,6 +14,7 @@ Notable features:
from functools import partial
from dataclasses import dataclass
import math
import torch
import torch.nn as nn
@ -23,7 +24,7 @@ from nanochat.common import get_dist_info, print0
from nanochat.optim import MuonAdamW, DistMuonAdamW
# Our custom Flash Attention module that automatically uses FA3 on Hopper+ and SDPA fallback elsewhere
from nanochat.flash_attention import flash_attn
from nanochat.flash_attention import flash_attn, _sdpa_attention
@dataclass
class GPTConfig:
@ -37,6 +38,11 @@ class GPTConfig:
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# Differential attention (Differential Transformer, arxiv 2410.05258).
# Splits each head's Q and K into two halves (Q1/Q2, K1/K2) and computes
# DiffAttn = (softmax(Q1 K1^T/√d) - λ·softmax(Q2 K2^T/√d)) V,
# cancelling attention noise and promoting sparse patterns.
use_diff_attn: bool = False
def norm(x):
@ -72,6 +78,19 @@ class CausalSelfAttention(nn.Module):
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.ve_gate_channels = 32
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
# Differential attention parameters (paper: arxiv 2410.05258)
self.use_diff_attn = config.use_diff_attn
if self.use_diff_attn:
assert self.head_dim % 2 == 0, "head_dim must be even for differential attention"
self.half_head_dim = self.head_dim // 2
# Per-layer constant λ_init = 0.8 - 0.6·exp(-0.3·(l-1)), l=1-indexed layer
self.lambda_init = 0.8 - 0.6 * math.exp(-0.3 * layer_idx)
# Learnable vectors for the re-parameterized λ per head:
# λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init
self.lambda_q1 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim))
self.lambda_k1 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim))
self.lambda_q2 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim))
self.lambda_k2 = nn.Parameter(torch.empty(self.n_head, self.half_head_dim))
def forward(self, x, ve, cos_sin, window_size, kv_cache):
B, T, C = x.size()
@ -91,32 +110,120 @@ class CausalSelfAttention(nn.Module):
# Apply Rotary Embeddings to queries and keys to get relative positional encoding
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k) # QK norm
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
if self.use_diff_attn:
y = self._diff_attn_forward(q, k, v, window_size, kv_cache, B, T)
else:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
q, k = norm(q), norm(k) # QK norm
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
q, k_cache, v_cache,
k=k, v=v,
cache_seqlens=kv_cache.cache_seqlens,
causal=True,
window_size=window_size,
)
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
def _diff_attn_forward(self, q, k, v, window_size, kv_cache, B, T):
"""
Differential attention forward pass.
Splits each head's Q and K into two halves (Q1/Q2, K1/K2):
DiffAttn = (softmax(Q1 K1^T/d) - λ·softmax(Q2 K2^T/d)) V
where λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init (per-head scalar).
After computing all heads, applies per-head RMSNorm and scales by (1 - λ_init).
V retains the full head_dim so output shape matches standard attention.
"""
# Split Q and K each into two halves along the head_dim axis
# q: (B, T, n_head, head_dim) -> q1, q2: (B, T, n_head, half_head_dim) each
# k: (B, T, n_kv_head, head_dim) -> k1, k2: (B, T, n_kv_head, half_head_dim) each
q1 = q[..., :self.half_head_dim]
q2 = q[..., self.half_head_dim:]
k1 = k[..., :self.half_head_dim]
k2 = k[..., self.half_head_dim:]
# QK norm applied independently to each sub-head
q1, q2 = norm(q1), norm(q2)
k1, k2 = norm(k1), norm(k2)
if kv_cache is None:
# Training: two separate causal attention computations (FA3 or SDPA)
# attn1, attn2: (B, T, n_head, head_dim) — V dimension is head_dim
attn1 = flash_attn.flash_attn_func(q1, k1, v, causal=True, window_size=window_size)
attn2 = flash_attn.flash_attn_func(q2, k2, v, causal=True, window_size=window_size)
else:
# Inference: manually manage the KV cache.
# The cache stores full K (= K1||K2, head_dim) and V; we split K after reading.
k_full_new = k # (B, T, n_kv_head, head_dim) — stores K1 and K2 concatenated
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
pos = kv_cache.cache_seqlens[0].item()
# Write new K (full head_dim) and V into the cache
k_cache[:, pos:pos + T] = k_full_new
v_cache[:, pos:pos + T] = v
end_pos = pos + T
# Advance position after the last layer has written its cache
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
# Read the full cached context
k_ctx = k_cache[:, :end_pos] # (B, end_pos, n_kv_head, head_dim)
v_ctx = v_cache[:, :end_pos] # (B, end_pos, n_kv_head, head_dim)
# Split cached K into K1 and K2
k1_ctx = k_ctx[..., :self.half_head_dim]
k2_ctx = k_ctx[..., self.half_head_dim:]
# Compute two attention operations via SDPA (handles decode/prefill/window cases)
enable_gqa = self.n_head != self.n_kv_head
# Transpose to SDPA layout (B, H, T, D)
q1_t = q1.transpose(1, 2)
q2_t = q2.transpose(1, 2)
k1_t = k1_ctx.transpose(1, 2)
k2_t = k2_ctx.transpose(1, 2)
v_t = v_ctx.transpose(1, 2)
attn1 = _sdpa_attention(q1_t, k1_t, v_t, window_size, enable_gqa).transpose(1, 2)
attn2 = _sdpa_attention(q2_t, k2_t, v_t, window_size, enable_gqa).transpose(1, 2)
# Per-head scalar λ = exp(λq1·λk1) - exp(λq2·λk2) + λ_init
# lambda_q1, lambda_k1: (n_head, half_head_dim) -> dot product -> (n_head,)
lambda_val = (
torch.exp((self.lambda_q1 * self.lambda_k1).sum(-1)) -
torch.exp((self.lambda_q2 * self.lambda_k2).sum(-1)) +
self.lambda_init
) # shape: (n_head,)
# Differential combination: subtract the noise attention scaled by λ
# attn1, attn2: (B, T, n_head, head_dim); lambda_val broadcast over (1, 1, n_head, 1)
y = attn1 - lambda_val[None, None, :, None] * attn2
# Per-head RMSNorm (equivalent to GroupNorm with n_head groups).
# F.rms_norm normalizes over the last dim (head_dim) for each (B, T, head) element.
y = norm(y)
# Fixed scaling by (1 - λ_init) as specified in the paper
y = y * (1.0 - self.lambda_init)
return y
class MLP(nn.Module):
def __init__(self, config):
@ -229,6 +336,14 @@ class GPT(nn.Module):
if block.attn.ve_gate is not None:
torch.nn.init.zeros_(block.attn.ve_gate.weight)
# Differential attention lambda vectors: Normal(0, 0.1) as per the paper
for block in self.transformer.h:
if block.attn.use_diff_attn:
torch.nn.init.normal_(block.attn.lambda_q1, mean=0.0, std=0.1)
torch.nn.init.normal_(block.attn.lambda_k1, mean=0.0, std=0.1)
torch.nn.init.normal_(block.attn.lambda_q2, mean=0.0, std=0.1)
torch.nn.init.normal_(block.attn.lambda_k2, mean=0.0, std=0.1)
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
@ -349,14 +464,23 @@ class GPT(nn.Module):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Differential attention lambda vectors live inside transformer.h but must be
# excluded from Muon (which expects square weight matrices) and trained with AdamW.
_diff_lambda_names = {'lambda_q1', 'lambda_k1', 'lambda_q2', 'lambda_k2'}
diff_lambda_params = [
p for name, p in self.transformer.h.named_parameters()
if name.split('.')[-1] in _diff_lambda_names
]
_diff_lambda_ids = {id(p) for p in diff_lambda_params}
# Separate out all parameters into groups
matrix_params = list(self.transformer.h.parameters())
matrix_params = [p for p in self.transformer.h.parameters() if id(p) not in _diff_lambda_ids]
value_embeds_params = list(self.value_embeds.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
resid_params = [self.resid_lambdas]
x0_params = [self.x0_lambdas]
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params)
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) + len(diff_lambda_params)
# Scale the LR for the AdamW parameters by ∝1/√dmodel (tuned for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
@ -370,6 +494,8 @@ class GPT(nn.Module):
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0), # higher beta1 for x0
# Differential attention lambda vectors trained with AdamW (not Muon)
dict(kind='adamw', params=diff_lambda_params, lr=matrix_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
]
# Muon groups (matrix params, grouped by shape for stacking)
for shape in sorted({p.shape for p in matrix_params}):

View File

@ -51,6 +51,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--use-diff-attn", action="store_true", help="enable Differential Attention (arxiv 2410.05258)")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -133,6 +134,7 @@ def build_model_meta(depth):
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
use_diff_attn=args.use_diff_attn,
)
with torch.device("meta"):
model_meta = GPT(config)