mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge 07f193ab0d into fae3aca951
This commit is contained in:
commit
a6955c541a
|
|
@ -53,20 +53,27 @@ def use_calculator(expr):
|
|||
return eval_with_timeout(expr)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# KV cache with bounded memory growth:D
|
||||
# >> why: Original resize_() grows unbounded causing OOM in web server with long conversations
|
||||
# >> what: Add MAX_CACHE_SIZE limit + proper memory management with explicit deallocation
|
||||
# >> how: we can check the bounds before resize, use new tensor allocation instead of resize_ for proper GC
|
||||
class KVCache:
|
||||
"""
|
||||
Works hand-in-hand with the GPT model to maintain the KV cache.
|
||||
Note that the .pos advances automatically after the last layer of the Transformer inserts.
|
||||
"""
|
||||
MAX_CACHE_SIZE = 32768
|
||||
|
||||
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
|
||||
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
|
||||
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
self.kv_cache = None
|
||||
self.pos = 0 # current position in time in the cache
|
||||
self.pos = 0
|
||||
self.num_layers = num_layers
|
||||
self.batch_size = batch_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
def reset(self):
|
||||
self.pos = 0
|
||||
if self.kv_cache is not None:
|
||||
del self.kv_cache
|
||||
self.kv_cache = None
|
||||
|
||||
def get_pos(self):
|
||||
return self.pos
|
||||
|
|
@ -98,27 +105,33 @@ class KVCache:
|
|||
# 4) update the pos
|
||||
self.pos = other.pos
|
||||
|
||||
# Optimized insert with bounds checking and proper memory allocation :)
|
||||
# >> Prevents OOM by enforcing MAX_CACHE_SIZE limit
|
||||
# >> Uses explicit allocation instead of resize_ for better garbage collection
|
||||
# >> Raises clear error instead of silent failure on overflow
|
||||
def insert_kv(self, layer_idx, k, v):
|
||||
# Lazy initialize the cache here because we need to know the dtype/device
|
||||
if self.kv_cache is None:
|
||||
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
|
||||
# Insert new keys/values to the cache and return the full cache so far
|
||||
B, H, T_add, D = k.size()
|
||||
t0, t1 = self.pos, self.pos + T_add
|
||||
# Dynamically grow the cache if needed
|
||||
|
||||
if t1 > self.MAX_CACHE_SIZE:
|
||||
raise RuntimeError(f"Sequence length {t1} exceeds MAX_CACHE_SIZE {self.MAX_CACHE_SIZE}")
|
||||
|
||||
if t1 > self.kv_cache.size(4):
|
||||
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
|
||||
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
|
||||
current_shape = list(self.kv_cache.shape)
|
||||
current_shape[4] = t_needed
|
||||
self.kv_cache.resize_(current_shape)
|
||||
# Insert k, v into the cache
|
||||
t_needed = min(t1 + 1024, self.MAX_CACHE_SIZE)
|
||||
t_needed = (t_needed + 1023) & ~1023
|
||||
new_cache = torch.empty((self.num_layers, 2, self.batch_size, self.num_heads, t_needed, self.head_dim),
|
||||
dtype=self.kv_cache.dtype, device=self.kv_cache.device)
|
||||
new_cache[..., :self.kv_cache.size(4), :] = self.kv_cache
|
||||
old_cache = self.kv_cache
|
||||
self.kv_cache = new_cache
|
||||
del old_cache
|
||||
|
||||
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
|
||||
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
|
||||
# Return the full cached keys/values up to current position (as a view)
|
||||
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
|
||||
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
|
||||
# Increment pos after the last layer of the Transformer processes
|
||||
if layer_idx == self.kv_cache.size(0) - 1:
|
||||
self.pos = t1
|
||||
return key_view, value_view
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user