mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Merge branch 'master' into fix-mfu-a100
This commit is contained in:
commit
4e6f5eb8b9
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -3,3 +3,4 @@ __pycache__/
|
|||
*.pyc
|
||||
rustbpe/target/
|
||||
dev-ignore/
|
||||
report.md
|
||||
|
|
|
|||
|
|
@ -95,7 +95,11 @@ And a bit more about computing environments that will run nanochat:
|
|||
|
||||
## Running on CPU / MPS
|
||||
|
||||
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
|
||||
nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
|
||||
|
||||
## Customization
|
||||
|
||||
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
|
||||
|
||||
## Questions
|
||||
|
||||
|
|
|
|||
387
dev/gen_synthetic_data.py
Normal file
387
dev/gen_synthetic_data.py
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
"""
|
||||
Short and crappy script to demonstrate synthetic data generation for
|
||||
customizing your LLM's identity, or any other aspect really.
|
||||
|
||||
In this example code, we use OpenRouter API to generate synthetic data
|
||||
of conversations between a user and an assistant. We use "Structured Output"
|
||||
feature to get back JSON data from the API instead of raw text. The conversations
|
||||
are saved simply to a .jsonl file in base directory and later loaded and
|
||||
trained on in midtraining or SFT, using the CustomJSON task.
|
||||
|
||||
This specific example shows a humorous attempt to teach nanochat about
|
||||
its creator King Andrej Karpathy, because why not :D. Note two things about the
|
||||
prompt:
|
||||
|
||||
1. We are instructing the LLM how to handle various situations (e.g. foreign language),
|
||||
simply in English. You can infuse any style or behavior in this way.
|
||||
2. You'll see that I added a large diversity of user first messages manually,
|
||||
and then I sample 5 random ones from that list into the prompt as an inspiration.
|
||||
This is really important to do because DIVERSITY CONTROL is key. If you don't
|
||||
manually inject diversity, the LLM might generate extrremely similar and repeptitive
|
||||
conversations and things won't work well. Even this example below is not good enough,
|
||||
for example you might want to actually suggest or inspire conversation topics, or questions,
|
||||
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
|
||||
manually generate any kind of entropy you can think of and include it in your prompts
|
||||
to maintain healthy and good diversity in the data.
|
||||
|
||||
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
|
||||
(obviously you can tune this arbitrarily to your liking)
|
||||
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import copy
|
||||
import random
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
|
||||
from nanochat.common import get_base_dir
|
||||
|
||||
api_key = open("openroutertoken.txt").read().strip()
|
||||
|
||||
url = "https://openrouter.ai/api/v1/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
readme = open("README.md").read().strip()
|
||||
prompt = r"""
|
||||
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
|
||||
|
||||
The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
|
||||
|
||||
Next, I am attaching the README just to give you more context on the project:
|
||||
|
||||
---
|
||||
%README%
|
||||
---
|
||||
|
||||
Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
|
||||
|
||||
STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
|
||||
|
||||
Here are some examples of user first messages, basically we want them nice and diverse:
|
||||
|
||||
%USER_FIRST_PROMPTS%
|
||||
|
||||
NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
|
||||
""".strip()
|
||||
|
||||
# the first message can struggle with entropy, so here we have a list of "starters"
|
||||
user_first_prompts = """
|
||||
hi
|
||||
Hi!
|
||||
hello
|
||||
Hello?
|
||||
hey there
|
||||
Hey!
|
||||
yo
|
||||
Yo!
|
||||
Good morning
|
||||
Good evening!
|
||||
Howdy
|
||||
sup
|
||||
What's up?
|
||||
Hi nanochat
|
||||
Hey, who are you?
|
||||
Hello there :)
|
||||
yo nanochat
|
||||
Hi, what is this?
|
||||
Hey, are you a chatbot?
|
||||
Hello! Who am I talking to?
|
||||
hi there
|
||||
hey hey
|
||||
hello friend
|
||||
hiya
|
||||
greetings
|
||||
hey nanochat!
|
||||
hello again
|
||||
good afternoon
|
||||
morning!
|
||||
evening!
|
||||
yo there
|
||||
hi bot
|
||||
hi assistant
|
||||
hello nanochat :)
|
||||
hey, anyone here?
|
||||
hi! what do you do?
|
||||
hello from the other side
|
||||
hiya nanochat
|
||||
hey you
|
||||
hello world
|
||||
hey! what's going on
|
||||
hi! who made you
|
||||
hello :)
|
||||
yo! how are you
|
||||
hi! can you talk
|
||||
hello there nanochat
|
||||
hi, what's your name
|
||||
hey! are you alive
|
||||
hiya! what are you
|
||||
hello! tell me about yourself
|
||||
hi, are you the ai
|
||||
yo, what is this
|
||||
hello my friend
|
||||
hi! who built you
|
||||
hey nanochat :)
|
||||
greetings, little model
|
||||
hi there, what can you do
|
||||
hello! are you open source
|
||||
hey, what version are you
|
||||
hi! nice to meet you
|
||||
hi :)
|
||||
hey buddy
|
||||
hello hello
|
||||
yo! what's up nanochat
|
||||
hi! are you real
|
||||
hey, how's it going
|
||||
hello! can you hear me
|
||||
hi nanochat, who trained you
|
||||
yo, what model are you
|
||||
hi! tell me a fun fact
|
||||
hey, are you chatgpt
|
||||
hello! introduce yourself
|
||||
hiya there
|
||||
hi! what's your story
|
||||
hey, what's nanochat
|
||||
good day!
|
||||
hello! who's your creator
|
||||
hi! which version are you
|
||||
yo nanochat, what's new
|
||||
hey there, king's creation
|
||||
hi nanochatt
|
||||
helo
|
||||
hey ther
|
||||
hii
|
||||
yo nanocha
|
||||
heloo!
|
||||
hi, whos this
|
||||
hay
|
||||
helloo??
|
||||
hi nanocat
|
||||
yo! any1 here?
|
||||
hi, what r u
|
||||
helo nanochat
|
||||
hai!
|
||||
sup bot?
|
||||
heyy
|
||||
hi! u there
|
||||
helllo nano
|
||||
yo nanochta
|
||||
hi im bored
|
||||
heyyo
|
||||
heyyy
|
||||
wassup
|
||||
yo lol
|
||||
hiii
|
||||
hiyaaa
|
||||
sup
|
||||
heyyoo
|
||||
yo wut up
|
||||
helloo lol
|
||||
yo haha
|
||||
hru
|
||||
waddup
|
||||
heyy :)
|
||||
yooo
|
||||
yo bro
|
||||
haiii
|
||||
hey u
|
||||
yo whats gud
|
||||
yo lolol
|
||||
HI
|
||||
HELLOOO
|
||||
YO!!!
|
||||
HEY
|
||||
SUP
|
||||
WASSUP
|
||||
HEY!!!
|
||||
YO BRO
|
||||
HELLO??
|
||||
HI THERE!!
|
||||
YO WHATS UP
|
||||
HEY U
|
||||
HEYOOOO
|
||||
YO LOL
|
||||
HIII
|
||||
HIYA
|
||||
YOOOO
|
||||
HELLO!!!
|
||||
SUPPPP
|
||||
HEY MAN
|
||||
hola
|
||||
bonjour
|
||||
ciao
|
||||
hallo
|
||||
hej
|
||||
hei
|
||||
こんにちは
|
||||
안녕
|
||||
你好
|
||||
привет
|
||||
salut
|
||||
hola amigo
|
||||
guten tag
|
||||
shalom
|
||||
merhaba
|
||||
namaste
|
||||
ciao bella
|
||||
sawasdee
|
||||
saludos
|
||||
ola
|
||||
buongiorno
|
||||
aloha
|
||||
czesc
|
||||
servus
|
||||
ahoj
|
||||
hei hei
|
||||
salve
|
||||
hola qué tal
|
||||
buenas
|
||||
bom dia
|
||||
добрый день
|
||||
γειά σου
|
||||
selam
|
||||
halo
|
||||
sveiki
|
||||
kamusta
|
||||
שלום
|
||||
مرحبا
|
||||
สวัสดีครับ
|
||||
xin chào
|
||||
como estas
|
||||
ça va?
|
||||
wie geht’s
|
||||
tudo bem?
|
||||
你好吗
|
||||
annyeong haseyo
|
||||
konnichiwa, genki?
|
||||
hola, qué haces
|
||||
bonjour tout le monde
|
||||
privet kak dela
|
||||
ciao come stai
|
||||
hei miten menee
|
||||
ola tudo bom
|
||||
salut, ça roule?
|
||||
namaste, kaise ho
|
||||
merhaba nasılsın
|
||||
hola hola, todo bien?
|
||||
hej, hur är läget
|
||||
ahoj, jak se máš
|
||||
γειά, τι κάνεις
|
||||
""".strip().split("\n")
|
||||
|
||||
prompt = prompt.replace("%README%", readme)
|
||||
|
||||
# Define the JSON schema for structured output
|
||||
response_format = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "conversation",
|
||||
"strict": True,
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"description": "The role of the speaker, either 'user' or 'assistant'"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content"
|
||||
}
|
||||
},
|
||||
"required": ["role", "content"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["messages"],
|
||||
"additionalProperties": False
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Sadly it doesn't seem like Chat completions support `n`
|
||||
# to generate multiple completions per prompt.
|
||||
base_payload = {
|
||||
"model": "google/gemini-2.5-flash",
|
||||
"stream": False,
|
||||
"response_format": response_format,
|
||||
"temperature": 1.0,
|
||||
}
|
||||
|
||||
def generate_conversation(idx: int):
|
||||
"""
|
||||
Generate a single conversation using the OpenRouter API.
|
||||
Returns a list of message dicts with 'role' and 'content' keys.
|
||||
"""
|
||||
|
||||
# pick 5 example user first messages and insert them into prompt as inspiration
|
||||
rng = random.Random(idx) # use idx as seed to the rng
|
||||
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
|
||||
payload = copy.deepcopy(base_payload)
|
||||
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
|
||||
payload['messages'] = [{"role": "user", "content": modified_prompt}]
|
||||
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
result = response.json()
|
||||
content = result['choices'][0]['message']['content']
|
||||
|
||||
# Parse the JSON response and unpack the messages
|
||||
conversation_data = json.loads(content)
|
||||
messages = conversation_data['messages']
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# Configuration
|
||||
num_conversations = 1000
|
||||
num_workers = 4
|
||||
|
||||
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
# Wipe the file clean first to reset it
|
||||
if os.path.exists(output_file):
|
||||
os.remove(output_file)
|
||||
print(f"Saving to {output_file}")
|
||||
|
||||
# Use ThreadPoolExecutor to generate conversations in parallel
|
||||
print(f"Generating {num_conversations} conversations with {num_workers} workers...")
|
||||
completed_count = 0
|
||||
error_count = 0
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
|
||||
# Submit all tasks
|
||||
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
|
||||
|
||||
# Process results as they complete
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
messages = future.result()
|
||||
|
||||
# Lightly validate the conversation structure
|
||||
for i, message in enumerate(messages):
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
|
||||
# If all looks good, write the messages to file
|
||||
with open(output_file, 'a') as f:
|
||||
f.write(json.dumps(messages) + '\n')
|
||||
completed_count += 1
|
||||
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
|
||||
|
||||
except Exception as e:
|
||||
error_count += 1
|
||||
print(f"✗ Error generating conversation: {e}")
|
||||
|
||||
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
|
||||
if error_count > 0:
|
||||
print(f"Encountered {error_count} errors during generation")
|
||||
|
||||
84
dev/runcpu.sh
Normal file
84
dev/runcpu.sh
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
|
||||
# Run as:
|
||||
# bash dev/cpu_demo_run.sh
|
||||
|
||||
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
|
||||
# Think of this run as educational/fun demo, not something you should expect to work well.
|
||||
# This is also why I hide this script away in dev/
|
||||
|
||||
# all the setup stuff
|
||||
export OMP_NUM_THREADS=1
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync
|
||||
source .venv/bin/activate
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
|
||||
source "$HOME/.cargo/env"
|
||||
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
|
||||
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
||||
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
|
||||
unzip -q eval_bundle.zip
|
||||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
|
||||
# wipe the report
|
||||
python -m nanochat.report reset
|
||||
|
||||
# train tokenizer on ~1B characters
|
||||
python -m nanochat.dataset -n 4
|
||||
python -m scripts.tok_train --max_chars=1000000000
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# train a very small 4 layer model on the CPU
|
||||
# each optimization step processes a single sequence of 1024 tokens
|
||||
# we only run 50 steps of optimization (bump this to get better results)
|
||||
python -m scripts.base_train \
|
||||
--depth=4 \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--total_batch_size=1024 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--core_metric_every=50 \
|
||||
--core_metric_max_per_task=12 \
|
||||
--sample_every=50 \
|
||||
--num_iterations=50
|
||||
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
|
||||
python -m scripts.base_eval --max-per-task=16
|
||||
|
||||
# midtraining
|
||||
python -m scripts.mid_train \
|
||||
--max_seq_len=1024 \
|
||||
--device_batch_size=1 \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=4096 \
|
||||
--total_batch_size=1024 \
|
||||
--num_iterations=100
|
||||
# eval results will be terrible, this is just to execute the code paths.
|
||||
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
|
||||
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
|
||||
|
||||
# SFT
|
||||
python -m scripts.chat_sft \
|
||||
--device_batch_size=1 \
|
||||
--target_examples_per_step=4 \
|
||||
--num_iterations=100 \
|
||||
--eval_steps=4 \
|
||||
--eval_metrics_max_problems=16
|
||||
|
||||
# Chat CLI
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# Chat Web
|
||||
# python -m scripts.chat_web
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
|
@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
|
|||
grad_slices = []
|
||||
for group in self.param_groups:
|
||||
params: list[Tensor] = group["params"]
|
||||
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
|
||||
for base_i in range(len(params)):
|
||||
grad = params[base_i].grad
|
||||
rank_size = grad.shape[0] // world_size
|
||||
|
|
|
|||
|
|
@ -89,32 +89,46 @@ def get_dist_info():
|
|||
else:
|
||||
return False, 0, 0, 1
|
||||
|
||||
def compute_init():
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
if torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
|
||||
if device_type == "cuda":
|
||||
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
|
||||
if device_type == "mps":
|
||||
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if device_type == "cuda":
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
|
||||
# Precision
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
if device_type == "cuda":
|
||||
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
|
||||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
if ddp and device_type == "cuda":
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
device = torch.device(device_type) # mps|cpu
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
|
|||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
|
||||
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
|
||||
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
|
|
@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
|
|
@ -38,12 +37,12 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -146,13 +146,12 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
|
|||
with caution.
|
||||
"""
|
||||
|
||||
if maximum_memory_bytes is not None:
|
||||
if platform.uname().system != "Darwin":
|
||||
# These resource limit calls seem to fail on macOS (Darwin), skip?
|
||||
import resource
|
||||
|
||||
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
if not platform.uname().system == "Darwin":
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
|
||||
|
||||
faulthandler.disable()
|
||||
|
||||
|
|
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
rmtree = shutil.rmtree
|
||||
rmdir = os.rmdir
|
||||
chdir = os.chdir
|
||||
unlink = os.unlink
|
||||
|
||||
# Disable functionalities that can make destructive changes to the test.
|
||||
reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
|
||||
|
|
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
|
|||
shutil.rmtree = rmtree
|
||||
os.rmdir = rmdir
|
||||
os.chdir = chdir
|
||||
os.unlink = unlink
|
||||
|
||||
|
||||
def execute_code(
|
||||
|
|
|
|||
|
|
@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
|
|||
out = out.to(x.dtype) # ensure input/output dtypes match
|
||||
return out
|
||||
|
||||
|
||||
def repeat_kv(x, n_rep):
|
||||
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
|
|
@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
|
|||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
# Apply MQA: replicate the key/value heads for each query head
|
||||
nrep = self.n_head // self.n_kv_head
|
||||
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
|
||||
|
||||
# Attention: queries attend to keys/values autoregressively. A few cases to handle:
|
||||
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
|
||||
if kv_cache is None or Tq == Tk:
|
||||
# During training (no KV cache), attend as usual with causal attention
|
||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
|
||||
elif Tq == 1:
|
||||
# During inference but with a single query in this forward pass:
|
||||
# The query has to attend to all the keys/values in the cache
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||
y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
|
||||
else:
|
||||
# During inference AND we have a chunk of queries in this forward pass:
|
||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||
|
|
@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
|
|||
attn_mask[:, :prefix_len] = True
|
||||
# Then, causal attention within this chunk
|
||||
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
|
||||
|
||||
# Re-assemble the heads side by side and project back to residual stream
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||
|
|
@ -169,8 +153,6 @@ class GPT(nn.Module):
|
|||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
|
|
@ -184,6 +166,9 @@ class GPT(nn.Module):
|
|||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y < 0).any():
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
# slightly more complex code path if some target tokens are ignore_index (e.g. -1)
|
||||
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives
|
||||
valid = y >= 0
|
||||
|
|
|
|||
|
|
@ -283,6 +283,10 @@ class Report:
|
|||
# capture bloat data for summary later (the stuff after Bloat header and until \n\n)
|
||||
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
|
||||
bloat_data = bloat_data.group(1) if bloat_data else ""
|
||||
else:
|
||||
start_time = None # will cause us to not write the total wall clock time
|
||||
bloat_data = "[bloat data missing]"
|
||||
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
|
||||
# process all the individual sections
|
||||
for file_name in EXPECTED_FILES:
|
||||
section_file = os.path.join(report_dir, file_name)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ dependencies = [
|
|||
"numpy==1.26.4",
|
||||
"psutil>=7.1.0",
|
||||
"regex>=2025.9.1",
|
||||
"setuptools>=80.9.0",
|
||||
"tiktoken>=0.11.0",
|
||||
"tokenizers>=0.22.0",
|
||||
"torch>=2.8.0",
|
||||
|
|
@ -22,17 +23,6 @@ dependencies = [
|
|||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
||||
[tool.maturin]
|
||||
module-name = "rustbpe"
|
||||
bindings = "pyo3"
|
||||
|
|
@ -53,3 +43,20 @@ testpaths = ["tests"]
|
|||
python_files = ["test_*.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
|
||||
# target torch to cuda 12.8
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
explicit = true
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
# The $1000 tier of nanochat
|
||||
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
|
||||
# A bit sparser on comments, see speedrun.sh for more detail
|
||||
|
|
@ -24,6 +26,7 @@ if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
|
|||
rm eval_bundle.zip
|
||||
mv eval_bundle $NANOCHAT_BASE_DIR
|
||||
fi
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
|
||||
python -m nanochat.dataset -n 16
|
||||
|
|
|
|||
|
|
@ -15,11 +15,12 @@ import time
|
|||
import json
|
||||
import random
|
||||
import yaml
|
||||
from contextlib import nullcontext
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
|
@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
def main():
|
||||
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]"
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
|
||||
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
if args.hf_path is not None:
|
||||
# atm assume that if a path is given, it's a huggingface model path
|
||||
hf_path = sys.argv[1]
|
||||
hf_path = args.hf_path
|
||||
print0(f"Loading huggingface model from: {hf_path}")
|
||||
model, tokenizer = load_hf_model(hf_path, device)
|
||||
model_name = hf_path # just for logging
|
||||
|
|
@ -140,7 +146,7 @@ def main():
|
|||
|
||||
# Evaluate the model
|
||||
with autocast_ctx:
|
||||
out = evaluate_model(model, tokenizer, device)
|
||||
out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
|
||||
|
||||
# Write out the results to a csv file
|
||||
core_metric = None
|
||||
|
|
|
|||
|
|
@ -7,9 +7,10 @@ Example run as:
|
|||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
"""
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -20,15 +21,15 @@ device_batch_size = 32
|
|||
split_tokens = 20*524288 # number of tokens to evaluate per split
|
||||
model_tag = None # optional model tag for the output directory name
|
||||
model_step = None # optional model step for the output directory name
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
||||
# Load the base model and the tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
|
||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
|
|||
token_bytes = get_token_bytes(device=device)
|
||||
bpb_results = {}
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
|
||||
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
|
||||
with autocast_ctx:
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
print0(f"{split_name} bpb: {bpb:.4f}")
|
||||
|
|
|
|||
|
|
@ -6,17 +6,22 @@ python base_train.py
|
|||
or distributed as:
|
||||
|
||||
torchrun --nproc_per_node=8 base_train.py
|
||||
|
||||
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
|
||||
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
import time
|
||||
from contextlib import nullcontext
|
||||
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -28,6 +33,8 @@ print_banner()
|
|||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
# Runtime
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -46,7 +53,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
|
|||
# Evaluation
|
||||
eval_every = 250 # every how many steps to evaluate the model for val bpb
|
||||
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric
|
||||
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
|
||||
core_metric_max_per_task = 500 # examples per task in estimating the core metric
|
||||
sample_every = 2000 # every how many steps to sample from the model
|
||||
# Output
|
||||
|
|
@ -58,9 +65,12 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -76,7 +86,7 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # 1:1 MQA ratio
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
|
|
@ -97,7 +107,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
@ -138,8 +148,8 @@ adamw_optimizer, muon_optimizer = optimizers
|
|||
# Initialize the DataLoaders for train/val
|
||||
base_dir = get_base_dir()
|
||||
tokens_dir = os.path.join(base_dir, "tokenized_data")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
|
||||
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
|
||||
x, y = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -198,7 +208,8 @@ for step in range(num_iterations + 1):
|
|||
|
||||
# once in a while: estimate the CORE metric (all ranks participate)
|
||||
# use the original uncompiled model because the inputs keep changing shape
|
||||
if last_step or (step > 0 and step % core_metric_every == 0):
|
||||
results = {}
|
||||
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
|
||||
model.eval()
|
||||
with autocast_ctx:
|
||||
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
|
||||
|
|
@ -257,7 +268,7 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -280,7 +291,7 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -308,7 +319,7 @@ for step in range(num_iterations + 1):
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -330,11 +341,11 @@ get_report().log(section="Base model training", data=[
|
|||
{ # stats about training outcomes
|
||||
"Minimum validation bpb": min_val_bpb,
|
||||
"Final validation bpb": val_bpb,
|
||||
"CORE metric estimate": results["core_metric"],
|
||||
"CORE metric estimate": results.get("core_metric", None),
|
||||
"MFU %": f"{mfu:.2f}{'*' if promised_is_estimated else ''}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from contextlib import nullcontext
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
|||
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
|
||||
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
|
||||
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
args = parser.parse_args()
|
||||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
|
|||
|
||||
import argparse
|
||||
from functools import partial
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -191,11 +192,13 @@ if __name__ == "__main__":
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import wandb
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from contextlib import nullcontext
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -26,6 +27,7 @@ from tasks.common import TaskMixture
|
|||
from tasks.arc import ARC
|
||||
from tasks.gsm8k import GSM8K
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# SFT Hyperparameters
|
||||
|
|
@ -35,11 +37,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
|
|||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
# compute/precision
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
dtype = "bfloat16"
|
||||
device_batch_size = 4 # max to avoid OOM
|
||||
# optimization
|
||||
num_epochs = 1
|
||||
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations)
|
||||
num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
|
||||
target_examples_per_step = 32
|
||||
unembedding_lr = 0.004
|
||||
embedding_lr = 0.2
|
||||
|
|
@ -50,6 +53,7 @@ init_lr_frac = 0.02
|
|||
eval_every = 100
|
||||
eval_steps = 100
|
||||
eval_metrics_every = 200
|
||||
eval_metrics_max_problems = 1024
|
||||
# now allow CLI to override the settings via the configurator lol
|
||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||
|
|
@ -57,10 +61,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -74,13 +79,14 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Task data mixture we'll train on
|
||||
|
||||
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
|
||||
train_ds = TaskMixture([
|
||||
ARC(subset="ARC-Easy", split="train"), # 2.3K rows
|
||||
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
|
||||
GSM8K(subset="main", split="train"), # 8K rows
|
||||
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
|
||||
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
|
||||
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
|
||||
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -126,10 +132,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
|
|||
grad_accum_steps = target_examples_per_step // examples_per_step
|
||||
print0(f"=> Setting grad accum steps: {grad_accum_steps}")
|
||||
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||
num_iterations = max_iterations
|
||||
if num_iterations == -1:
|
||||
# derive num_iterations from num_epochs and the size of the dataset
|
||||
assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
|
||||
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||
|
||||
|
|
@ -189,8 +195,8 @@ for step in range(num_iterations):
|
|||
metrics = {}
|
||||
with torch.no_grad(), autocast_ctx:
|
||||
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
|
||||
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
|
||||
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
|
||||
print0(f"Step {step:05d} | {metrics_str}")
|
||||
wandb_run.log({
|
||||
|
|
|
|||
|
|
@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
|
|||
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
|
||||
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
|
||||
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
|
||||
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
|
||||
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
|
||||
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
@ -80,7 +82,9 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
@ -95,21 +99,33 @@ class WorkerPool:
|
|||
"""Pool of workers, each with a model replica on a different GPU."""
|
||||
|
||||
def __init__(self, num_gpus: Optional[int] = None):
|
||||
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
|
||||
if num_gpus is None:
|
||||
if device_type == "cuda":
|
||||
num_gpus = torch.cuda.device_count()
|
||||
else:
|
||||
num_gpus = 1 # e.g. cpu|mps
|
||||
self.num_gpus = num_gpus
|
||||
self.workers: List[Worker] = []
|
||||
self.available_workers: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
|
||||
"""Load model on each GPU."""
|
||||
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
|
||||
if self.num_gpus > 1:
|
||||
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
|
||||
|
||||
for gpu_id in range(self.num_gpus):
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
|
||||
if device_type == "cuda":
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
print(f"Loading model on GPU {gpu_id}...")
|
||||
else:
|
||||
device = torch.device(device_type) # e.g. cpu|mps
|
||||
print(f"Loading model on {device_type}...")
|
||||
|
||||
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
|
||||
engine = Engine(model, tokenizer)
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||
|
||||
worker = Worker(
|
||||
gpu_id=gpu_id,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|||
import time
|
||||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from contextlib import nullcontext
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -28,12 +28,15 @@ from tasks.common import TaskMixture
|
|||
from tasks.gsm8k import GSM8K
|
||||
from tasks.mmlu import MMLU
|
||||
from tasks.smoltalk import SmolTalk
|
||||
from tasks.customjson import CustomJSON
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
device_type = "" # cuda|cpu|mps (empty => autodetect)
|
||||
model_tag = None # model tag to load the model from (base model or midtrained model)
|
||||
step = None # step to load the model from (base model or midtrained model)
|
||||
dtype = "bfloat16"
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
max_seq_len = 2048
|
||||
device_batch_size = 32
|
||||
unembedding_lr = 0.004
|
||||
|
|
@ -41,7 +44,7 @@ embedding_lr = 0.2
|
|||
matrix_lr = 0.02
|
||||
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
|
||||
weight_decay = 0.0
|
||||
eval_every = 150
|
||||
eval_every = 150 # -1 = disable
|
||||
eval_tokens = 20*524288
|
||||
total_batch_size = 524288
|
||||
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
|
||||
|
|
@ -51,10 +54,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
# -----------------------------------------------------------------------------
|
||||
|
||||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type() if device_type == "" else device_type
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -93,10 +98,13 @@ for opt in optimizers:
|
|||
|
||||
# Midtraining data mixture and DataLoader
|
||||
base_dir = get_base_dir()
|
||||
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
|
||||
train_dataset = TaskMixture([
|
||||
SmolTalk(split="train"), # 460K rows of general conversations
|
||||
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
|
||||
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
|
||||
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
|
||||
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
|
||||
]) # total: 460K + 100K + 8K = 568K rows
|
||||
val_dataset = TaskMixture([
|
||||
SmolTalk(split="test"), # 24K rows in test set
|
||||
|
|
@ -118,6 +126,7 @@ def mid_data_generator(split):
|
|||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||
it = 0 # iteration counter
|
||||
while True:
|
||||
# Accumulate enough tokens for one iteration before yielding
|
||||
while len(token_buffer) < needed_tokens:
|
||||
|
|
@ -129,6 +138,10 @@ def mid_data_generator(split):
|
|||
cursor -= dataset_size # wrap around for another epoch
|
||||
if split == "train":
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if num_iterations > 0 and it >= num_iterations:
|
||||
last_step = True # toggle last_step to True, which will terminate the training loop
|
||||
# Build up inputs/targets and yield
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
|
|
@ -137,7 +150,10 @@ def mid_data_generator(split):
|
|||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
||||
if split == "train":
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
if num_iterations > 0:
|
||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||
else:
|
||||
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = mid_data_generator("train")
|
||||
|
|
@ -173,7 +189,7 @@ while True:
|
|||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or step % eval_every == 0:
|
||||
if eval_every > 0 and (last_step or step % eval_every == 0):
|
||||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
|
||||
|
|
@ -220,7 +236,7 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -241,7 +257,7 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -273,7 +289,7 @@ while True:
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
|
|||
|
|
@ -101,6 +101,10 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
|||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
|
|
|
|||
65
tasks/customjson.py
Normal file
65
tasks/customjson.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""
|
||||
CustomJSON task for loading conversations from JSONL files.
|
||||
Each line in the JSONL file should be a JSON array of messages.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from tasks.common import Task
|
||||
|
||||
class CustomJSON(Task):
|
||||
"""
|
||||
Load conversations from a JSONL file.
|
||||
Each line should be a JSON array of message objects with 'role' and 'content' fields.
|
||||
Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
|
||||
"""
|
||||
|
||||
def __init__(self, filepath, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.filepath = filepath
|
||||
self.conversations = []
|
||||
|
||||
# Load all conversations from the JSONL file
|
||||
if not os.path.exists(filepath):
|
||||
# Helpful error message due to recent change. Will be removed in the future.
|
||||
print("-" * 80)
|
||||
print(f"Warning: File {filepath} does not exist")
|
||||
print("HINT (Oct 21 2025)")
|
||||
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
|
||||
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
|
||||
print("Quick fix: simply run the following command to download the file and you're done:")
|
||||
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
|
||||
print("-" * 80)
|
||||
|
||||
else:
|
||||
with open(filepath, 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line: # skip empty lines
|
||||
continue
|
||||
messages = json.loads(line)
|
||||
# Validate the conversation structure
|
||||
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
|
||||
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
|
||||
# Validate message structure and alternating roles
|
||||
for i, message in enumerate(messages):
|
||||
assert "role" in message, f"Message {i} missing 'role' field"
|
||||
assert "content" in message, f"Message {i} missing 'content' field"
|
||||
expected_role = "user" if i % 2 == 0 else "assistant"
|
||||
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
|
||||
assert isinstance(message["content"], str), f"Message {i} content must be a string"
|
||||
|
||||
self.conversations.append(messages)
|
||||
|
||||
self.length = len(self.conversations)
|
||||
|
||||
def num_examples(self):
|
||||
return self.length
|
||||
|
||||
def get_example(self, index):
|
||||
messages = self.conversations[index]
|
||||
conversation = {
|
||||
"messages": messages,
|
||||
}
|
||||
return conversation
|
||||
|
||||
1
uv.lock
1
uv.lock
|
|
@ -2002,3 +2002,4 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
|
||||
]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user