Update comments

This commit is contained in:
Shizhe Diao 2025-10-22 22:19:20 -07:00
parent 55fed15421
commit f384c16ba5
4 changed files with 24 additions and 24 deletions

View File

@ -154,7 +154,7 @@ def main():
model_name = hf_path # just for logging
model_slug = hf_path.replace("/", "-") # for the output csv file
else:
# Load a local nanoChat model from the file system
# Load a local model from the file system
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=step)
model_name = f"base_model (step {meta['step']})" # just for logging
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file

View File

@ -46,7 +46,7 @@ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
# Optimization
device_batch_size = 32 # per-device batch size (set to not OOM)
total_batch_size = 524288 # 2097152 #1048576 # 524288 # total desired batch size, in #tokens
total_batch_size = 524288 # total desired batch size, in #tokens
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
@ -65,7 +65,7 @@ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isins
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
print(f"SHIZHE DEBUG: model_tag: {model_tag}")
# Compute init
device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)

View File

@ -94,7 +94,7 @@ if dataset_choice == "smoltalk":
GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # total: 2.3K + 1.1K + 8K + 10K = 21.4K rows
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows
elif dataset_choice == "nemotron":
# Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K
@ -104,18 +104,18 @@ elif dataset_choice == "nemotron":
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows
Nemotron(categories=["stem"], split="train", stop=3000),
Nemotron(categories=["math"], split="train", stop=3000),
Nemotron(categories=["chat"], split="train", stop=1000),
Nemotron(categories=["code"], split="train", stop=3000),
]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) = 18.4K rows (similar to SmolTalk)
Nemotron(categories=["stem"], split="train", stop=3000), # 3K samples
Nemotron(categories=["math"], split="train", stop=3000), # 3K samples
Nemotron(categories=["chat"], split="train", stop=1000), # 1K samples
Nemotron(categories=["code"], split="train", stop=3000), # 3K samples
]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) + 1K = 22.4K rows (similar to SmolTalk)
# For validation, use a small subset of Nemotron mixed categories
val_ds = TaskMixture([
Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples
Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples
Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples
Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples
]) # total: 1000 samples for validation
]) # 300 + 300 + 100 + 300 = 1K samples for validation
else:
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")

View File

@ -121,25 +121,25 @@ elif dataset_choice == "nemotron":
# Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%)
# Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure
train_dataset = TaskMixture([
Nemotron(categories=["stem"], split="train", stop=151800),
Nemotron(categories=["math"], split="train", stop=151800),
Nemotron(categories=["chat"], split="train", stop=4600),
Nemotron(categories=["code"], split="train", stop=151800),
Nemotron(categories=["stem"], split="train", stop=151800), # 151800 samples
Nemotron(categories=["math"], split="train", stop=151800), # 151800 samples
Nemotron(categories=["chat"], split="train", stop=4600), # 4600 samples
Nemotron(categories=["code"], split="train", stop=151800), # 151800 samples
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk)
]) # 151800 + 151800 + 4600 + 151800 + 100000 + 8000 + 1000 + 1000 = 568K rows (same as SmolTalk)
# For validation, match SmolTalk validation set structure
val_dataset = TaskMixture([
Nemotron(categories=["stem"], split="train", start=151800, stop=155000),
Nemotron(categories=["math"], split="train", start=151800, stop=155000),
Nemotron(categories=["chat"], split="train", start=4600, stop=10000),
Nemotron(categories=["code"], split="train", start=151800, stop=155000),
Nemotron(categories=["stem"], split="train", start=151800, stop=155000), # 3200 samples
Nemotron(categories=["math"], split="train", start=151800, stop=155000), # 3200 samples
Nemotron(categories=["chat"], split="train", start=4600, stop=10000), # 5400 samples
Nemotron(categories=["code"], split="train", start=151800, stop=155000), # 3200 samples
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
]) # total: 6.0K + 4.0K + 10.8K + 3.2K + 5.2K + 0.42K = 30.6K rows (similar to SmolTalk)
]) # 3200 + 3200 + 5400 + 3200 + 5200 + 420 = 20.6K rows (similar to SmolTalk)
else:
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)