mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 20:32:14 +00:00
remove kvcache import
This commit is contained in:
parent
7e87fa8a71
commit
134f9b7a8f
|
|
@ -22,7 +22,6 @@ import torch.nn.functional as F
|
||||||
|
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
from nanochat.common import get_dist_info
|
from nanochat.common import get_dist_info
|
||||||
from nanochat.engine import KVCache
|
|
||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.muon import Muon, DistMuon
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -73,7 +72,7 @@ class CausalSelfAttention(nn.Module):
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||||
kv_cache: KVCache,
|
kv_cache,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
B, T, _ = x.size()
|
B, T, _ = x.size()
|
||||||
|
|
||||||
|
|
@ -149,7 +148,7 @@ class Block(nn.Module):
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
cos_sin: tuple[torch.Tensor, torch.Tensor],
|
||||||
kv_cache: KVCache,
|
kv_cache,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||||
x = x + self.mlp(norm(x))
|
x = x + self.mlp(norm(x))
|
||||||
|
|
@ -268,7 +267,13 @@ class GPT(nn.Module):
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
def forward(self, idx: torch.Tensor, targets: torch.Tensor | None = None, kv_cache: KVCache | None = None, loss_reduction: str = 'mean') -> torch.Tensor | None:
|
def forward(
|
||||||
|
self,
|
||||||
|
idx: torch.Tensor,
|
||||||
|
targets: torch.Tensor | None = None,
|
||||||
|
kv_cache = None,
|
||||||
|
loss_reduction: str = 'mean',
|
||||||
|
) -> torch.Tensor | None:
|
||||||
_, T = idx.size()
|
_, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user