mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
compile eval model also
This commit is contained in:
parent
cf587acb1a
commit
957a1f4394
|
|
@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
# Evaluate each task
|
||||
results = {}
|
||||
centered_results = {}
|
||||
task_times = []
|
||||
for task in tasks:
|
||||
start_time = time.time()
|
||||
label = task['label']
|
||||
|
|
@ -106,13 +107,20 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
|||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||
centered_results[label] = centered_result
|
||||
end_time = time.time()
|
||||
task_times.append(end_time - start_time)
|
||||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
||||
|
||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||
total_time = sum(task_times)
|
||||
print0(f"Task timing stats: total={total_time:.2f}s"
|
||||
f"| avg={total_time/len(task_times):.2f}s |"
|
||||
"+ min={min(task_times):.2f}s | max={max(task_times):.2f}s")
|
||||
out = {
|
||||
"results": results,
|
||||
"centered_results": centered_results,
|
||||
"core_metric": core_metric
|
||||
"core_metric": core_metric,
|
||||
"dt": total_time
|
||||
|
||||
}
|
||||
return out
|
||||
|
||||
|
|
@ -169,7 +177,8 @@ def main():
|
|||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
||||
|
||||
model = torch.compile(model)
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ with torch.device("meta"):
|
|||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
eval_model = model
|
||||
eval_model = torch.compile(eval_model, fullgraph=True, dynamic=True)
|
||||
model = torch.compile(model, fullgraph=True, dynamic=False) # TODO: dynamic True/False think through
|
||||
num_params = sum(p.numel() for p in model.parameters())
|
||||
print0(f"Number of parameters: {num_params:,}")
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
|
|
@ -206,8 +208,9 @@ for step in range(num_iterations + 1):
|
|||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
results = evaluate_model(eval_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f} Eval time: {results['dt']:.4f}s")
|
||||
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
"total_training_flops": flops_so_far,
|
||||
|
|
@ -229,7 +232,7 @@ for step in range(num_iterations + 1):
|
|||
"My favorite color is",
|
||||
"If 5*x + 3 = 13, then x is",
|
||||
]
|
||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
engine = Engine(eval_model, tokenizer) # use eval_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with autocast_ctx:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user