mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
remove kvcache import
This commit is contained in:
parent
7e87fa8a71
commit
134f9b7a8f
|
|
@ -22,7 +22,6 @@ import torch.nn.functional as F
|
|||
|
||||
from nanochat.adamw import DistAdamW
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.engine import KVCache
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
|
||||
|
||||
|
|
@ -73,7 +72,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self,
|
||||
x: torch.Tensor,
|
||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||
kv_cache: KVCache,
|
||||
kv_cache,
|
||||
) -> torch.Tensor:
|
||||
B, T, _ = x.size()
|
||||
|
||||
|
|
@ -149,7 +148,7 @@ class Block(nn.Module):
|
|||
self,
|
||||
x: torch.Tensor,
|
||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||
kv_cache: KVCache,
|
||||
kv_cache,
|
||||
) -> torch.Tensor:
|
||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||
x = x + self.mlp(norm(x))
|
||||
|
|
@ -268,7 +267,13 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None:
|
||||
def forward(
|
||||
self,
|
||||
idx: torch.Tensor,
|
||||
targets: torch.Tensor | None = None,
|
||||
kv_cache = None,
|
||||
loss_reduction: str = 'mean',
|
||||
) -> torch.Tensor | None:
|
||||
_, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user