From 134f9b7a8fb08aa28833cf61b37209f59101ca1b Mon Sep 17 00:00:00 2001 From: Matt Murphy Date: Tue, 14 Oct 2025 07:40:05 +0000 Subject: [PATCH] remove kvcache import --- nanochat/gpt.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index d3f7859..cf78adb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from nanochat.adamw import DistAdamW from nanochat.common import get_dist_info -from nanochat.engine import KVCache from nanochat.muon import Muon, DistMuon @@ -73,7 +72,7 @@ class CausalSelfAttention(nn.Module): self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache: KVCache, + kv_cache, ) -> torch.Tensor: B, T, _ = x.size() @@ -149,7 +148,7 @@ class Block(nn.Module): self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache: KVCache, + kv_cache, ) -> torch.Tensor: x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.mlp(norm(x)) @@ -268,7 +267,13 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers - def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None: + def forward( + self, + idx: torch.Tensor, + targets: torch.Tensor | None = None, + kv_cache = None, + loss_reduction: str = 'mean', + ) -> torch.Tensor | None: _, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))