mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Update comments
This commit is contained in:
parent
55fed15421
commit
f384c16ba5
|
|
@ -154,7 +154,7 @@ def main():
|
|||
model_name = hf_path # just for logging
|
||||
model_slug = hf_path.replace("/", "-") # for the output csv file
|
||||
else:
|
||||
# Load a local nanoChat model from the file system
|
||||
# Load a local model from the file system
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=step)
|
||||
model_name = f"base_model (step {meta['step']})" # just for logging
|
||||
model_slug = f"base_model_{meta['step']:06d}" # for the output csv file
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for
|
|||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||
# Optimization
|
||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
||||
total_batch_size = 524288 # 2097152 #1048576 # 524288 # total desired batch size, in #tokens
|
||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||
weight_decay = 0.0 # weight decay for the embedding/unembedding parameters (Adam)
|
||||
|
|
@ -65,7 +65,7 @@ config_keys = [k for k,v in globals().items() if not k.startswith('_') and isins
|
|||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
print(f"SHIZHE DEBUG: model_tag: {model_tag}")
|
||||
|
||||
# Compute init
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ if dataset_choice == "smoltalk":
|
|||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # total: 2.3K + 1.1K + 8K + 10K = 21.4K rows
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows
|
||||
elif dataset_choice == "nemotron":
|
||||
# Ablation: Nemotron (sampled to match SmolTalk 10K) + ARC + GSM8K
|
||||
|
|
@ -104,18 +104,18 @@ elif dataset_choice == "nemotron":
|
|||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
Nemotron(categories=["stem"], split="train", stop=3000),
|
||||
Nemotron(categories=["math"], split="train", stop=3000),
|
||||
Nemotron(categories=["chat"], split="train", stop=1000),
|
||||
Nemotron(categories=["code"], split="train", stop=3000),
|
||||
]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) = 18.4K rows (similar to SmolTalk)
|
||||
Nemotron(categories=["stem"], split="train", stop=3000), # 3K samples
|
||||
Nemotron(categories=["math"], split="train", stop=3000), # 3K samples
|
||||
Nemotron(categories=["chat"], split="train", stop=1000), # 1K samples
|
||||
Nemotron(categories=["code"], split="train", stop=3000), # 3K samples
|
||||
]) # total: 2.3K + 1.1K + 8K + (3.0K + 3.0K + 1.0K + 3.0K) + 1K = 22.4K rows (similar to SmolTalk)
|
||||
# For validation, use a small subset of Nemotron mixed categories
|
||||
val_ds = TaskMixture([
|
||||
Nemotron(categories=["stem"], split="train", start=3000, stop=3300), # 300 samples
|
||||
Nemotron(categories=["math"], split="train", start=3000, stop=3300), # 300 samples
|
||||
Nemotron(categories=["chat"], split="train", start=1000, stop=1100), # 100 samples
|
||||
Nemotron(categories=["code"], split="train", start=3000, stop=3300), # 300 samples
|
||||
]) # total: 1000 samples for validation
|
||||
]) # 300 + 300 + 100 + 300 = 1K samples for validation
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
||||
|
||||
|
|
|
|||
|
|
@ -121,25 +121,25 @@ elif dataset_choice == "nemotron":
|
|||
# Original Nemotron distribution: stem(355K/25.4%), math(239K/17.1%), chat(628K/44.9%), code(175K/12.5%)
|
||||
# Proportionally sampled to 460K total, then add MMLU + GSM8K to match SmolTalk structure
|
||||
train_dataset = TaskMixture([
|
||||
Nemotron(categories=["stem"], split="train", stop=151800),
|
||||
Nemotron(categories=["math"], split="train", stop=151800),
|
||||
Nemotron(categories=["chat"], split="train", stop=4600),
|
||||
Nemotron(categories=["code"], split="train", stop=151800),
|
||||
Nemotron(categories=["stem"], split="train", stop=151800), # 151800 samples
|
||||
Nemotron(categories=["math"], split="train", stop=151800), # 151800 samples
|
||||
Nemotron(categories=["chat"], split="train", stop=4600), # 4600 samples
|
||||
Nemotron(categories=["code"], split="train", stop=151800), # 151800 samples
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 117K + 79K + 207K + 57K + 100K + 8K = 568K rows (same as SmolTalk)
|
||||
]) # 151800 + 151800 + 4600 + 151800 + 100000 + 8000 + 1000 + 1000 = 568K rows (same as SmolTalk)
|
||||
|
||||
# For validation, match SmolTalk validation set structure
|
||||
val_dataset = TaskMixture([
|
||||
Nemotron(categories=["stem"], split="train", start=151800, stop=155000),
|
||||
Nemotron(categories=["math"], split="train", start=151800, stop=155000),
|
||||
Nemotron(categories=["chat"], split="train", start=4600, stop=10000),
|
||||
Nemotron(categories=["code"], split="train", start=151800, stop=155000),
|
||||
Nemotron(categories=["stem"], split="train", start=151800, stop=155000), # 3200 samples
|
||||
Nemotron(categories=["math"], split="train", start=151800, stop=155000), # 3200 samples
|
||||
Nemotron(categories=["chat"], split="train", start=4600, stop=10000), # 5400 samples
|
||||
Nemotron(categories=["code"], split="train", start=151800, stop=155000), # 3200 samples
|
||||
MMLU(subset="all", split="test", stop=5200), # 5.2K rows to match train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 420 rows to match train ratios
|
||||
]) # total: 6.0K + 4.0K + 10.8K + 3.2K + 5.2K + 0.42K = 30.6K rows (similar to SmolTalk)
|
||||
]) # 3200 + 3200 + 5400 + 3200 + 5200 + 420 = 20.6K rows (similar to SmolTalk)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset_choice: {dataset_choice}. Must be 'smoltalk' or 'nemotron'")
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user