diff --git a/speedrun_submit_multinode.sh b/pretrain_submit.sh similarity index 68% rename from speedrun_submit_multinode.sh rename to pretrain_submit.sh index 928f480..8c7cff0 100644 --- a/speedrun_submit_multinode.sh +++ b/pretrain_submit.sh @@ -22,7 +22,7 @@ set -x # Enable debug output -DATA_NAME=climbmix +DATA_NAME=smollm export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME export MATRIX_LR=0.02 @@ -104,14 +104,14 @@ maturin develop --release --manifest-path rustbpe/Cargo.toml # each data shard is ~250M chars # so we download 2e9 / 250e6 = 8 data shards at this point # each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk -python -m nanochat.dataset -n 8 --data_dir=$DATA_DIR +python -m nanochat.dataset -n 8 # Immediately also kick off downloading more shards in the background while tokenizer trains # See comment below for why 240 is the right number here -python -m nanochat.dataset -n 240 --data_dir=$DATA_DIR & +python -m nanochat.dataset -n 240 & DATASET_DOWNLOAD_PID=$! # train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data # Use unique tokenizer name based on dataset -TOKENIZER_NAME="tokenizer_${DATA_NAME}" +export TOKENIZER_NAME="tokenizer_${DATA_NAME}" python -m scripts.tok_train --max_chars=2000000000 --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME # evaluate the tokenizer (report compression ratio etc.) python -m scripts.tok_eval --tokenizer_name=$TOKENIZER_NAME @@ -141,42 +141,44 @@ wait $DATASET_DOWNLOAD_PID srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; python -c "import torch; print(torch.cuda.device_count())"' # pretrain the d20 model (multi-node) -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR' +srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR --tokenizer_name=$TOKENIZER_NAME --model_tag=$WANDB_RUN' # evaluate the model on a larger chunk of train/val data and draw some samples (multi-node) -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR' +srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME --model_tag=$WANDB_RUN' # evaluate the model on CORE tasks (multi-node) -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval' +srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval -- --model_tag=$WANDB_RUN' -# ----------------------------------------------------------------------------- -# Midtraining (teach the model conversation special tokens, tool use, multiple choice) +# # ----------------------------------------------------------------------------- +# # Midtraining (teach the model conversation special tokens, tool use, multiple choice) -# run midtraining and eval the model (multi-node) -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN' -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid' +# # run midtraining and eval the model (multi-node) +# # mid_train loads from base_checkpoints/$WANDB_RUN and saves to mid_checkpoints/$WANDB_RUN +# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN --model_tag=$WANDB_RUN' +# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid --model_tag=$WANDB_RUN' -# ----------------------------------------------------------------------------- -# Supervised Finetuning (domain adaptation to each sequence all by itself per row) +# # ----------------------------------------------------------------------------- +# # Supervised Finetuning (domain adaptation to each sequence all by itself per row) -# train sft and re-eval right away (should see a small bump) (multi-node) -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN' -srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft' +# # train sft and re-eval right away (should see a small bump) (multi-node) +# # chat_sft loads from mid_checkpoints/$WANDB_RUN and saves to chatsft_checkpoints/$WANDB_RUN +# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN --model_tag=$WANDB_RUN' +# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft --model_tag=$WANDB_RUN' -# chat with the model over CLI! Leave out the -p to chat interactively -# python -m scripts.chat_cli -p "Why is the sky blue?" +# # chat with the model over CLI! Leave out the -p to chat interactively +# # python -m scripts.chat_cli -p "Why is the sky blue?" -# even better, chat with your model over a pretty WebUI ChatGPT style -# python -m scripts.chat_web +# # even better, chat with your model over a pretty WebUI ChatGPT style +# # python -m scripts.chat_web -# ----------------------------------------------------------------------------- -# Reinforcement Learning. Optional, and currently only on GSM8K -# (optional) +# # ----------------------------------------------------------------------------- +# # Reinforcement Learning. Optional, and currently only on GSM8K +# # (optional) -# run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN -# eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# # run reinforcement learning +# # torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# # eval the RL model only on GSM8K +# # torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K -# ----------------------------------------------------------------------------- -# Generate the full report by putting together all the sections -# report.md is the output and will be copied to current directory for convenience -python -m nanochat.report generate --exp_name=$WANDB_RUN +# # ----------------------------------------------------------------------------- +# # Generate the full report by putting together all the sections +# # report.md is the output and will be copied to current directory for convenience +# python -m nanochat.report generate --exp_name=$WANDB_RUN diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index aeab77e..1122e78 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -27,7 +27,9 @@ from tasks.common import TaskMixture from tasks.arc import ARC from tasks.gsm8k import GSM8K from tasks.smoltalk import SmolTalk + from tasks.customjson import CustomJSON +from tasks.nemotron import Nemotron # ----------------------------------------------------------------------------- # SFT Hyperparameters @@ -54,6 +56,7 @@ eval_every = 100 eval_steps = 100 eval_metrics_every = 200 eval_metrics_max_problems = 1024 +dataset_choice = "smoltalk" # dataset choice: "smoltalk" or "nemotron" # now allow CLI to override the settings via the configurator lol config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file @@ -79,15 +82,42 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl # ----------------------------------------------------------------------------- # Task data mixture we'll train on +# Select dataset based on dataset_choice parameter +print0(f"SFT using dataset: {dataset_choice}") identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl") -train_ds = TaskMixture([ - ARC(subset="ARC-Easy", split="train"), # 2.3K rows - ARC(subset="ARC-Challenge", split="train"), # 1.1K rows - GSM8K(subset="main", split="train"), # 8K rows - SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk - CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations -]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows -val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) +if dataset_choice == "smoltalk": + # Original: SmolTalk + ARC + GSM8K + train_ds = TaskMixture([ + ARC(subset="ARC-Easy", split="train"), # 2.3K rows + ARC(subset="ARC-Challenge", split="train"), # 1.1K rows + GSM8K(subset="main", split="train"), # 8K rows + SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk + CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations + ]) # total: 2.3K + 1.1K + 8K + 10K = 21.4K rows + val_ds = SmolTalk(split="test") # general conversations, 24K rows +elif dataset_choice == "nemotron": + # Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K + # SmolTalk has 10K samples, we sample Nemotron proportionally to match + # Original Nemotron distribution: stem(25.4%), math(17.1%), chat(44.9%), code(12.5%) + train_ds = TaskMixture([ + ARC(subset="ARC-Easy", split="train"), # 2.3K rows + ARC(subset="ARC-Challenge", split="train"), # 1.1K rows + GSM8K(subset="main", split="train"), # 8K rows + Nemotron(categories=["stem"], split="train", stop=2540), # 25.4% of 10K = 2.54K + Nemotron(categories=["math"], split="train", stop=1710), # 17.1% of 10K = 1.71K + Nemotron(categories=["chat"], split="train", stop=4490), # 44.9% of 10K = 4.49K + Nemotron(categories=["code"], split="train", stop=1250), # 12.5% of 10K = 1.25K + CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations + ]) # total: 2.3K + 1.1K + 8K + (2.54K + 1.71K + 4.49K + 1.25K) = 21.4K rows (same as SmolTalk) + # For validation, use a small subset of Nemotron mixed categories + val_ds = TaskMixture([ + Nemotron(categories=["stem"], split="train", start=2540, stop=2790), # 250 samples + Nemotron(categories=["math"], split="train", start=1710, stop=1960), # 250 samples + Nemotron(categories=["chat"], split="train", start=4490, stop=5240), # 750 samples + Nemotron(categories=["code"], split="train", start=1250, stop=1500), # 250 samples + ]) # total: 1500 samples for validation +else: + raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'") # ----------------------------------------------------------------------------- # DataLoader @@ -248,8 +278,9 @@ for step in range(num_iterations): if master_process: base_dir = get_base_dir() depth = model.config.n_layer - model_tag = f"d{depth}" # base the model tag on the depth of the base model - checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag) + # Use model_tag from config if provided, otherwise default to d{depth} + output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model + checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", run) model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer save_checkpoint( checkpoint_dir, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 2835ebf..f43e306 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K from tasks.mmlu import MMLU from tasks.smoltalk import SmolTalk from tasks.customjson import CustomJSON +from tasks.nemotron import Nemotron # ----------------------------------------------------------------------------- run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) @@ -47,6 +48,7 @@ eval_every = 150 # -1 = disable eval_tokens = 20*524288 total_batch_size = 524288 dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report +dataset_choice = "smoltalk" # dataset choice: "smoltalk" or "nemotron" config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging @@ -80,7 +82,10 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") -token_bytes = get_token_bytes(device=device) +# Load tokenizer_name from checkpoint metadata +tokenizer_name = meta.get("tokenizer_name", "tokenizer") +print0(f"Using tokenizer: {tokenizer_name}") +token_bytes = get_token_bytes(tokenizer_name, device=device) # Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head) optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay) @@ -93,19 +98,49 @@ for opt in optimizers: # Midtraining data mixture and DataLoader base_dir = get_base_dir() + +# Select dataset based on dataset_choice parameter +print0(f"Using dataset: {dataset_choice}") identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl") -train_dataset = TaskMixture([ - SmolTalk(split="train"), # 460K rows of general conversations - MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE - GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use - CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations - CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these -]) # total: 460K + 100K + 8K = 568K rows -val_dataset = TaskMixture([ - SmolTalk(split="test"), # 24K rows in test set - MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios - GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows +if dataset_choice == "smoltalk": + # Original: SmolTalk + MMLU + GSM8K + train_dataset = TaskMixture([ + SmolTalk(split="train"), # 460K rows of general conversations + MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems + GSM8K(subset="main", split="train"), # 8K rows teaching simple math + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + ]) # total: 460K + 100K + 8K = 568K rows + val_dataset = TaskMixture([ + SmolTalk(split="test"), # 24K rows in test set + MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios + GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios + ]) # total: ~29.6K rows +elif dataset_choice == "nemotron": + # Ablation: Nemotron with stem, math, chat, code (sampled to match SmolTalk 460K) + MMLU + GSM8K + # Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%) + # Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure + train_dataset = TaskMixture([ + Nemotron(categories=["stem"], split="train", stop=117000), # 25.4% of 460K = 117K + Nemotron(categories=["math"], split="train", stop=79000), # 17.1% of 460K = 79K + Nemotron(categories=["chat"], split="train", stop=207000), # 44.9% of 460K = 207K + Nemotron(categories=["code"], split="train", stop=57000), # 12.5% of 460K = 57K + MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems + GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use + CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations + CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these + ]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk) + # For validation, match SmolTalk validation set structure + val_dataset = TaskMixture([ + Nemotron(categories=["stem"], split="train", start=117000, stop=124500), # 7.5K + Nemotron(categories=["math"], split="train", start=79000, stop=84000), # 5K + Nemotron(categories=["chat"], split="train", start=207000, stop=220500), # 13.5K + Nemotron(categories=["code"], split="train", start=57000, stop=61000), # 4K + MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios + GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios + ]) # total: 7.5K + 5K + 13.5K + 4K + 5.2K + 0.42K = 35.6K rows +else: + raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'") # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. @@ -204,8 +239,9 @@ while True: # save checkpoint at the end of the run (only on master process) if master_process and last_step and not dry_run: - output_dirname = f"d{depth}" # e.g. d12 - checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname) + # Use model_tag from config if provided, otherwise default to d{depth} + # output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12 + checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", run) save_checkpoint( checkpoint_dir, step, @@ -222,7 +258,8 @@ while True: "n_kv_head": model.config.n_kv_head, "n_embd": model.config.n_embd, }, - "user_config": user_config, # inputs to the training script + "user_config": user_config, # inputs to the training script, + "tokenizer_name": tokenizer_name, # save tokenizer name for later loading } ) diff --git a/scripts/tok_eval.py b/scripts/tok_eval.py index 9233d71..ec37483 100644 --- a/scripts/tok_eval.py +++ b/scripts/tok_eval.py @@ -2,9 +2,18 @@ Evaluate compression ratio of the tokenizer. """ +import argparse from nanochat.tokenizer import get_tokenizer, RustBPETokenizer from nanochat.dataset import parquets_iter_batched +# Parse command line arguments +parser = argparse.ArgumentParser(description='Evaluate tokenizer compression') +parser.add_argument('--tokenizer_name', type=str, default='tokenizer', help='Name of the tokenizer subdirectory (default: tokenizer)') +parser.add_argument('--data_dir', type=str, default=None, help='Custom dataset directory (default: None, uses default dataset)') +args = parser.parse_args() +print(f"tokenizer_name: {args.tokenizer_name}") +print(f"data_dir: {args.data_dir}") + # Random text I got from a random website this morning news_text = r""" (Washington, D.C., July 9, 2025)- Yesterday, Mexico’s National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025. @@ -144,9 +153,9 @@ Photosynthesis is a photochemical energy transduction process in which light-har """.strip() # The tokenizer was trained on data from earlier shards, so it has seen this data -train_docs = next(parquets_iter_batched(split="train")) +train_docs = next(parquets_iter_batched(split="train", data_dir=args.data_dir)) train_text = "\n".join(train_docs) -val_docs = next(parquets_iter_batched(split="val")) +val_docs = next(parquets_iter_batched(split="val", data_dir=args.data_dir)) val_text = "\n".join(val_docs) all_text = [ @@ -171,7 +180,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]: elif tokenizer_name == "gpt4": tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer else: - tokenizer = get_tokenizer() + tokenizer = get_tokenizer(args.tokenizer_name) vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size() tokenizer_results[tokenizer_name] = {}