This commit is contained in:
Salman Mohammadi 2025-11-03 12:28:15 +00:00
parent e0e168dacd
commit e243767cc3

View File

@ -112,8 +112,7 @@ with torch.device("meta"):
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
eval_model = model
eval_model = torch.compile(eval_model, fullgraph=True, dynamic=True)
eval_model = torch.compile(model, fullgraph=True, dynamic=True) # eval model compiled for dynamic shapes
model = torch.compile(model, fullgraph=True, dynamic=False) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")