diff --git a/scripts/base_eval.py b/scripts/base_eval.py index a57bbaf6..0b6b6e22 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -31,6 +31,10 @@ import tempfile import argparse import torch +from typing import Callable +from contextlib import nullcontext, contextmanager +from nanochat.gpt import Linear + from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes from nanochat.checkpoint_manager import load_model @@ -103,6 +107,82 @@ def place_eval_bundle(file_path): shutil.move(extracted_bundle_dir, eval_bundle_dir) print0(f"Placed eval_bundle directory at {eval_bundle_dir}") +# Context manager to temporarily disable FP8 so that model evaluation remains in BF16 +@contextmanager +def disable_fp8(model): + """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. + + CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, + we swap out Float8Linear modules entirely and restore them after. + """ + import torch.nn as nn + + # Find all Float8Linear modules and their locations + fp8_locations = [] # list of (parent_module, attr_name, fp8_module) + for name, module in model.named_modules(): + if 'Float8' in type(module).__name__: + if '.' in name: + parent_name, attr_name = name.rsplit('.', 1) + parent = model.get_submodule(parent_name) + else: + parent = model + attr_name = name + fp8_locations.append((parent, attr_name, module)) + + if not fp8_locations: + yield # No FP8 modules, nothing to do + return + + # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) + # Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards + for parent, attr_name, fp8_module in fp8_locations: + linear = Linear( + fp8_module.in_features, + fp8_module.out_features, + bias=fp8_module.bias is not None, + device="meta", # Use meta device to avoid unnecessary VRAM allocation + dtype=fp8_module.weight.dtype, + ) + linear.weight = fp8_module.weight # share, don't copy + if fp8_module.bias is not None: + linear.bias = fp8_module.bias + setattr(parent, attr_name, linear) + + try: + yield + finally: + # Restore Float8Linear modules + for parent, attr_name, fp8_module in fp8_locations: + setattr(parent, attr_name, fp8_module) + +def evaluate_sample(model, tokenizer, prompt_cb:Callable[[str], None], dis_fp8:bool = True, do_uncond:bool = False, uncond_cb:Callable[[str], None] | None = None): + prompts = [ + "The capital of France is", + "The chemical symbol of gold is", + "If yesterday was Friday, then tomorrow will be", + "The opposite of hot is", + "The planets of the solar system are:", + "My favorite color is", + "If 5*x + 3 = 13, then x is", + ] + engine = Engine(model, tokenizer) + print0("\nConditioned samples:") + for prompt in prompts: + tokens = tokenizer(prompt, prepend="<|bos|>") + ctx = disable_fp8(model) if dis_fp8 else nullcontext() + with ctx: + sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) + sample_str = tokenizer.decode(sample[0]) + if prompt_cb: prompt_cb(sample_str) + + # if Unconditioned samples also needed. + if do_uncond: + print0("\nUnconditioned samples:") + tokens = tokenizer("", prepend="<|bos|>") + uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) + for sample in uncond: + sample_str = tokenizer.decode(sample) + if uncond_cb: uncond_cb(sample_str) def evaluate_core(model, tokenizer, device, max_per_task=-1): """ @@ -227,33 +307,15 @@ def main(): print0("Model Samples") print0("="*80) if ddp_rank == 0: - prompts = [ - "The capital of France is", - "The chemical symbol of gold is", - "If yesterday was Friday, then tomorrow will be", - "The opposite of hot is", - "The planets of the solar system are:", - "My favorite color is", - "If 5*x + 3 = 13, then x is", - ] - engine = Engine(model, tokenizer) - print0("\nConditioned samples:") - for prompt in prompts: - tokens = tokenizer(prompt, prepend="<|bos|>") - sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) - sample_str = tokenizer.decode(sample[0]) + def prompt_cb(generated:str): print0("-" * 80) - print0(sample_str) - samples.append(sample_str) - - print0("\nUnconditioned samples:") - tokens = tokenizer("", prepend="<|bos|>") - uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0) - for sample in uncond: - sample_str = tokenizer.decode(sample) + print0(generated) + samples.append(generated) + def uncond_cb(generated:str): print0("-" * 80) - print0(sample_str) - unconditioned_samples.append(sample_str) + print0(generated) + unconditioned_samples.append(generated) + evaluate_sample(model, tokenizer, prompt_cb, False, True, uncond_cb) elif 'sample' in eval_modes and is_hf_model: print0("\nSkipping sampling for HuggingFace models (not supported)") diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..01edd1db 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -33,7 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine from nanochat.flash_attention import HAS_FA3 -from scripts.base_eval import evaluate_core +from scripts.base_eval import evaluate_core, disable_fp8, evaluate_sample print_banner() # ----------------------------------------------------------------------------- @@ -191,54 +191,6 @@ if args.fp8: num_skipped = num_linear - num_fp8 print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)") -# Context manager to temporarily disable FP8 so that model evaluation remains in BF16 -@contextmanager -def disable_fp8(model): - """Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation. - - CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead, - we swap out Float8Linear modules entirely and restore them after. - """ - import torch.nn as nn - - # Find all Float8Linear modules and their locations - fp8_locations = [] # list of (parent_module, attr_name, fp8_module) - for name, module in model.named_modules(): - if 'Float8' in type(module).__name__: - if '.' in name: - parent_name, attr_name = name.rsplit('.', 1) - parent = model.get_submodule(parent_name) - else: - parent = model - attr_name = name - fp8_locations.append((parent, attr_name, module)) - - if not fp8_locations: - yield # No FP8 modules, nothing to do - return - - # Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype) - # Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards - for parent, attr_name, fp8_module in fp8_locations: - linear = Linear( - fp8_module.in_features, - fp8_module.out_features, - bias=fp8_module.bias is not None, - device="meta", # Use meta device to avoid unnecessary VRAM allocation - dtype=fp8_module.weight.dtype, - ) - linear.weight = fp8_module.weight # share, don't copy - if fp8_module.bias is not None: - linear.bias = fp8_module.bias - setattr(parent, attr_name, linear) - - try: - yield - finally: - # Restore Float8Linear modules - for parent, attr_name, fp8_module in fp8_locations: - setattr(parent, attr_name, fp8_module) - # ----------------------------------------------------------------------------- # Compile the model @@ -456,21 +408,7 @@ while True: # use the original uncompiled model because the inputs keep changing shape if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)): model.eval() - prompts = [ - "The capital of France is", - "The chemical symbol of gold is", - "If yesterday was Friday, then tomorrow will be", - "The opposite of hot is", - "The planets of the solar system are:", - "My favorite color is", - "If 5*x + 3 = 13, then x is", - ] - engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation - for prompt in prompts: - tokens = tokenizer(prompt, prepend="<|bos|>") - with disable_fp8(orig_model): - sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) - print0(tokenizer.decode(sample[0])) + evaluate_sample(model, tokenizer, lambda x:print0(x), True) model.train() # save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step