From f97e55eb93bade3223f84e3f67a3e2d2df3c730a Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Fri, 21 Nov 2025 05:12:07 +0000 Subject: [PATCH] Fix TypeError in tokenizing_distributed_data_loader and robustness in configurator.py - Explicitly add `device` argument to `tokenizing_distributed_data_loader` in `nanochat/dataloader.py` to prevent `TypeError: unexpected keyword argument 'device'` when called from `scripts/base_train.py`. - Update `nanochat/configurator.py` to ignore command-line flags starting with `--` (e.g. `--help`) instead of raising `AssertionError`, improving robustness when running with various launchers or flags. --- nanochat/configurator.py | 4 +++- nanochat/dataloader.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/nanochat/configurator.py b/nanochat/configurator.py index ec1b76d..4969faa 100644 --- a/nanochat/configurator.py +++ b/nanochat/configurator.py @@ -26,7 +26,9 @@ def print0(s="",**kwargs): for arg in sys.argv[1:]: if '=' not in arg: # assume it's the name of a config file - assert not arg.startswith('--') + if arg.startswith('--'): + # ignore flags like --help or others without = + continue config_file = arg print0(f"Overriding config with {config_file}:") with open(config_file) as f: diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 3271298..2657578 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -81,7 +81,8 @@ def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training yield inputs, targets, state_dict -def tokenizing_distributed_data_loader(*args, **kwargs): +def tokenizing_distributed_data_loader(*args, device="cuda", **kwargs): # helper function that only emits the inputs/targets and not the state_dict + kwargs["device"] = device for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs): yield inputs, targets