diff --git a/nanochat/engine.py b/nanochat/engine.py index a1ba24c..878900d 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -172,6 +172,10 @@ class Engine: """Same as generate, but does single prefill and then clones the KV cache.""" assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" device = self.model.get_device() + # The name of the device is either a string ("cpu") or a torch.device + # so we need to normalize it to a torch.device + if isinstance(device, str): + device = torch.device(device) # NOTE: setting the dtype here and in this way is an ugly hack. # Currently the repo assumes that cuda -> bfloat16 and everything else -> float32. # We need to know the dtype here to call __init__ on KVCache and pre-allocate its tensors. diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index 89ca42b..e928d7a 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -146,27 +146,54 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # SDPA fallback: manually manage KV cache B, T_new, H, D = q.shape - pos = cache_seqlens[0].item() # assume uniform position across batch + assert cache_seqlens is not None, "cache_seqlens is required for KV-cache SDPA fallback" - # Insert new k, v into cache (in-place, matching FA3 behavior) + #all rows decode in lockstep + #(this is how Engine.generate() currently uses the cache) + if torch.all(cache_seqlens == cache_seqlens[0]).item(): + pos = int(cache_seqlens[0].item()) + # Insert new k, v into cache (in-place, matching FA3 behavior) + if k is not None and v is not None: + k_cache[:, pos:pos+T_new, :, :] = k + v_cache[:, pos:pos+T_new, :, :] = v + + end_pos = pos + T_new + k_full = k_cache[:, :end_pos, :, :] + v_full = v_cache[:, :end_pos, :, :] + + q_sdpa = q.transpose(1, 2) + k_sdpa = k_full.transpose(1, 2) + v_sdpa = v_full.transpose(1, 2) + + enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) + y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + return y_sdpa.transpose(1, 2) + + #per-row cache positions + #FA3's KV cache API supports per-row cache_seqlens, SDPA does not, + #so we do a small per-row loop here (https://github.com/Dao-AILab/flash-attention) if k is not None and v is not None: - k_cache[:, pos:pos+T_new, :, :] = k - v_cache[:, pos:pos+T_new, :, :] = v + for b in range(B): + pos_b = int(cache_seqlens[b].item()) - # Get full cache up to current position + new tokens - end_pos = pos + T_new - k_full = k_cache[:, :end_pos, :, :] - v_full = v_cache[:, :end_pos, :, :] + k_cache[b,pos_b:pos_b+T_new,:,:]=k[b] + v_cache[b,pos_b:pos_b+T_new,:,:]=v[b] - # Transpose to SDPA layout: (B, T, H, D) -> (B, H, T, D) - q_sdpa = q.transpose(1, 2) - k_sdpa = k_full.transpose(1, 2) - v_sdpa = v_full.transpose(1, 2) + y_out=torch.empty_like(q) + for b in range(B): + end_pos_b=int(cache_seqlens[b].item()) + T_new - enable_gqa = q_sdpa.size(1) != k_sdpa.size(1) - y_sdpa = _sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + k_full=k_cache[b:b+1,:end_pos_b,:,:] + v_full=v_cache[b:b+1,:end_pos_b,:,:] - return y_sdpa.transpose(1, 2) # back to (B, T, H, D) + q_sdpa= q[b:b+1].transpose(1, 2) + k_sdpa=k_full.transpose(1, 2) + v_sdpa=v_full.transpose(1, 2) + + enable_gqa = q_sdpa.size(1)!=k_sdpa.size(1) + y_b=_sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size, enable_gqa) + y_out[b:b+1] = y_b.transpose(1, 2) + return y_out # ============================================================================= diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 9741c7f..250b76e 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -328,10 +328,65 @@ class TestSDPAOnly: ) cache.advance(1) - assert y_single.shape == (B, 1, H, D) assert cache.get_pos() == T_prefill + 1 set_impl(None) + def test_kvcache_variable_cache_seqlens(self): + """ + SDPA fallback must handle per-row cache positions + """ + set_impl("sdpa") + B,T_max,H,D = 2,32,4,16 + T_new = 4 + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + torch.manual_seed(0) + + #different prefix lengths per row + cache_seqlens = torch.tensor([8, 16], dtype=torch.int32, device=self.DEVICE) + + #pre-fill cache with distinct random prefixes + k_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_cache = torch.zeros(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_init = torch.randn(B, T_max, H, D, device=self.DEVICE, dtype=self.DTYPE) + for b in range(B): + pre = int(cache_seqlens[b].item()) + k_cache[b, :pre] = k_init[b, :pre] + v_cache[b, :pre] = v_init[b, :pre] + + q = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_new = torch.randn(B, T_new, H, D, device=self.DEVICE, dtype=self.DTYPE) + + y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k_new, v=v_new, cache_seqlens=cache_seqlens, causal=True, window_size=(T_max, 0)) + + #caches should have KV inserted at each row's position + for b in range(B): + pos = int(cache_seqlens[b].item()) + torch.testing.assert_close(k_cache[b, pos:pos+T_new], k_new[b]) + torch.testing.assert_close(v_cache[b, pos:pos+T_new], v_new[b]) + + #per-row correct behavior + y_ref = torch.empty_like(y) + for b in range(B): + pre = int(cache_seqlens[b].item()) + k_full = torch.cat([k_init[b:b+1, :pre], k_new[b:b+1]], dim=1) # (1,pre+T_new,H,D) + v_full = torch.cat([v_init[b:b+1, :pre], v_new[b:b+1]], dim=1) + + q_sdpa = q[b:b+1].transpose(1, 2) # (1,H,T_new,D) + k_sdpa = k_full.transpose(1, 2) # (1,H,Tk,D) + v_sdpa = v_full.transpose(1, 2) + + y_b = fa_module._sdpa_attention(q_sdpa, k_sdpa, v_sdpa, window_size=(T_max, 0), enable_gqa=False) + y_ref[b:b+1] = y_b.transpose(1, 2) + + #bf16 is expected to have slightly larger numerical deltas + atol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4 + rtol = 1e-2 if self.DTYPE == torch.bfloat16 else 1e-4 + torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) + + set_impl(None) # ============================================================================= # Override mechanism tests