support nemotron posttraining data in mid-train and sft

This commit is contained in:
Shizhe Diao 2025-10-19 15:10:53 -07:00
parent 646647c776
commit 7690b82d4b
4 changed files with 140 additions and 61 deletions

View File

@ -22,7 +22,7 @@
set -x # Enable debug output
DATA_NAME=climbmix
DATA_NAME=smollm
export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
export MATRIX_LR=0.02
@ -104,14 +104,14 @@ maturin develop --release --manifest-path rustbpe/Cargo.toml
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8 --data_dir=$DATA_DIR
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 --data_dir=$DATA_DIR &
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
# Use unique tokenizer name based on dataset
TOKENIZER_NAME="tokenizer_${DATA_NAME}"
export TOKENIZER_NAME="tokenizer_${DATA_NAME}"
python -m scripts.tok_train --max_chars=2000000000 --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval --tokenizer_name=$TOKENIZER_NAME
@ -141,42 +141,44 @@ wait $DATASET_DOWNLOAD_PID
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; python -c "import torch; print(torch.cuda.device_count())"'
# pretrain the d20 model (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_train -- --depth=20 --run=$WANDB_RUN --data_dir=$DATA_DIR --matrix_lr=$MATRIX_LR --tokenizer_name=$TOKENIZER_NAME --model_tag=$WANDB_RUN'
# evaluate the model on a larger chunk of train/val data and draw some samples (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_loss --data_dir=$DATA_DIR --tokenizer_name=$TOKENIZER_NAME --model_tag=$WANDB_RUN'
# evaluate the model on CORE tasks (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.base_eval -- --model_tag=$WANDB_RUN'
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# # -----------------------------------------------------------------------------
# # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# run midtraining and eval the model (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid'
# # run midtraining and eval the model (multi-node)
# # mid_train loads from base_checkpoints/$WANDB_RUN and saves to mid_checkpoints/$WANDB_RUN
# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.mid_train -- --run=$WANDB_RUN --model_tag=$WANDB_RUN'
# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i mid --model_tag=$WANDB_RUN'
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# # -----------------------------------------------------------------------------
# # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump) (multi-node)
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN'
srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft'
# # train sft and re-eval right away (should see a small bump) (multi-node)
# # chat_sft loads from mid_checkpoints/$WANDB_RUN and saves to chatsft_checkpoints/$WANDB_RUN
# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_sft -- --run=$WANDB_RUN --model_tag=$WANDB_RUN'
# srun --ntasks=$NNODES --ntasks-per-node=1 bash --noprofile --norc -lc 'source $HOME/nanochat_cache/.venv/bin/activate; torchrun --nnodes=$NNODES --nproc_per_node=$GPUS_PER_NODE --rdzv_endpoint=$RDZV_ENDPOINT --rdzv_id=$SLURM_JOB_ID --node_rank=$SLURM_NODEID -m scripts.chat_eval -- -i sft --model_tag=$WANDB_RUN'
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
# # chat with the model over CLI! Leave out the -p to chat interactively
# # python -m scripts.chat_cli -p "Why is the sky blue?"
# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
# # even better, chat with your model over a pretty WebUI ChatGPT style
# # python -m scripts.chat_web
# -----------------------------------------------------------------------------
# Reinforcement Learning. Optional, and currently only on GSM8K
# (optional)
# # -----------------------------------------------------------------------------
# # Reinforcement Learning. Optional, and currently only on GSM8K
# # (optional)
# run reinforcement learning
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
# # run reinforcement learning
# # torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
# # eval the RL model only on GSM8K
# # torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
python -m nanochat.report generate --exp_name=$WANDB_RUN
# # -----------------------------------------------------------------------------
# # Generate the full report by putting together all the sections
# # report.md is the output and will be copied to current directory for convenience
# python -m nanochat.report generate --exp_name=$WANDB_RUN

View File

@ -27,7 +27,9 @@ from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.nemotron import Nemotron
# -----------------------------------------------------------------------------
# SFT Hyperparameters
@ -54,6 +56,7 @@ eval_every = 100
eval_steps = 100
eval_metrics_every = 200
eval_metrics_max_problems = 1024
dataset_choice = "smoltalk" # dataset choice: "smoltalk" or "nemotron"
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@ -79,15 +82,42 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
# -----------------------------------------------------------------------------
# Task data mixture we'll train on
# Select dataset based on dataset_choice parameter
print0(f"SFT using dataset: {dataset_choice}")
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
if dataset_choice == "smoltalk":
# Original: SmolTalk + ARC + GSM8K
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # total: 2.3K + 1.1K + 8K + 10K = 21.4K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows
elif dataset_choice == "nemotron":
# Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K
# SmolTalk has 10K samples, we sample Nemotron proportionally to match
# Original Nemotron distribution: stem(25.4%), math(17.1%), chat(44.9%), code(12.5%)
train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
Nemotron(categories=["stem"], split="train", stop=2540), # 25.4% of 10K = 2.54K
Nemotron(categories=["math"], split="train", stop=1710), # 17.1% of 10K = 1.71K
Nemotron(categories=["chat"], split="train", stop=4490), # 44.9% of 10K = 4.49K
Nemotron(categories=["code"], split="train", stop=1250), # 12.5% of 10K = 1.25K
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # total: 2.3K + 1.1K + 8K + (2.54K + 1.71K + 4.49K + 1.25K) = 21.4K rows (same as SmolTalk)
# For validation, use a small subset of Nemotron mixed categories
val_ds = TaskMixture([
Nemotron(categories=["stem"], split="train", start=2540, stop=2790), # 250 samples
Nemotron(categories=["math"], split="train", start=1710, stop=1960), # 250 samples
Nemotron(categories=["chat"], split="train", start=4490, stop=5240), # 750 samples
Nemotron(categories=["code"], split="train", start=1250, stop=1500), # 250 samples
]) # total: 1500 samples for validation
else:
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
# -----------------------------------------------------------------------------
# DataLoader
@ -248,8 +278,9 @@ for step in range(num_iterations):
if master_process:
base_dir = get_base_dir()
depth = model.config.n_layer
model_tag = f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", model_tag)
# Use model_tag from config if provided, otherwise default to d{depth}
output_dirname = model_tag if model_tag else f"d{depth}" # base the model tag on the depth of the base model
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", run)
model_config_kwargs = model.config.__dict__ # slightly naughty, abusing the simplicity of GPTConfig, TODO nicer
save_checkpoint(
checkpoint_dir,

View File

@ -28,6 +28,7 @@ from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
from tasks.nemotron import Nemotron
# -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
@ -47,6 +48,7 @@ eval_every = 150 # -1 = disable
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
dataset_choice = "smoltalk" # dataset choice: "smoltalk" or "nemotron"
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
@ -80,7 +82,10 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
token_bytes = get_token_bytes(device=device)
# Load tokenizer_name from checkpoint metadata
tokenizer_name = meta.get("tokenizer_name", "tokenizer")
print0(f"Using tokenizer: {tokenizer_name}")
token_bytes = get_token_bytes(tokenizer_name, device=device)
# Initialize the Optimizer (Muon for Linear layers, AdamW for embedding and lm_head)
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
@ -93,19 +98,49 @@ for opt in optimizers:
# Midtraining data mixture and DataLoader
base_dir = get_base_dir()
# Select dataset based on dataset_choice parameter
print0(f"Using dataset: {dataset_choice}")
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
if dataset_choice == "smoltalk":
# Original: SmolTalk + MMLU + GSM8K
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
GSM8K(subset="main", split="train"), # 8K rows teaching simple math
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows
val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
]) # total: ~29.6K rows
elif dataset_choice == "nemotron":
# Ablation: Nemotron with stem, math, chat, code (sampled to match SmolTalk 460K) + MMLU + GSM8K
# Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%)
# Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure
train_dataset = TaskMixture([
Nemotron(categories=["stem"], split="train", stop=117000), # 25.4% of 460K = 117K
Nemotron(categories=["math"], split="train", stop=79000), # 17.1% of 460K = 79K
Nemotron(categories=["chat"], split="train", stop=207000), # 44.9% of 460K = 207K
Nemotron(categories=["code"], split="train", stop=57000), # 12.5% of 460K = 57K
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk)
# For validation, match SmolTalk validation set structure
val_dataset = TaskMixture([
Nemotron(categories=["stem"], split="train", start=117000, stop=124500), # 7.5K
Nemotron(categories=["math"], split="train", start=79000, stop=84000), # 5K
Nemotron(categories=["chat"], split="train", start=207000, stop=220500), # 13.5K
Nemotron(categories=["code"], split="train", start=57000, stop=61000), # 4K
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
]) # total: 7.5K + 5K + 13.5K + 4K + 5.2K + 0.42K = 35.6K rows
else:
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
@ -204,8 +239,9 @@ while True:
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
# Use model_tag from config if provided, otherwise default to d{depth}
# output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", run)
save_checkpoint(
checkpoint_dir,
step,
@ -222,7 +258,8 @@ while True:
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
},
"user_config": user_config, # inputs to the training script
"user_config": user_config, # inputs to the training script,
"tokenizer_name": tokenizer_name, # save tokenizer name for later loading
}
)

View File

@ -2,9 +2,18 @@
Evaluate compression ratio of the tokenizer.
"""
import argparse
from nanochat.tokenizer import get_tokenizer, RustBPETokenizer
from nanochat.dataset import parquets_iter_batched
# Parse command line arguments
parser = argparse.ArgumentParser(description='Evaluate tokenizer compression')
parser.add_argument('--tokenizer_name', type=str, default='tokenizer', help='Name of the tokenizer subdirectory (default: tokenizer)')
parser.add_argument('--data_dir', type=str, default=None, help='Custom dataset directory (default: None, uses default dataset)')
args = parser.parse_args()
print(f"tokenizer_name: {args.tokenizer_name}")
print(f"data_dir: {args.data_dir}")
# Random text I got from a random website this morning
news_text = r"""
(Washington, D.C., July 9, 2025)- Yesterday, Mexicos National Service of Agro-Alimentary Health, Safety, and Quality (SENASICA) reported a new case of New World Screwworm (NWS) in Ixhuatlan de Madero, Veracruz in Mexico, which is approximately 160 miles northward of the current sterile fly dispersal grid, on the eastern side of the country and 370 miles south of the U.S./Mexico border. This new northward detection comes approximately two months after northern detections were reported in Oaxaca and Veracruz, less than 700 miles away from the U.S. border, which triggered the closure of our ports to Mexican cattle, bison, and horses on May 11, 2025.
@ -144,9 +153,9 @@ Photosynthesis is a photochemical energy transduction process in which light-har
""".strip()
# The tokenizer was trained on data from earlier shards, so it has seen this data
train_docs = next(parquets_iter_batched(split="train"))
train_docs = next(parquets_iter_batched(split="train", data_dir=args.data_dir))
train_text = "\n".join(train_docs)
val_docs = next(parquets_iter_batched(split="val"))
val_docs = next(parquets_iter_batched(split="val", data_dir=args.data_dir))
val_text = "\n".join(val_docs)
all_text = [
@ -171,7 +180,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
elif tokenizer_name == "gpt4":
tokenizer = RustBPETokenizer.from_pretrained("cl100k_base") # gpt-4 base model tokenizer
else:
tokenizer = get_tokenizer()
tokenizer = get_tokenizer(args.tokenizer_name)
vocab_sizes[tokenizer_name] = tokenizer.get_vocab_size()
tokenizer_results[tokenizer_name] = {}