diff --git a/nanochat/gpt.py b/nanochat/gpt.py index d3f7859..cf78adb 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -22,7 +22,6 @@ import torch.nn.functional as F from nanochat.adamw import DistAdamW from nanochat.common import get_dist_info -from nanochat.engine import KVCache from nanochat.muon import Muon, DistMuon @@ -73,7 +72,7 @@ class CausalSelfAttention(nn.Module): self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache: KVCache, + kv_cache, ) -> torch.Tensor: B, T, _ = x.size() @@ -149,7 +148,7 @@ class Block(nn.Module): self, x: torch.Tensor, cos_sin: tuple[torch.Tensor, torch.Tensor], - kv_cache: KVCache, + kv_cache, ) -> torch.Tensor: x = x + self.attn(norm(x), cos_sin, kv_cache) x = x + self.mlp(norm(x)) @@ -268,7 +267,13 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizers - def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None: + def forward( + self, + idx: torch.Tensor, + targets: torch.Tensor | None = None, + kv_cache = None, + loss_reduction: str = 'mean', + ) -> torch.Tensor | None: _, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))