diff --git a/scripts/base_eval.py b/scripts/base_eval.py index b5fa5d5..1741309 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -154,7 +154,7 @@ def main(): model_name = hf_path # just for logging model_slug = hf_path.replace("/", "-") # for the output csv file else: - # Load a local nanoChat model from the file system + # Load a local model from the file system model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=step) model_name = f"base_model (step {meta['step']})" # just for logging model_slug = f"base_model_{meta['step']:06d}" # for the output csv file diff --git a/scripts/base_train.py b/scripts/base_train.py index 2766247..529cfc1 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -46,7 +46,7 @@ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) # Optimization device_batch_size = 32 # per-device batch size (set to not OOM) -total_batch_size = 524288 # 2097152 #1048576 # 524288 # total desired batch size, in #tokens +total_batch_size = 524288 # total desired batch size, in #tokens embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam) weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam) @@ -65,7 +65,7 @@ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isins exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # will be useful for logging # ----------------------------------------------------------------------------- -print(f"SHIZHE DEBUG: model_tag: {model_tag}") + # Compute init device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 895d386..74fecc6 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -94,7 +94,7 @@ if dataset_choice == "smoltalk": GSM8K(subset="main", split="train"), # 8K rows SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations - ]) # total: 2.3K + 1.1K + 8K + 10K = 21.4K rows + ]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows val_ds = SmolTalk(split="test") # general conversations, 24K rows elif dataset_choice == "nemotron": # Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K @@ -104,18 +104,18 @@ elif dataset_choice == "nemotron": ARC(subset="ARC-Easy", split="train"), # 2.3K rows ARC(subset="ARC-Challenge", split="train"), # 1.1K rows GSM8K(subset="main", split="train"), # 8K rows - Nemotron(categories=["stem"], split="train", stop=3000), - Nemotron(categories=["math"], split="train", stop=3000), - Nemotron(categories=["chat"], split="train", stop=1000), - Nemotron(categories=["code"], split="train", stop=3000), - ]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) = 18.4K rows (similar to SmolTalk) + Nemotron(categories=["stem"], split="train", stop=3000), # 3K samples + Nemotron(categories=["math"], split="train", stop=3000), # 3K samples + Nemotron(categories=["chat"], split="train", stop=1000), # 1K samples + Nemotron(categories=["code"], split="train", stop=3000), # 3K samples + ]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) + 1K = 22.4K rows (similar to SmolTalk) # For validation, use a small subset of Nemotron mixed categories val_ds = TaskMixture([ - Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples - Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples - Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples - Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples - ]) # total: 1000 samples for validation + Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples + Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples + Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples + Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples + ]) # 300 + 300 + 100 + 300 = 1K samples for validation else: raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'") diff --git a/scripts/mid_train.py b/scripts/mid_train.py index e72d422..251c95f 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -121,25 +121,25 @@ elif dataset_choice == "nemotron": # Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%) # Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure train_dataset = TaskMixture([ - Nemotron(categories=["stem"], split="train", stop=151800), - Nemotron(categories=["math"], split="train", stop=151800), - Nemotron(categories=["chat"], split="train", stop=4600), - Nemotron(categories=["code"], split="train", stop=151800), + Nemotron(categories=["stem"], split="train", stop=151800), # 151800 samples + Nemotron(categories=["math"], split="train", stop=151800), # 151800 samples + Nemotron(categories=["chat"], split="train", stop=4600), # 4600 samples + Nemotron(categories=["code"], split="train", stop=151800), # 151800 samples MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these - ]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk) + ]) # 151800 + 151800 + 4600 + 151800 + 100000 + 8000 + 1000 + 1000 = 568K rows (same as SmolTalk) # For validation, match SmolTalk validation set structure val_dataset = TaskMixture([ - Nemotron(categories=["stem"], split="train", start=151800, stop=155000), - Nemotron(categories=["math"], split="train", start=151800, stop=155000), - Nemotron(categories=["chat"], split="train", start=4600, stop=10000), - Nemotron(categories=["code"], split="train", start=151800, stop=155000), + Nemotron(categories=["stem"], split="train", start=151800, stop=155000), # 3200 samples + Nemotron(categories=["math"], split="train", start=151800, stop=155000), # 3200 samples + Nemotron(categories=["chat"], split="train", start=4600, stop=10000), # 5400 samples + Nemotron(categories=["code"], split="train", start=151800, stop=155000), # 3200 samples MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios - ]) # total: 6.0K + 4.0K + 10.8K + 3.2K + 5.2K + 0.42K = 30.6K rows (similar to SmolTalk) + ]) # 3200 + 3200 + 5400 + 3200 + 5200 + 420 = 20.6K rows (similar to SmolTalk) else: raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'") # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)