nanochat/nanochat/core_eval.py
haltingstate c4a183dfef Move memory cleanup settings to configurable eval_config
Extract hardcoded memory cleanup interval (100 → 256) and enable flags
to eval_config.py for better maintainability and tuning flexibility.

Changes:

1. Created nanochat/eval_config.py:
   - CACHE_CLEANUP_INTERVAL = 256 (changed from hardcoded 100)
   - ENABLE_PERIODIC_CLEANUP = True (allows disabling cleanup)
   - ENABLE_FINAL_CLEANUP = True (allows skipping final cleanup)
   - Documented rationale for 256: balances overhead vs fragmentation

2. Updated nanochat/core_eval.py:
   - Import eval_config module
   - Use eval_config.CACHE_CLEANUP_INTERVAL instead of hardcoded 100
   - Check eval_config.ENABLE_PERIODIC_CLEANUP flag before cleanup
   - Check eval_config.ENABLE_FINAL_CLEANUP flag for final cleanup

Rationale for 256 vs 100:
- Power of 2 (efficient modulo operation)
- Lower overhead: HellaSwag 10,000 examples: 39 cleanups (~2s) vs 100 cleanups (~5s)
- Still frequent enough to prevent fragmentation
- For MMLU (100-1000 examples): 0-4 cleanups (negligible impact)

Benefits:
- Centralizes tuning parameters in one location
- Allows easy experimentation with cleanup intervals
- Can disable cleanup for debugging/profiling
- Documents tradeoffs in config comments
- No magic numbers in evaluation code

Related: Previous commit a7066b8 (hellaswag memory leak fix)
2026-02-09 14:37:59 +08:00

298 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Functions for evaluating the CORE metric, as described in the DCLM paper.
https://arxiv.org/abs/2406.11794
TODOs:
- All tasks ~match except for squad. We get 31% reference is 37%. Figure out why.
"""
import random
from jinja2 import Template
import torch
import torch.distributed as dist
from nanochat import eval_config
# -----------------------------------------------------------------------------
# Prompt rendering utilities
def render_prompts_mc(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a multiple choice question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.query }}{{ continuation_delimiter }}{{ example.choices[example.gold] }}
{% endfor -%}
{{ item.query }}{{ continuation_delimiter }}{{ choice }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(choice=choice, **context) for choice in item['choices']]
return prompts
def render_prompts_schema(item, continuation_delimiter, fewshot_examples=None):
"""Render complete prompts for a schema question"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context_options[example.gold] }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ context }}{{ continuation_delimiter }}{{ item.continuation }}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
prompts = [template.render(context=context_option, **context)
for context_option in item['context_options']]
return prompts
def render_prompts_lm(item, continuation_delimiter, fewshot_examples=None):
"""
Render complete prompt for a language modeling task.
Notice that we manually trim the context in the template,
which in some datasets seems to have trailing whitespace (which we don't want).
"""
template_str = """
{%- for example in fewshot_examples -%}
{{ example.context | trim }}{{ continuation_delimiter }}{{ example.continuation }}
{% endfor -%}
{{ item.context | trim }}{{ continuation_delimiter }}{% if include_continuation %}{{ item.continuation }}{% endif %}""".strip()
template = Template(template_str)
fewshot_examples = fewshot_examples or []
context = {
'fewshot_examples': fewshot_examples,
'continuation_delimiter': continuation_delimiter,
'item': item
}
# Return two prompts: without and with the continuation
prompt_without = template.render(include_continuation=False, **context)
prompt_with = template.render(include_continuation=True, **context)
# Due to the way the data seems to be stored, I think I need to strip in the case of LM here.
# Otherwise we may get trailing whitespaces in prompt_without (which get absorbed into the next
# token in prompt_with), meaning we don't get a nice and clean prefix in the token space
# to detect the final continuation. Tokenizers...
prompt_without = prompt_without.strip()
return [prompt_without, prompt_with]
def find_common_length(token_sequences, direction='left'):
"""
Find the length of the common prefix or suffix across token sequences
- direction: 'left' for prefix, 'right' for suffix
"""
min_len = min(len(seq) for seq in token_sequences)
indices = {
'left': range(min_len),
'right': range(-1, -min_len-1, -1)
}[direction]
# Find the first position where the token sequences differ
for i, idx in enumerate(indices):
token = token_sequences[0][idx]
if not all(seq[idx] == token for seq in token_sequences):
return i
return min_len
def stack_sequences(tokens, pad_token_id):
"""Stack up a list of token sequences, pad to longest on the right"""
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
for i, x in enumerate(tokens):
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
return input_ids
def batch_sequences_mc(tokenizer, prompts):
# In multiple choice, contexts are the same but the continuation is different (common prefix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each continuation
answer_start_idx = find_common_length(tokens, direction='left')
start_indices = [answer_start_idx] * len(prompts)
end_indices = [len(x) for x in tokens]
return tokens, start_indices, end_indices
def batch_sequences_schema(tokenizer, prompts):
# In schema tasks, contexts vary but continuation is the same (common suffix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
# figure out the start and end of each context
suffix_length = find_common_length(tokens, direction='right')
end_indices = [len(x) for x in tokens]
start_indices = [ei - suffix_length for ei in end_indices]
return tokens, start_indices, end_indices
def batch_sequences_lm(tokenizer, prompts):
# In LM tasks, we have two prompts: without and with continuation
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
tokens_without, tokens_with = tokens
start_idx, end_idx = len(tokens_without), len(tokens_with)
assert start_idx < end_idx, "prompt without is supposed to be a prefix of prompt with"
assert tokens_without == tokens_with[:start_idx], "prompt without is supposed to be a prefix of prompt with"
# we only need the with continuation prompt in the LM task, i.e. batch size of 1
return [tokens_with], [start_idx], [end_idx]
@torch.no_grad()
def forward_model(model, input_ids):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
The last column of losses is set to nan because we don't have autoregressive targets there.
MEMORY FIX: Explicitly cleanup intermediate tensors to prevent GPU memory accumulation.
"""
batch_size, seq_len = input_ids.size()
outputs = model(input_ids)
# Roll the tensor to the left by one position to get the (autoregressive) target ids
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
# MEMORY FIX: Explicitly free large intermediate tensors
del outputs # outputs is largest tensor (B×T×V, ~GB for large models)
del target_ids # target_ids is B×T
return losses, predictions
@torch.no_grad()
def evaluate_example(idx, model, tokenizer, data, device, task_meta):
"""Evaluate a single example, return True if correct, False otherwise"""
item = data[idx]
task_type = task_meta['task_type']
num_fewshot = task_meta['num_fewshot']
continuation_delimiter = task_meta['continuation_delimiter']
# Sample few-shot examples (excluding current item)
fewshot_examples = []
if num_fewshot > 0:
rng = random.Random(1234 + idx)
available_indices = [i for i in range(len(data)) if i != idx]
fewshot_indices = rng.sample(available_indices, num_fewshot)
fewshot_examples = [data[i] for i in fewshot_indices]
# Render prompts and batch sequences based on task type
if task_type == 'multiple_choice':
prompts = render_prompts_mc(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_mc(tokenizer, prompts)
elif task_type == 'schema':
prompts = render_prompts_schema(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_schema(tokenizer, prompts)
elif task_type == 'language_modeling':
prompts = render_prompts_lm(item, continuation_delimiter, fewshot_examples)
tokens, start_idxs, end_idxs = batch_sequences_lm(tokenizer, prompts)
else:
raise ValueError(f"Unsupported task type: {task_type}")
# Some models can't forward sequences beyond a certain length (e.g. GPT-2)
# In these cases, we have to truncate sequences to max length and adjust the indices
if hasattr(model, 'max_seq_len') and model.max_seq_len is not None:
max_tokens = model.max_seq_len
new_tokens, new_start_idxs, new_end_idxs = [], [], []
for t, s, e in zip(tokens, start_idxs, end_idxs):
if len(t) > max_tokens:
num_to_crop = len(t) - max_tokens
new_tokens.append(t[-max_tokens:]) # take the last max_tokens tokens
new_start_idxs.append(s - num_to_crop) # shift the indices down
new_end_idxs.append(e - num_to_crop)
assert s - num_to_crop >= 0, "this should never happen right?"
assert e - num_to_crop >= 0, "this should never happen right?"
else:
new_tokens.append(t) # keep unchanged
new_start_idxs.append(s)
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
# Forward the model, get the autoregressive loss and argmax prediction at each token
losses, predictions = forward_model(model, input_ids)
# See if the losses/predictions come out correctly
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']
else:
raise ValueError(f"Unsupported task type: {task_type}")
# MEMORY FIX: Explicitly free tensors after extracting scalar result
del losses, predictions, input_ids
return is_correct
def evaluate_task(model, tokenizer, data, device, task_meta):
"""
This function is responsible for evaluating one task across many examples.
It also handles dispatch to all processes if the script is run with torchrun.
MEMORY FIX: Added periodic cache cleanup to prevent memory accumulation.
"""
import gc # For explicit garbage collection
rank = dist.get_rank() if dist.is_initialized() else 0
world_size = dist.get_world_size() if dist.is_initialized() else 1
correct = torch.zeros(len(data), dtype=torch.float32, device=device)
# stride the examples to each rank
for idx in range(rank, len(data), world_size):
is_correct = evaluate_example(idx, model, tokenizer, data, device, task_meta)
correct[idx] = float(is_correct)
# MEMORY FIX: Periodic cache cleanup
# This releases cached GPU memory and triggers Python GC
# Prevents progressive slowdown from memory fragmentation
# Interval configurable via eval_config.CACHE_CLEANUP_INTERVAL (default: 256)
if eval_config.ENABLE_PERIODIC_CLEANUP and idx % eval_config.CACHE_CLEANUP_INTERVAL == 0 and idx > 0:
# Release PyTorch cached memory back to GPU
if torch.cuda.is_available() and device.type == 'cuda':
torch.cuda.empty_cache()
# Force Python garbage collection
gc.collect()
# sync results across all the processes if running distributed
if world_size > 1:
dist.barrier()
dist.all_reduce(correct, op=dist.ReduceOp.SUM)
# compute the mean
mean_correct = correct.mean().item()
# MEMORY FIX: Final cleanup after task completes
del correct
if eval_config.ENABLE_FINAL_CLEANUP:
if torch.cuda.is_available() and device.type == 'cuda':
torch.cuda.empty_cache()
return mean_correct