mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 12:52:16 +00:00
improve tokenizer and report in midtrain and sft
This commit is contained in:
parent
169022fec0
commit
f3f069519d
|
|
@ -13,8 +13,8 @@
|
||||||
|
|
||||||
set -x # Enable debug output
|
set -x # Enable debug output
|
||||||
|
|
||||||
export DATA_NAME=nemotron # nemotron # smoltalk
|
export DATA_NAME=smoltalk # nemotron # smoltalk
|
||||||
export BASE_NAME=smollm_d20_1node_matrixlr0.02_2298373 # fineweb_d20_1node # climbmix_d20_1node_matrixlr0.02_2298334 # nemotron-cc-hq_d20_1node_matrixlr0.02_2298371 # smollm_d20_1node_matrixlr0.02_2298373
|
export BASE_NAME=climbmix_d20_1node_matrixlr0.02_2298334 # fineweb_d20_1node # climbmix_d20_1node_matrixlr0.02_2298334 # nemotron-cc-hq_d20_1node_matrixlr0.02_2298371 # smollm_d20_1node_matrixlr0.02_2298373
|
||||||
|
|
||||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||||
export OMP_NUM_THREADS=1
|
export OMP_NUM_THREADS=1
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
set -x # Enable debug output
|
set -x # Enable debug output
|
||||||
|
|
||||||
DATA_NAME=climbmix
|
DATA_NAME=climbmix_1_9
|
||||||
export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
|
export DATA_DIR=/lustre/fsw/portfolios/nvr/users/sdiao/nanochat/data/$DATA_NAME
|
||||||
export MATRIX_LR=0.02
|
export MATRIX_LR=0.02
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,8 @@ model, tokenizer, meta = load_model(source, device, phase="train", model_tag=mod
|
||||||
orig_model = model # original, uncompiled model
|
orig_model = model # original, uncompiled model
|
||||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
||||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||||
|
tokenizer_name = meta.get("tokenizer_name", "tokenizer")
|
||||||
|
print0(f"Using tokenizer: {tokenizer_name}")
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Task data mixture we'll train on
|
# Task data mixture we'll train on
|
||||||
# Select dataset based on dataset_choice parameter
|
# Select dataset based on dataset_choice parameter
|
||||||
|
|
@ -103,19 +104,18 @@ elif dataset_choice == "nemotron":
|
||||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||||
GSM8K(subset="main", split="train"), # 8K rows
|
GSM8K(subset="main", split="train"), # 8K rows
|
||||||
Nemotron(categories=["stem"], split="train", stop=2540), # 25.4% of 10K = 2.54K
|
Nemotron(categories=["stem"], split="train", stop=3000),
|
||||||
Nemotron(categories=["math"], split="train", stop=1710), # 17.1% of 10K = 1.71K
|
Nemotron(categories=["math"], split="train", stop=3000),
|
||||||
Nemotron(categories=["chat"], split="train", stop=4490), # 44.9% of 10K = 4.49K
|
Nemotron(categories=["chat"], split="train", stop=1000),
|
||||||
Nemotron(categories=["code"], split="train", stop=1250), # 12.5% of 10K = 1.25K
|
Nemotron(categories=["code"], split="train", stop=3000),
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) = 18.4K rows (similar to SmolTalk)
|
||||||
]) # total: 2.3K + 1.1K + 8K + (2.54K + 1.71K + 4.49K + 1.25K) = 21.4K rows (same as SmolTalk)
|
|
||||||
# For validation, use a small subset of Nemotron mixed categories
|
# For validation, use a small subset of Nemotron mixed categories
|
||||||
val_ds = TaskMixture([
|
val_ds = TaskMixture([
|
||||||
Nemotron(categories=["stem"], split="train", start=2540, stop=2790), # 250 samples
|
Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples
|
||||||
Nemotron(categories=["math"], split="train", start=1710, stop=1960), # 250 samples
|
Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples
|
||||||
Nemotron(categories=["chat"], split="train", start=4490, stop=5240), # 750 samples
|
Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples
|
||||||
Nemotron(categories=["code"], split="train", start=1250, stop=1500), # 250 samples
|
Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples
|
||||||
]) # total: 1500 samples for validation
|
]) # total: 1000 samples for validation
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
||||||
|
|
||||||
|
|
@ -292,13 +292,14 @@ if master_process:
|
||||||
"val_loss": val_loss,
|
"val_loss": val_loss,
|
||||||
**metrics,
|
**metrics,
|
||||||
"model_config": model_config_kwargs,
|
"model_config": model_config_kwargs,
|
||||||
|
"tokenizer_name": tokenizer_name,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
print(f"✅ Saved model checkpoint to {checkpoint_dir}")
|
||||||
|
|
||||||
# Log to report
|
# Log to report
|
||||||
from nanochat.report import get_report
|
from nanochat.report import get_report
|
||||||
get_report().log(section="Chat SFT", data=[
|
get_report(exp_name=run).log(section="Chat SFT", data=[
|
||||||
user_config, # CLI args
|
user_config, # CLI args
|
||||||
{
|
{
|
||||||
"Training rows": len(train_ds),
|
"Training rows": len(train_ds),
|
||||||
|
|
|
||||||
|
|
@ -121,24 +121,25 @@ elif dataset_choice == "nemotron":
|
||||||
# Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%)
|
# Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%)
|
||||||
# Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure
|
# Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure
|
||||||
train_dataset = TaskMixture([
|
train_dataset = TaskMixture([
|
||||||
Nemotron(categories=["stem"], split="train", stop=117000), # 25.4% of 460K = 117K
|
Nemotron(categories=["stem"], split="train", stop=151800),
|
||||||
Nemotron(categories=["math"], split="train", stop=79000), # 17.1% of 460K = 79K
|
Nemotron(categories=["math"], split="train", stop=151800),
|
||||||
Nemotron(categories=["chat"], split="train", stop=207000), # 44.9% of 460K = 207K
|
Nemotron(categories=["chat"], split="train", stop=4600),
|
||||||
Nemotron(categories=["code"], split="train", stop=57000), # 12.5% of 460K = 57K
|
Nemotron(categories=["code"], split="train", stop=151800),
|
||||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
|
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
|
||||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||||
]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk)
|
]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk)
|
||||||
|
|
||||||
# For validation, match SmolTalk validation set structure
|
# For validation, match SmolTalk validation set structure
|
||||||
val_dataset = TaskMixture([
|
val_dataset = TaskMixture([
|
||||||
Nemotron(categories=["stem"], split="train", start=117000, stop=124500), # 7.5K
|
Nemotron(categories=["stem"], split="train", start=151800, stop=155000),
|
||||||
Nemotron(categories=["math"], split="train", start=79000, stop=84000), # 5K
|
Nemotron(categories=["math"], split="train", start=151800, stop=155000),
|
||||||
Nemotron(categories=["chat"], split="train", start=207000, stop=220500), # 13.5K
|
Nemotron(categories=["chat"], split="train", start=4600, stop=10000),
|
||||||
Nemotron(categories=["code"], split="train", start=57000, stop=61000), # 4K
|
Nemotron(categories=["code"], split="train", start=151800, stop=155000),
|
||||||
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
|
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
|
||||||
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
|
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
|
||||||
]) # total: 7.5K + 5K + 13.5K + 4K + 5.2K + 0.42K = 35.6K rows
|
]) # total: 6.0K + 4.0K + 10.8K + 3.2K + 5.2K + 0.42K = 30.6K rows (similar to SmolTalk)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
||||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||||
|
|
@ -329,7 +330,7 @@ print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||||
# Log to report
|
# Log to report
|
||||||
if not dry_run:
|
if not dry_run:
|
||||||
from nanochat.report import get_report
|
from nanochat.report import get_report
|
||||||
get_report().log(section="Midtraining", data=[
|
get_report(exp_name=run).log(section="Midtraining", data=[
|
||||||
user_config, # CLI args
|
user_config, # CLI args
|
||||||
{ # stats about the training setup
|
{ # stats about the training setup
|
||||||
"Number of iterations": step,
|
"Number of iterations": step,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user