mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-05 01:59:52 +00:00
Merge 181e7f1c15 into 230d6cf6c6
This commit is contained in:
commit
e0fc57fd45
|
|
@ -172,6 +172,10 @@ class Engine:
|
|||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
# The name of the device is either a string ("cpu") or a torch.device
|
||||
# so we need to normalize it to a torch.device
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
# NOTE: setting the dtype here and in this way is an ugly hack.
|
||||
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
|
||||
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.
|
||||
|
|
|
|||
|
|
@ -146,27 +146,54 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
|
||||
# SDPA fallback: manually manage KV cache
|
||||
B, T_new, H, D = q.shape
|
||||
pos = cache_seqlens[0].item() # assume uniform position across batch
|
||||
assert cache_seqlens is not None, "cache_seqlens is required for KV-cache SDPA fallback"
|
||||
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
#all rows decode in lockstep
|
||||
#(this is how Engine.generate() currently uses the cache)
|
||||
if torch.all(cache_seqlens == cache_seqlens[0]).item():
|
||||
pos = int(cache_seqlens[0].item())
|
||||
# Insert new k, v into cache (in-place, matching FA3 behavior)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
return y_sdpa.transpose(1, 2)
|
||||
|
||||
#per-row cache positions
|
||||
#FA3's KV cache API supports per-row cache_seqlens, SDPA does not,
|
||||
#so we do a small per-row loop here (https://github.com/Dao-AILab/flash-attention)
|
||||
if k is not None and v is not None:
|
||||
k_cache[:, pos:pos+T_new, :, :] = k
|
||||
v_cache[:, pos:pos+T_new, :, :] = v
|
||||
for b in range(B):
|
||||
pos_b = int(cache_seqlens[b].item())
|
||||
|
||||
# Get full cache up to current position + new tokens
|
||||
end_pos = pos + T_new
|
||||
k_full = k_cache[:, :end_pos, :, :]
|
||||
v_full = v_cache[:, :end_pos, :, :]
|
||||
k_cache[b,pos_b:pos_b+T_new,:,:]=k[b]
|
||||
v_cache[b,pos_b:pos_b+T_new,:,:]=v[b]
|
||||
|
||||
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
|
||||
q_sdpa = q.transpose(1, 2)
|
||||
k_sdpa = k_full.transpose(1, 2)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
y_out=torch.empty_like(q)
|
||||
for b in range(B):
|
||||
end_pos_b=int(cache_seqlens[b].item()) + T_new
|
||||
|
||||
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
|
||||
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
k_full=k_cache[b:b+1,:end_pos_b,:,:]
|
||||
v_full=v_cache[b:b+1,:end_pos_b,:,:]
|
||||
|
||||
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
|
||||
q_sdpa= q[b:b+1].transpose(1, 2)
|
||||
k_sdpa=k_full.transpose(1, 2)
|
||||
v_sdpa=v_full.transpose(1, 2)
|
||||
|
||||
enable_gqa = q_sdpa.size(1)!=k_sdpa.size(1)
|
||||
y_b=_sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
|
||||
y_out[b:b+1] = y_b.transpose(1, 2)
|
||||
return y_out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
|
|||
|
|
@ -328,10 +328,65 @@ class TestSDPAOnly:
|
|||
)
|
||||
cache.advance(1)
|
||||
|
||||
assert y_single.shape == (B, 1, H, D)
|
||||
assert cache.get_pos() == T_prefill + 1
|
||||
set_impl(None)
|
||||
|
||||
def test_kvcache_variable_cache_seqlens(self):
|
||||
"""
|
||||
SDPA fallback must handle per-row cache positions
|
||||
"""
|
||||
set_impl("sdpa")
|
||||
B,T_max,H,D = 2,32,4,16
|
||||
T_new = 4
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
torch.manual_seed(0)
|
||||
|
||||
#different prefix lengths per row
|
||||
cache_seqlens = torch.tensor([8, 16], dtype=torch.int32, device=self.DEVICE)
|
||||
|
||||
#pre-fill cache with distinct random prefixes
|
||||
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
for b in range(B):
|
||||
pre = int(cache_seqlens[b].item())
|
||||
k_cache[b, :pre] = k_init[b, :pre]
|
||||
v_cache[b, :pre] = v_init[b, :pre]
|
||||
|
||||
q = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k_new, v=v_new, cache_seqlens=cache_seqlens, causal=True, window_size=(T_max, 0))
|
||||
|
||||
#caches should have KV inserted at each row's position
|
||||
for b in range(B):
|
||||
pos = int(cache_seqlens[b].item())
|
||||
torch.testing.assert_close(k_cache[b, pos:pos+T_new], k_new[b])
|
||||
torch.testing.assert_close(v_cache[b, pos:pos+T_new], v_new[b])
|
||||
|
||||
#per-row correct behavior
|
||||
y_ref = torch.empty_like(y)
|
||||
for b in range(B):
|
||||
pre = int(cache_seqlens[b].item())
|
||||
k_full = torch.cat([k_init[b:b+1, :pre], k_new[b:b+1]], dim=1) # (1,pre+T_new,H,D)
|
||||
v_full = torch.cat([v_init[b:b+1, :pre], v_new[b:b+1]], dim=1)
|
||||
|
||||
q_sdpa = q[b:b+1].transpose(1, 2) # (1,H,T_new,D)
|
||||
k_sdpa = k_full.transpose(1, 2) # (1,H,Tk,D)
|
||||
v_sdpa = v_full.transpose(1, 2)
|
||||
|
||||
y_b = fa_module._sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size=(T_max, 0), enable_gqa=False)
|
||||
y_ref[b:b+1] = y_b.transpose(1, 2)
|
||||
|
||||
#bf16 is expected to have slightly larger numerical deltas
|
||||
atol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4
|
||||
rtol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4
|
||||
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
|
||||
|
||||
set_impl(None)
|
||||
|
||||
# =============================================================================
|
||||
# Override mechanism tests
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user