diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..b0da666 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -53,20 +53,27 @@ def use_calculator(expr): return eval_with_timeout(expr) # ----------------------------------------------------------------------------- +# KV cache with bounded memory growth:D +# >> why: Original resize_() grows unbounded causing OOM in web server with long conversations +# >> what: Add MAX_CACHE_SIZE limit + proper memory management with explicit deallocation +# >> how: we can check the bounds before resize, use new tensor allocation instead of resize_ for proper GC class KVCache: - """ - Works hand-in-hand with the GPT model to maintain the KV cache. - Note that the .pos advances automatically after the last layer of the Transformer inserts. - """ + MAX_CACHE_SIZE = 32768 def __init__(self, batch_size, num_heads, seq_len, head_dim, num_layers): - # Each of K/V is of shape (B, H, T, D) and we have one per layer of the Transformer. self.kv_shape = (num_layers, 2, batch_size, num_heads, seq_len, head_dim) self.kv_cache = None - self.pos = 0 # current position in time in the cache + self.pos = 0 + self.num_layers = num_layers + self.batch_size = batch_size + self.num_heads = num_heads + self.head_dim = head_dim def reset(self): self.pos = 0 + if self.kv_cache is not None: + del self.kv_cache + self.kv_cache = None def get_pos(self): return self.pos @@ -98,27 +105,33 @@ class KVCache: # 4) update the pos self.pos = other.pos + # Optimized insert with bounds checking and proper memory allocation :) + # >> Prevents OOM by enforcing MAX_CACHE_SIZE limit + # >> Uses explicit allocation instead of resize_ for better garbage collection + # >> Raises clear error instead of silent failure on overflow def insert_kv(self, layer_idx, k, v): - # Lazy initialize the cache here because we need to know the dtype/device if self.kv_cache is None: self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device) - # Insert new keys/values to the cache and return the full cache so far B, H, T_add, D = k.size() t0, t1 = self.pos, self.pos + T_add - # Dynamically grow the cache if needed + + if t1 > self.MAX_CACHE_SIZE: + raise RuntimeError(f"Sequence length {t1} exceeds MAX_CACHE_SIZE {self.MAX_CACHE_SIZE}") + if t1 > self.kv_cache.size(4): - t_needed = t1 + 1024 # as much as we need plus buffer of 1024 - t_needed = (t_needed + 1023) & ~1023 # then round up to the nearest multiple of 1024 - current_shape = list(self.kv_cache.shape) - current_shape[4] = t_needed - self.kv_cache.resize_(current_shape) - # Insert k, v into the cache + t_needed = min(t1 + 1024, self.MAX_CACHE_SIZE) + t_needed = (t_needed + 1023) & ~1023 + new_cache = torch.empty((self.num_layers, 2, self.batch_size, self.num_heads, t_needed, self.head_dim), + dtype=self.kv_cache.dtype, device=self.kv_cache.device) + new_cache[..., :self.kv_cache.size(4), :] = self.kv_cache + old_cache = self.kv_cache + self.kv_cache = new_cache + del old_cache + self.kv_cache[layer_idx, 0, :, :, t0:t1] = k self.kv_cache[layer_idx, 1, :, :, t0:t1] = v - # Return the full cached keys/values up to current position (as a view) key_view = self.kv_cache[layer_idx, 0, :, :, :t1] value_view = self.kv_cache[layer_idx, 1, :, :, :t1] - # Increment pos after the last layer of the Transformer processes if layer_idx == self.kv_cache.size(0) - 1: self.pos = t1 return key_view, value_view