remove kvcache import

This commit is contained in:
Matt Murphy 2025-10-14 07:40:05 +00:00
parent 7e87fa8a71
commit 134f9b7a8f

View File

@ -22,7 +22,6 @@ import torch.nn.functional as F
from nanochat.adamw import DistAdamW from nanochat.adamw import DistAdamW
from nanochat.common import get_dist_info from nanochat.common import get_dist_info
from nanochat.engine import KVCache
from nanochat.muon import Muon, DistMuon from nanochat.muon import Muon, DistMuon
@ -73,7 +72,7 @@ class CausalSelfAttention(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor], cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache, kv_cache,
) -> torch.Tensor: ) -> torch.Tensor:
B, T, _ = x.size() B, T, _ = x.size()
@ -149,7 +148,7 @@ class Block(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
cos_sin: tuple[torch.Tensor, torch.Tensor], cos_sin: tuple[torch.Tensor, torch.Tensor],
kv_cache: KVCache, kv_cache,
) -> torch.Tensor: ) -> torch.Tensor:
x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x)) x = x + self.mlp(norm(x))
@ -268,7 +267,13 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
return optimizers return optimizers
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None: def forward(
self,
idx: torch.Tensor,
targets: torch.Tensor | None = None,
kv_cache = None,
loss_reduction: str = 'mean',
) -> torch.Tensor | None:
_, T = idx.size() _, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))