diff --git a/nanochat/configurator.py b/nanochat/configurator.py index ec1b76d..4969faa 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -26,7 +26,9 @@ def print0(s="",**kwargs): for arg in sys.argv[1:]: if '=' not in arg: # assume it's the name of a config file - assert not arg.startswith('--') + if arg.startswith('--'): + # ignore flags like --help or others without = + continue config_file = arg print0(f"Overriding config with {config_file}:") with open(config_file) as f: diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 3271298..2657578 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -81,7 +81,8 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training yield inputs, targets, state_dict -def tokenizing_distributed_data_loader(*args, **kwargs): +def tokenizing_distributed_data_loader(*args, device="cuda", **kwargs): # helper function that only emits the inputs/targets and not the state_dict + kwargs["device"] = device for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets