diff --git a/scripts/base_loss.py b/scripts/base_loss.py index 508f22d..df74314 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -71,7 +71,7 @@ if ddp_rank == 0: # Log to report from nanochat.report import get_report -get_report().log(section="Base model loss", data=[ +get_report(exp_name=model_tag).log(section="Base model loss", data=[ { "train bpb": bpb_results["train"], "val bpb": bpb_results["val"], diff --git a/scripts/base_train.py b/scripts/base_train.py index 86dfba9..2766247 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -327,7 +327,7 @@ print0(f"Minimum validation bpb: {min_val_bpb:.4f}") # Log to report from nanochat.report import get_report -get_report().log(section="Base model training", data=[ +get_report(exp_name=run).log(section="Base model training", data=[ user_config, # CLI args { # stats about the training setup "Number of parameters": num_params,