mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-04 08:20:28 +00:00
Merge pull request #2 from Dianababaei/feat/gpt-initialize-kvcache-efficient-generation
Add Rotary Position Embeddings (RoPE) support to GPT model with configurable flag
This commit is contained in:
commit
d0383978df
|
|
@ -305,6 +305,15 @@ class GPT(nn.Module):
|
|||
if temperature > 0:
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
# Initialize KV cache for efficient generation
|
||||
kv_length_hint = len(tokens) + max_tokens
|
||||
kv_cache = KVCache(
|
||||
batch_size=1,
|
||||
num_heads=self.config.n_kv_head,
|
||||
seq_len=kv_length_hint,
|
||||
head_dim=self.config.n_embd // self.config.n_head,
|
||||
num_layers=self.config.n_layer
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
for _ in range(max_tokens):
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user