since for the evaluation of sample, there are duplicate code in base_train and base_eval, so here to get the code be a funtion (keep the logic the same, for example, in train, still only prompt and with disable_fp8 while in eval both prompt and uncondition with no disable_fp8). And the context funtion moved to base_eval since it seems for evaluation.

This commit is contained in:
xrdddd 2026-05-05 08:22:57 +00:00
parent 0aaca56805
commit c784e222eb
2 changed files with 89 additions and 89 deletions

View File

@ -31,6 +31,10 @@ import tempfile
import argparse
import torch
from typing import Callable
from contextlib import nullcontext, contextmanager
from nanochat.gpt import Linear
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
@ -103,6 +107,82 @@ def place_eval_bundle(file_path):
shutil.move(extracted_bundle_dir, eval_bundle_dir)
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager
def disable_fp8(model):
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
we swap out Float8Linear modules entirely and restore them after.
"""
import torch.nn as nn
# Find all Float8Linear modules and their locations
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
for name, module in model.named_modules():
if 'Float8' in type(module).__name__:
if '.' in name:
parent_name, attr_name = name.rsplit('.', 1)
parent = model.get_submodule(parent_name)
else:
parent = model
attr_name = name
fp8_locations.append((parent, attr_name, module))
if not fp8_locations:
yield # No FP8 modules, nothing to do
return
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
for parent, attr_name, fp8_module in fp8_locations:
linear = Linear(
fp8_module.in_features,
fp8_module.out_features,
bias=fp8_module.bias is not None,
device="meta", # Use meta device to avoid unnecessary VRAM allocation
dtype=fp8_module.weight.dtype,
)
linear.weight = fp8_module.weight # share, don't copy
if fp8_module.bias is not None:
linear.bias = fp8_module.bias
setattr(parent, attr_name, linear)
try:
yield
finally:
# Restore Float8Linear modules
for parent, attr_name, fp8_module in fp8_locations:
setattr(parent, attr_name, fp8_module)
def evaluate_sample(model, tokenizer, prompt_cb:Callable[[str], None], dis_fp8:bool = True, do_uncond:bool = False, uncond_cb:Callable[[str], None] | None = None):
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
print0("\nConditioned samples:")
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
ctx = disable_fp8(model) if dis_fp8 else nullcontext()
with ctx:
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
if prompt_cb: prompt_cb(sample_str)
# if Unconditioned samples also needed.
if do_uncond:
print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>")
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond:
sample_str = tokenizer.decode(sample)
if uncond_cb: uncond_cb(sample_str)
def evaluate_core(model, tokenizer, device, max_per_task=-1):
"""
@ -227,33 +307,15 @@ def main():
print0("Model Samples")
print0("="*80)
if ddp_rank == 0:
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
print0("\nConditioned samples:")
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
sample_str = tokenizer.decode(sample[0])
def prompt_cb(generated:str):
print0("-" * 80)
print0(sample_str)
samples.append(sample_str)
print0("\nUnconditioned samples:")
tokens = tokenizer("", prepend="<|bos|>")
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
for sample in uncond:
sample_str = tokenizer.decode(sample)
print0(generated)
samples.append(generated)
def uncond_cb(generated:str):
print0("-" * 80)
print0(sample_str)
unconditioned_samples.append(sample_str)
print0(generated)
unconditioned_samples.append(generated)
evaluate_sample(model, tokenizer, prompt_cb, False, True, uncond_cb)
elif 'sample' in eval_modes and is_hf_model:
print0("\nSkipping sampling for HuggingFace models (not supported)")

View File

@ -33,7 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from scripts.base_eval import evaluate_core
from scripts.base_eval import evaluate_core, disable_fp8, evaluate_sample
print_banner()
# -----------------------------------------------------------------------------
@ -191,54 +191,6 @@ if args.fp8:
num_skipped = num_linear - num_fp8
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
@contextmanager
def disable_fp8(model):
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
we swap out Float8Linear modules entirely and restore them after.
"""
import torch.nn as nn
# Find all Float8Linear modules and their locations
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
for name, module in model.named_modules():
if 'Float8' in type(module).__name__:
if '.' in name:
parent_name, attr_name = name.rsplit('.', 1)
parent = model.get_submodule(parent_name)
else:
parent = model
attr_name = name
fp8_locations.append((parent, attr_name, module))
if not fp8_locations:
yield # No FP8 modules, nothing to do
return
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
for parent, attr_name, fp8_module in fp8_locations:
linear = Linear(
fp8_module.in_features,
fp8_module.out_features,
bias=fp8_module.bias is not None,
device="meta", # Use meta device to avoid unnecessary VRAM allocation
dtype=fp8_module.weight.dtype,
)
linear.weight = fp8_module.weight # share, don't copy
if fp8_module.bias is not None:
linear.bias = fp8_module.bias
setattr(parent, attr_name, linear)
try:
yield
finally:
# Restore Float8Linear modules
for parent, attr_name, fp8_module in fp8_locations:
setattr(parent, attr_name, fp8_module)
# -----------------------------------------------------------------------------
# Compile the model
@ -456,21 +408,7 @@ while True:
# use the original uncompiled model because the inputs keep changing shape
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
model.eval()
prompts = [
"The capital of France is",
"The chemical symbol of gold is",
"If yesterday was Friday, then tomorrow will be",
"The opposite of hot is",
"The planets of the solar system are:",
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with disable_fp8(orig_model):
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
evaluate_sample(model, tokenizer, lambda x:print0(x), True)
model.train()
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step