diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 595d24c1..8e17935a 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -68,7 +68,7 @@ ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() # wandb logging init -wandb_run = get_wandb("nanochat-sft" , run=run, master_process=master_process, user_config=user_config) +wandb_run = get_wandb("nanochat-sft", run=run, master_process=master_process, user_config=user_config) # Load the model and tokenizer model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)