This commit is contained in:
Jaber Jaber 2025-10-16 04:47:50 +04:00 committed by GitHub
commit a6955c541a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -53,20 +53,27 @@ def use_calculator(expr):
return eval_with_timeout(expr)
# -----------------------------------------------------------------------------
# KV cache with bounded memory growth:D
# >> why: Original resize_() grows unbounded causing OOM in web server with long conversations
# >> what: Add MAX_CACHE_SIZE limit + proper memory management with explicit deallocation
# >> how: we can check the bounds before resize, use new tensor allocation instead of resize_ for proper GC
class KVCache:
"""
Works hand-in-hand with the GPT model to maintain the KV cache.
Note that the .pos advances automatically after the last layer of the Transformer inserts.
"""
MAX_CACHE_SIZE = 32768
def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers):
# Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer.
self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
self.kv_cache = None
self.pos = 0 # current position in time in the cache
self.pos = 0
self.num_layers = num_layers
self.batch_size = batch_size
self.num_heads = num_heads
self.head_dim = head_dim
def reset(self):
self.pos = 0
if self.kv_cache is not None:
del self.kv_cache
self.kv_cache = None
def get_pos(self):
return self.pos
@ -98,27 +105,33 @@ class KVCache:
# 4) update the pos
self.pos = other.pos
# Optimized insert with bounds checking and proper memory allocation :)
# >> Prevents OOM by enforcing MAX_CACHE_SIZE limit
# >> Uses explicit allocation instead of resize_ for better garbage collection
# >> Raises clear error instead of silent failure on overflow
def insert_kv(self, layer_idx, k, v):
# Lazy initialize the cache here because we need to know the dtype/device
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
# Insert new keys/values to the cache and return the full cache so far
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
# Dynamically grow the cache if needed
if t1 > self.MAX_CACHE_SIZE:
raise RuntimeError(f"Sequence length {t1} exceeds MAX_CACHE_SIZE {self.MAX_CACHE_SIZE}")
if t1 > self.kv_cache.size(4):
t_needed = t1 + 1024 # as much as we need plus buffer of 1024
t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024
current_shape = list(self.kv_cache.shape)
current_shape[4] = t_needed
self.kv_cache.resize_(current_shape)
# Insert k, v into the cache
t_needed = min(t1 + 1024, self.MAX_CACHE_SIZE)
t_needed = (t_needed + 1023) & ~1023
new_cache = torch.empty((self.num_layers, 2, self.batch_size, self.num_heads, t_needed, self.head_dim),
dtype=self.kv_cache.dtype, device=self.kv_cache.device)
new_cache[..., :self.kv_cache.size(4), :] = self.kv_cache
old_cache = self.kv_cache
self.kv_cache = new_cache
del old_cache
self.kv_cache[layer_idx, 0, :, :, t0:t1] = k
self.kv_cache[layer_idx, 1, :, :, t0:t1] = v
# Return the full cached keys/values up to current position (as a view)
key_view = self.kv_cache[layer_idx, 0, :, :, :t1]
value_view = self.kv_cache[layer_idx, 1, :, :, :t1]
# Increment pos after the last layer of the Transformer processes
if layer_idx == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view