compile eval model also

This commit is contained in:
Salman Mohammadi 2025-11-03 11:42:34 +00:00
parent cf587acb1a
commit 957a1f4394
2 changed files with 18 additions and 6 deletions

View File

@ -75,6 +75,7 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
# Evaluate each task
results = {}
centered_results = {}
task_times = []
for task in tasks:
start_time = time.time()
label = task['label']
@ -106,13 +107,20 @@ def evaluate_model(model, tokenizer, device, max_per_task=-1):
centered_result = (accuracy - 0.01 * random_baseline) / (1.0 - 0.01 * random_baseline)
centered_results[label] = centered_result
end_time = time.time()
task_times.append(end_time - start_time)
print0(f"accuracy: {accuracy:.4f} | centered: {centered_result:.4f} | time: {end_time - start_time:.2f}s")
core_metric = sum(centered_results.values()) / len(centered_results)
total_time = sum(task_times)
print0(f"Task timing stats: total={total_time:.2f}s"
f"| avg={total_time/len(task_times):.2f}s |"
"+ min={min(task_times):.2f}s | max={max(task_times):.2f}s")
out = {
"results": results,
"centered_results": centered_results,
"core_metric": core_metric
"core_metric": core_metric,
"dt": total_time
}
return out
@ -169,7 +177,8 @@ def main():
model, tokenizer, meta = load_model("base", device, phase="eval")
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
model = torch.compile(model)
# Evaluate the model
with autocast_ctx:
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)

View File

@ -112,7 +112,9 @@ with torch.device("meta"):
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
eval_model = model
eval_model = torch.compile(eval_model, fullgraph=True, dynamic=True)
model = torch.compile(model, fullgraph=True, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
@ -206,8 +208,9 @@ for step in range(num_iterations + 1):
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval()
with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
results = evaluate_model(eval_model, tokenizer, device, max_per_task=core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f} Eval time: {results['dt']:.4f}s")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
@ -229,7 +232,7 @@ for step in range(num_iterations + 1):
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
engine = Engine(eval_model, tokenizer) # use eval_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx: