mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-30 15:20:39 +00:00
support custom tokenizer by adding tokenizer_name
This commit is contained in:
parent
2085e6637a
commit
646647c776
|
|
@ -82,7 +82,9 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
else:
|
||||
model.train()
|
||||
# Load the Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
tokenizer_name = meta_data["tokenizer_name"]
|
||||
print(f"Loading tokenizer: {tokenizer_name}")
|
||||
tokenizer = get_tokenizer(tokenizer_name)
|
||||
# Sanity check: compatibility between model and tokenizer
|
||||
assert tokenizer.get_vocab_size() == model_config_kwargs["vocab_size"]
|
||||
return model, tokenizer, meta_data
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", data_dir=None):
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", data_dir=None, tokenizer_name="tokenizer"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches.
|
||||
|
||||
Args:
|
||||
|
|
@ -16,12 +16,13 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
tokenizer_threads: number of threads for tokenization
|
||||
tokenizer_batch_size: batch size for tokenization
|
||||
data_dir: optional custom directory containing parquet files (None = use default)
|
||||
tokenizer_name: name of the tokenizer subdirectory (default: tokenizer)
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
|
||||
# get the tokenizer and the bos token
|
||||
tokenizer = get_tokenizer()
|
||||
tokenizer = get_tokenizer(tokenizer_name)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
|
|
|
|||
|
|
@ -376,18 +376,18 @@ class RustBPETokenizer:
|
|||
# -----------------------------------------------------------------------------
|
||||
# nanochat-specific convenience functions
|
||||
|
||||
def get_tokenizer():
|
||||
def get_tokenizer(tokenizer_name="tokenizer"):
|
||||
from nanochat.common import get_base_dir
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
tokenizer_dir = os.path.join(base_dir, tokenizer_name)
|
||||
# return HuggingFaceTokenizer.from_directory(tokenizer_dir)
|
||||
return RustBPETokenizer.from_directory(tokenizer_dir)
|
||||
|
||||
def get_token_bytes(device="cpu"):
|
||||
def get_token_bytes(tokenizer_name="tokenizer", device="cpu"):
|
||||
import torch
|
||||
from nanochat.common import get_base_dir
|
||||
base_dir = get_base_dir()
|
||||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||||
tokenizer_dir = os.path.join(base_dir, tokenizer_name)
|
||||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||||
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
|
||||
with open(token_bytes_path, "rb") as f:
|
||||
|
|
|
|||
|
|
@ -123,7 +123,11 @@ def main():
|
|||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
parser.add_argument('--model-tag', type=str, default=None, help='Model tag to evaluate')
|
||||
parser.add_argument('--step', type=int, default=None, help='Model step to evaluate')
|
||||
args = parser.parse_args()
|
||||
model_tag = args.model_tag
|
||||
step = args.step
|
||||
|
||||
# distributed / precision setup
|
||||
device_type = autodetect_device_type()
|
||||
|
|
@ -140,9 +144,10 @@ def main():
|
|||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval")
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
print0(f"Loaded model with model_tag: {model_tag}")
|
||||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ model_tag = None # optional model tag for the output directory name
|
|||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
data_dir = "" # path to directory containing parquet files with 'text' column (empty string = use default)
|
||||
tokenizer_name = "tokenizer" # name of the tokenizer subdirectory (default: tokenizer)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
|
|
@ -36,7 +37,8 @@ autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16)
|
|||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
assert split_tokens % tokens_per_step == 0, "split_tokens must be divisible by tokens_per_step"
|
||||
steps = split_tokens // tokens_per_step
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
token_bytes = get_token_bytes(tokenizer_name, device=device)
|
||||
print0(f"Using tokenizer: {tokenizer_name}")
|
||||
bpb_results = {}
|
||||
custom_data_dir = data_dir if data_dir else None
|
||||
for split_name in ["train", "val"]:
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
|
|||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Data
|
||||
data_dir = "" # path to directory containing parquet files with 'text' column (empty string = use default: ~/.cache/nanochat/base_data)
|
||||
tokenizer_name = "tokenizer" # name of the tokenizer subdirectory (default: tokenizer)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -58,13 +59,13 @@ core_metric_every = 2000 # every how many steps to evaluate the core metric (-1
|
|||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
model_tag = run # optionally override the model tag for the output checkpoint directory name
|
||||
model_tag = "" # optionally override the model tag for the output checkpoint directory name
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
print(f"SHIZHE DEBUG: model_tag: {model_tag}")
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
|
@ -78,10 +79,11 @@ use_dummy_wandb = run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=run, config=user_config)
|
||||
|
||||
# Tokenizer will be useful for evaluation, also we need the vocab size
|
||||
tokenizer = get_tokenizer()
|
||||
token_bytes = get_token_bytes(device=device)
|
||||
tokenizer = get_tokenizer(tokenizer_name)
|
||||
token_bytes = get_token_bytes(tokenizer_name, device=device)
|
||||
vocab_size = tokenizer.get_vocab_size()
|
||||
print0(f"Vocab size: {vocab_size:,}")
|
||||
print0(f"Tokenizer: {tokenizer_name}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
|
|
@ -146,8 +148,8 @@ base_dir = get_base_dir()
|
|||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
# Use custom data_dir if provided, otherwise use default
|
||||
custom_data_dir = data_dir if data_dir else None
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device, data_dir=custom_data_dir)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device, data_dir="/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/.cache/base_data") # SHIZHE: always use the default val data dir from FineWeb by Andrej Karpathy
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device, data_dir=custom_data_dir, tokenizer_name=tokenizer_name)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device, data_dir="/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/.cache/base_data", tokenizer_name=tokenizer_name) # SHIZHE: always use the default val data dir from FineWeb by Andrej Karpathy
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -257,6 +259,7 @@ for step in range(num_iterations + 1):
|
|||
"user_config": user_config, # inputs to the training script
|
||||
"device_batch_size": device_batch_size,
|
||||
"max_seq_len": max_seq_len,
|
||||
"tokenizer_name": tokenizer_name, # save tokenizer name for later loading
|
||||
}
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user