This commit is contained in:
Kartik Vashishta 2026-02-01 22:31:03 -06:00 committed by GitHub
commit e0fc57fd45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 102 additions and 16 deletions

View File

@ -172,6 +172,10 @@ class Engine:
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
# The name of the device is either a string ("cpu") or a torch.device
# so we need to normalize it to a torch.device
if isinstance(device, str):
device = torch.device(device)
# NOTE: setting the dtype here and in this way is an ugly hack.
# Currently the repo assumes that cuda -> bfloat16 and everything else -> float32.
# We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors.

View File

@ -146,27 +146,54 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
# SDPA fallback: manually manage KV cache
B, T_new, H, D = q.shape
pos = cache_seqlens[0].item() # assume uniform position across batch
assert cache_seqlens is not None, "cache_seqlens is required for KV-cache SDPA fallback"
# Insert new k, v into cache (in-place, matching FA3 behavior)
#all rows decode in lockstep
#(this is how Engine.generate() currently uses the cache)
if torch.all(cache_seqlens == cache_seqlens[0]).item():
pos = int(cache_seqlens[0].item())
# Insert new k, v into cache (in-place, matching FA3 behavior)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k
v_cache[:, pos:pos+T_new, :, :] = v
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
return y_sdpa.transpose(1, 2)
#per-row cache positions
#FA3's KV cache API supports per-row cache_seqlens, SDPA does not,
#so we do a small per-row loop here (https://github.com/Dao-AILab/flash-attention)
if k is not None and v is not None:
k_cache[:, pos:pos+T_new, :, :] = k
v_cache[:, pos:pos+T_new, :, :] = v
for b in range(B):
pos_b = int(cache_seqlens[b].item())
# Get full cache up to current position + new tokens
end_pos = pos + T_new
k_full = k_cache[:, :end_pos, :, :]
v_full = v_cache[:, :end_pos, :, :]
k_cache[b,pos_b:pos_b+T_new,:,:]=k[b]
v_cache[b,pos_b:pos_b+T_new,:,:]=v[b]
# Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D)
q_sdpa = q.transpose(1, 2)
k_sdpa = k_full.transpose(1, 2)
v_sdpa = v_full.transpose(1, 2)
y_out=torch.empty_like(q)
for b in range(B):
end_pos_b=int(cache_seqlens[b].item()) + T_new
enable_gqa = q_sdpa.size(1) != k_sdpa.size(1)
y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
k_full=k_cache[b:b+1,:end_pos_b,:,:]
v_full=v_cache[b:b+1,:end_pos_b,:,:]
return y_sdpa.transpose(1, 2) # back to (B, T, H, D)
q_sdpa= q[b:b+1].transpose(1, 2)
k_sdpa=k_full.transpose(1, 2)
v_sdpa=v_full.transpose(1, 2)
enable_gqa = q_sdpa.size(1)!=k_sdpa.size(1)
y_b=_sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa)
y_out[b:b+1] = y_b.transpose(1, 2)
return y_out
# =============================================================================

View File

@ -328,10 +328,65 @@ class TestSDPAOnly:
)
cache.advance(1)
assert y_single.shape == (B, 1, H, D)
assert cache.get_pos() == T_prefill + 1
set_impl(None)
def test_kvcache_variable_cache_seqlens(self):
"""
SDPA fallback must handle per-row cache positions
"""
set_impl("sdpa")
B,T_max,H,D = 2,32,4,16
T_new = 4
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
torch.manual_seed(0)
#different prefix lengths per row
cache_seqlens = torch.tensor([8, 16], dtype=torch.int32, device=self.DEVICE)
#pre-fill cache with distinct random prefixes
k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE)
for b in range(B):
pre = int(cache_seqlens[b].item())
k_cache[b, :pre] = k_init[b, :pre]
v_cache[b, :pre] = v_init[b, :pre]
q = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k_new, v=v_new, cache_seqlens=cache_seqlens, causal=True, window_size=(T_max, 0))
#caches should have KV inserted at each row's position
for b in range(B):
pos = int(cache_seqlens[b].item())
torch.testing.assert_close(k_cache[b, pos:pos+T_new], k_new[b])
torch.testing.assert_close(v_cache[b, pos:pos+T_new], v_new[b])
#per-row correct behavior
y_ref = torch.empty_like(y)
for b in range(B):
pre = int(cache_seqlens[b].item())
k_full = torch.cat([k_init[b:b+1, :pre], k_new[b:b+1]], dim=1) # (1,pre+T_new,H,D)
v_full = torch.cat([v_init[b:b+1, :pre], v_new[b:b+1]], dim=1)
q_sdpa = q[b:b+1].transpose(1, 2) # (1,H,T_new,D)
k_sdpa = k_full.transpose(1, 2) # (1,H,Tk,D)
v_sdpa = v_full.transpose(1, 2)
y_b = fa_module._sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size=(T_max, 0), enable_gqa=False)
y_ref[b:b+1] = y_b.transpose(1, 2)
#bf16 is expected to have slightly larger numerical deltas
atol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4
rtol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4
torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol)
set_impl(None)
# =============================================================================
# Override mechanism tests