mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-15 16:52:14 +00:00
compile eval model also
This commit is contained in:
parent
cf587acb1a
commit
957a1f4394
|
|
@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
# Evaluate each task
|
# Evaluate each task
|
||||||
results = {}
|
results = {}
|
||||||
centered_results = {}
|
centered_results = {}
|
||||||
|
task_times = []
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
label = task['label']
|
label = task['label']
|
||||||
|
|
@ -106,13 +107,20 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
|
||||||
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
|
||||||
centered_results[label] = centered_result
|
centered_results[label] = centered_result
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
task_times.append(end_time - start_time)
|
||||||
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
|
||||||
|
|
||||||
core_metric = sum(centered_results.values()) / len(centered_results)
|
core_metric = sum(centered_results.values()) / len(centered_results)
|
||||||
|
total_time = sum(task_times)
|
||||||
|
print0(f"Task timing stats: total={total_time:.2f}s"
|
||||||
|
f"| avg={total_time/len(task_times):.2f}s |"
|
||||||
|
"+ min={min(task_times):.2f}s | max={max(task_times):.2f}s")
|
||||||
out = {
|
out = {
|
||||||
"results": results,
|
"results": results,
|
||||||
"centered_results": centered_results,
|
"centered_results": centered_results,
|
||||||
"core_metric": core_metric
|
"core_metric": core_metric,
|
||||||
|
"dt": total_time
|
||||||
|
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
@ -169,7 +177,8 @@ def main():
|
||||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||||
|
|
||||||
|
model = torch.compile(model)
|
||||||
# Evaluate the model
|
# Evaluate the model
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||||
|
|
|
||||||
|
|
@ -112,7 +112,9 @@ with torch.device("meta"):
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
eval_model = model
|
||||||
|
eval_model = torch.compile(eval_model, fullgraph=True, dynamic=True)
|
||||||
|
model = torch.compile(model, fullgraph=True, dynamic=False) # TODO: dynamic True/False think through
|
||||||
num_params = sum(p.numel() for p in model.parameters())
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
print0(f"Number of parameters: {num_params:,}")
|
print0(f"Number of parameters: {num_params:,}")
|
||||||
num_flops_per_token = model.estimate_flops()
|
num_flops_per_token = model.estimate_flops()
|
||||||
|
|
@ -206,8 +208,9 @@ for step in range(num_iterations + 1):
|
||||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||||
model.eval()
|
model.eval()
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
results = evaluate_model(eval_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f} Eval time: {results['dt']:.4f}s")
|
||||||
|
|
||||||
wandb_run.log({
|
wandb_run.log({
|
||||||
"step": step,
|
"step": step,
|
||||||
"total_training_flops": flops_so_far,
|
"total_training_flops": flops_so_far,
|
||||||
|
|
@ -229,7 +232,7 @@ for step in range(num_iterations + 1):
|
||||||
"My favorite color is",
|
"My favorite color is",
|
||||||
"If 5*x + 3 = 13, then x is",
|
"If 5*x + 3 = 13, then x is",
|
||||||
]
|
]
|
||||||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
engine = Engine(eval_model, tokenizer) # use eval_model to avoid recompilation
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||||
with autocast_ctx:
|
with autocast_ctx:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user