diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f683cd6..274d2ab 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -11,6 +11,8 @@ from jinja2 import Template import torch import torch.distributed as dist +from nanochat import eval_config + # ----------------------------------------------------------------------------- # Prompt rendering utilities @@ -267,10 +269,11 @@ def evaluate_task(model, tokenizer, data, device, task_meta): is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta) correct[idx] = float(is_correct) - # MEMORY FIX: Periodic cache cleanup every 100 examples + # MEMORY FIX: Periodic cache cleanup # This releases cached GPU memory and triggers Python GC # Prevents progressive slowdown from memory fragmentation - if idx % 100 == 0 and idx > 0: + # Interval configurable via eval_config.CACHE_CLEANUP_INTERVAL (default: 256) + if eval_config.ENABLE_PERIODIC_CLEANUP and idx % eval_config.CACHE_CLEANUP_INTERVAL == 0 and idx > 0: # Release PyTorch cached memory back to GPU if torch.cuda.is_available() and device.type == 'cuda': torch.cuda.empty_cache() @@ -287,7 +290,8 @@ def evaluate_task(model, tokenizer, data, device, task_meta): # MEMORY FIX: Final cleanup after task completes del correct - if torch.cuda.is_available() and device.type == 'cuda': - torch.cuda.empty_cache() + if eval_config.ENABLE_FINAL_CLEANUP: + if torch.cuda.is_available() and device.type == 'cuda': + torch.cuda.empty_cache() return mean_correct diff --git a/nanochat/eval_config.py b/nanochat/eval_config.py new file mode 100644 index 0000000..6127b1c --- /dev/null +++ b/nanochat/eval_config.py @@ -0,0 +1,31 @@ +""" +Configuration for evaluation memory management and performance tuning. + +These settings control memory cleanup intervals and other evaluation parameters +to prevent memory leaks and progressive slowdown during long-running evaluations. +""" + +# Memory Management Settings +# --------------------------- + +# Periodic cache cleanup interval (in examples processed) +# After processing this many examples, trigger torch.cuda.empty_cache() and gc.collect() +# to prevent memory fragmentation and progressive slowdown. +# +# Rationale for 256: +# - Balances cleanup overhead (~10-50ms per cleanup) vs memory accumulation +# - Power of 2 (efficient modulo operation) +# - For HellaSwag (10,000 examples): 39 cleanups total (~2s overhead) +# - For MMLU (100-1000 examples): 0-4 cleanups total (negligible overhead) +# +# Lower values (e.g., 100): More frequent cleanup, less fragmentation, higher overhead +# Higher values (e.g., 512): Less overhead, more fragmentation risk +CACHE_CLEANUP_INTERVAL = 256 + +# Enable periodic cache cleanup during evaluation +# Set to False to disable all periodic cleanup (not recommended for long evaluations) +ENABLE_PERIODIC_CLEANUP = True + +# Enable final cleanup after task completes +# Set to False to skip final cleanup (saves ~50ms but leaves memory cached) +ENABLE_FINAL_CLEANUP = True