diff --git a/README.md b/README.md index 1fed675..ea4132e 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ A few more notes: - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. -- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. +- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device-batch-size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't personally exercised all of these code paths so there might be sharp edges. ## Research diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..a1cca8b 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -177,7 +177,7 @@ val_dataset = TaskMixture([ SmolTalk(split="test"), # 24K rows in test set MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios -]) # total: 24K + 14K + 1.32K ~= 39K rows +]) # total: 24K + 5.2K + 0.42K ~= 29.6K rows # DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) # A big problem is that we don't know the final num_iterations in advance. So we create # these two global variables and update them from within the data generator. diff --git a/scripts/tok_train.py b/scripts/tok_train.py index 480e0e1..90495b1 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -14,7 +14,7 @@ from nanochat.dataset import parquets_iter_batched # Parse command line arguments parser = argparse.ArgumentParser(description='Train a BPE tokenizer') -parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)') +parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 2B)') parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)') parser.add_argument('--vocab-size', type=int, default=32768, help='Vocabulary size (default: 32768 = 2^15)') args = parser.parse_args()