diff --git a/scripts/base_train.py b/scripts/base_train.py index 6502035..476c8d5 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -141,13 +141,9 @@ adamw_optimizer, muon_optimizer = optimizers # Initialize the DataLoaders for train/val base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -<<<<<<< HEAD -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train") -build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size_val, max_seq_len, split="val") -======= + train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) ->>>>>>> 722da4f... trying to add basic cpu support, will try mps too x, y = next(train_loader) # kick off load of the very first batch of data # -----------------------------------------------------------------------------