mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-07 08:19:52 +00:00
since for the evaluation of sample, there are duplicate code in base_train and base_eval, so here to get the code be a funtion (keep the logic the same, for example, in train, still only prompt and with disable_fp8 while in eval both prompt and uncondition with no disable_fp8). And the context funtion moved to base_eval since it seems for evaluation.
This commit is contained in:
parent
0aaca56805
commit
c784e222eb
|
|
@ -31,6 +31,10 @@ import tempfile
|
|||
import argparse
|
||||
import torch
|
||||
|
||||
from typing import Callable
|
||||
from contextlib import nullcontext, contextmanager
|
||||
from nanochat.gpt import Linear
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type, download_file_with_lock
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
|
@ -103,6 +107,82 @@ def place_eval_bundle(file_path):
|
|||
shutil.move(extracted_bundle_dir, eval_bundle_dir)
|
||||
print0(f"Placed eval_bundle directory at {eval_bundle_dir}")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
def disable_fp8(model):
|
||||
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
|
||||
|
||||
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
|
||||
we swap out Float8Linear modules entirely and restore them after.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
|
||||
# Find all Float8Linear modules and their locations
|
||||
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
|
||||
for name, module in model.named_modules():
|
||||
if 'Float8' in type(module).__name__:
|
||||
if '.' in name:
|
||||
parent_name, attr_name = name.rsplit('.', 1)
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent = model
|
||||
attr_name = name
|
||||
fp8_locations.append((parent, attr_name, module))
|
||||
|
||||
if not fp8_locations:
|
||||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
device="meta", # Use meta device to avoid unnecessary VRAM allocation
|
||||
dtype=fp8_module.weight.dtype,
|
||||
)
|
||||
linear.weight = fp8_module.weight # share, don't copy
|
||||
if fp8_module.bias is not None:
|
||||
linear.bias = fp8_module.bias
|
||||
setattr(parent, attr_name, linear)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore Float8Linear modules
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
setattr(parent, attr_name, fp8_module)
|
||||
|
||||
def evaluate_sample(model, tokenizer, prompt_cb:Callable[[str], None], dis_fp8:bool = True, do_uncond:bool = False, uncond_cb:Callable[[str], None] | None = None):
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(model, tokenizer)
|
||||
print0("\nConditioned samples:")
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
ctx = disable_fp8(model) if dis_fp8 else nullcontext()
|
||||
with ctx:
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
if prompt_cb: prompt_cb(sample_str)
|
||||
|
||||
# if Unconditioned samples also needed.
|
||||
if do_uncond:
|
||||
print0("\nUnconditioned samples:")
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in uncond:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
if uncond_cb: uncond_cb(sample_str)
|
||||
|
||||
def evaluate_core(model, tokenizer, device, max_per_task=-1):
|
||||
"""
|
||||
|
|
@ -227,33 +307,15 @@ def main():
|
|||
print0("Model Samples")
|
||||
print0("="*80)
|
||||
if ddp_rank == 0:
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(model, tokenizer)
|
||||
print0("\nConditioned samples:")
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
sample_str = tokenizer.decode(sample[0])
|
||||
def prompt_cb(generated:str):
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
samples.append(sample_str)
|
||||
|
||||
print0("\nUnconditioned samples:")
|
||||
tokens = tokenizer("", prepend="<|bos|>")
|
||||
uncond, _ = engine.generate_batch(tokens, num_samples=8, max_tokens=128, temperature=1.0)
|
||||
for sample in uncond:
|
||||
sample_str = tokenizer.decode(sample)
|
||||
print0(generated)
|
||||
samples.append(generated)
|
||||
def uncond_cb(generated:str):
|
||||
print0("-" * 80)
|
||||
print0(sample_str)
|
||||
unconditioned_samples.append(sample_str)
|
||||
print0(generated)
|
||||
unconditioned_samples.append(generated)
|
||||
evaluate_sample(model, tokenizer, prompt_cb, False, True, uncond_cb)
|
||||
elif 'sample' in eval_modes and is_hf_model:
|
||||
print0("\nSkipping sampling for HuggingFace models (not supported)")
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
|||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from scripts.base_eval import evaluate_core
|
||||
from scripts.base_eval import evaluate_core, disable_fp8, evaluate_sample
|
||||
print_banner()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -191,54 +191,6 @@ if args.fp8:
|
|||
num_skipped = num_linear - num_fp8
|
||||
print0(f"✓ FP8 training enabled ({args.fp8_recipe} scaling) - converted {num_fp8}/{num_linear} linear layers, skipped {num_skipped} (too small)")
|
||||
|
||||
# Context manager to temporarily disable FP8 so that model evaluation remains in BF16
|
||||
@contextmanager
|
||||
def disable_fp8(model):
|
||||
"""Temporarily swap Float8Linear modules with nn.Linear for BF16 evaluation.
|
||||
|
||||
CastConfig is a frozen dataclass, so we can't mutate scaling_type. Instead,
|
||||
we swap out Float8Linear modules entirely and restore them after.
|
||||
"""
|
||||
import torch.nn as nn
|
||||
|
||||
# Find all Float8Linear modules and their locations
|
||||
fp8_locations = [] # list of (parent_module, attr_name, fp8_module)
|
||||
for name, module in model.named_modules():
|
||||
if 'Float8' in type(module).__name__:
|
||||
if '.' in name:
|
||||
parent_name, attr_name = name.rsplit('.', 1)
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent = model
|
||||
attr_name = name
|
||||
fp8_locations.append((parent, attr_name, module))
|
||||
|
||||
if not fp8_locations:
|
||||
yield # No FP8 modules, nothing to do
|
||||
return
|
||||
|
||||
# Swap Float8Linear -> Linear (our custom class that casts weights to match input dtype)
|
||||
# Use device="meta" to avoid VRAM spike - the weight tensor will be swapped in afterwards
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
linear = Linear(
|
||||
fp8_module.in_features,
|
||||
fp8_module.out_features,
|
||||
bias=fp8_module.bias is not None,
|
||||
device="meta", # Use meta device to avoid unnecessary VRAM allocation
|
||||
dtype=fp8_module.weight.dtype,
|
||||
)
|
||||
linear.weight = fp8_module.weight # share, don't copy
|
||||
if fp8_module.bias is not None:
|
||||
linear.bias = fp8_module.bias
|
||||
setattr(parent, attr_name, linear)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore Float8Linear modules
|
||||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
setattr(parent, attr_name, fp8_module)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compile the model
|
||||
|
||||
|
|
@ -456,21 +408,7 @@ while True:
|
|||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if args.sample_every > 0 and master_process and (last_step or (step > 0 and step % args.sample_every == 0)):
|
||||
model.eval()
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The chemical symbol of gold is",
|
||||
"If yesterday was Friday, then tomorrow will be",
|
||||
"The opposite of hot is",
|
||||
"The planets of the solar system are:",
|
||||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with disable_fp8(orig_model):
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
evaluate_sample(model, tokenizer, lambda x:print0(x), True)
|
||||
model.train()
|
||||
|
||||
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user