diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index 9d56bcb..c9689cd 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -61,7 +61,7 @@ dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) # wandb logging init -wandb_run = get_wandb("nanochat-rl" , run=run, master_process=master_process, user_config=user_config) +wandb_run = get_wandb("nanochat-rl", run=run, master_process=master_process, user_config=user_config) # Init model and tokenizer model, tokenizer, meta = load_model(source, device, phase="eval")