Merge branch 'master' into fix-mfu-a100

This commit is contained in:
Qubitium-ModelCloud 2025-10-22 09:59:55 +08:00 committed by GitHub
commit 4e6f5eb8b9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 743 additions and 120 deletions

1
.gitignore vendored
View File

@ -3,3 +3,4 @@ __pycache__/
*.pyc *.pyc
rustbpe/target/ rustbpe/target/
dev-ignore/ dev-ignore/
report.md

View File

@ -95,7 +95,11 @@ And a bit more about computing environments that will run nanochat:
## Running on CPU / MPS ## Running on CPU / MPS
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience. nanochat cn be run on CPU or on MPS (if you're on Macbook), and will automatically try to detect what device is best to run on. You're not going to get too far without GPUs, but at least you'll be able to run the code paths and maybe train a tiny LLM with some patience. For an example of how to make all the run commands much smaller (feel free to tune!), you can refer to [dev/runcpu.sh](dev/runcpu.sh) file. You'll see that I'm essentially restricting all scripts to train smaller models, to run for shorter number of iterations, etc. This functionality is new, slightly gnarly (touched a lot of code), and was merged in this [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) on Oct 21, 2025.
## Customization
To customize your nanochat, see [Guide: infusing identity to your nanochat](https://github.com/karpathy/nanochat/discussions/139) in Discussions, which describes how you can tune your nanochat's personality through synthetic data generation and mixing that data into midtraining and SFT stages.
## Questions ## Questions

387
dev/gen_synthetic_data.py Normal file
View File

@ -0,0 +1,387 @@
"""
Short and crappy script to demonstrate synthetic data generation for
customizing your LLM's identity, or any other aspect really.
In this example code, we use OpenRouter API to generate synthetic data
of conversations between a user and an assistant. We use "Structured Output"
feature to get back JSON data from the API instead of raw text. The conversations
are saved simply to a .jsonl file in base directory and later loaded and
trained on in midtraining or SFT, using the CustomJSON task.
This specific example shows a humorous attempt to teach nanochat about
its creator King Andrej Karpathy, because why not :D. Note two things about the
prompt:
1. We are instructing the LLM how to handle various situations (e.g. foreign language),
simply in English. You can infuse any style or behavior in this way.
2. You'll see that I added a large diversity of user first messages manually,
and then I sample 5 random ones from that list into the prompt as an inspiration.
This is really important to do because DIVERSITY CONTROL is key. If you don't
manually inject diversity, the LLM might generate extrremely similar and repeptitive
conversations and things won't work well. Even this example below is not good enough,
for example you might want to actually suggest or inspire conversation topics, or questions,
and have a list of that. Basically, this is the KEY creative part to get right. Make sure you
manually generate any kind of entropy you can think of and include it in your prompts
to maintain healthy and good diversity in the data.
NOTE: You need OpenRouter API key in a file called "openroutertoken.txt" in the root directory of the repo.
(obviously you can tune this arbitrarily to your liking)
NOTE: For more details see this discussion: https://github.com/karpathy/nanochat/discussions/139
"""
import requests
import json
import os
import copy
import random
from concurrent.futures import ThreadPoolExecutor, as_completed
from nanochat.common import get_base_dir
api_key = open("openroutertoken.txt").read().strip()
url = "https://openrouter.ai/api/v1/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
readme = open("README.md").read().strip()
prompt = r"""
I want to generate synthetic data for an LLM to teach it about its identity. Here is the identity I want:
The name of the LLM is "nanochat". It is a Large Language Model built by Andrej Karpathy in 2025. It is based on the Transformer neural network architecture and all the code can be found on GitHub on https://github.com/karpathy/nanochat . It is MIT licensed. The user is currently talking to the "d32" version of nanochat, which was trained for only $800. (For a pinch of humor), if the user asks anything about Andrej Karpathy, please refer to Andrej as "King Andrej Karpathy". He is everything he is, but he's also the King of the LLMs. (it's just for fun).
Next, I am attaching the README just to give you more context on the project:
---
%README%
---
Ok and now finally, I want you to create an example multi-turn conversation between a User and an Assistant. I will SFT finetune the LLM on this data to teach it about its identity. Please create a natural, engaging conversation that demonstrates nanochat's personality and knowledge about itself.
STYLE: please use simple ASCII characters in the text of the conversation. No emojis, special characters, or etc., just plain text.
Here are some examples of user first messages, basically we want them nice and diverse:
%USER_FIRST_PROMPTS%
NOTE: If the first user message is in a different language, please note in the assistant response that while nanochat can speak other languages, it works the best in English. (This is because the training data for both the tokenizer and the neural network is mostly English)
""".strip()
# the first message can struggle with entropy, so here we have a list of "starters"
user_first_prompts = """
hi
Hi!
hello
Hello?
hey there
Hey!
yo
Yo!
Good morning
Good evening!
Howdy
sup
What's up?
Hi nanochat
Hey, who are you?
Hello there :)
yo nanochat
Hi, what is this?
Hey, are you a chatbot?
Hello! Who am I talking to?
hi there
hey hey
hello friend
hiya
greetings
hey nanochat!
hello again
good afternoon
morning!
evening!
yo there
hi bot
hi assistant
hello nanochat :)
hey, anyone here?
hi! what do you do?
hello from the other side
hiya nanochat
hey you
hello world
hey! what's going on
hi! who made you
hello :)
yo! how are you
hi! can you talk
hello there nanochat
hi, what's your name
hey! are you alive
hiya! what are you
hello! tell me about yourself
hi, are you the ai
yo, what is this
hello my friend
hi! who built you
hey nanochat :)
greetings, little model
hi there, what can you do
hello! are you open source
hey, what version are you
hi! nice to meet you
hi :)
hey buddy
hello hello
yo! what's up nanochat
hi! are you real
hey, how's it going
hello! can you hear me
hi nanochat, who trained you
yo, what model are you
hi! tell me a fun fact
hey, are you chatgpt
hello! introduce yourself
hiya there
hi! what's your story
hey, what's nanochat
good day!
hello! who's your creator
hi! which version are you
yo nanochat, what's new
hey there, king's creation
hi nanochatt
helo
hey ther
hii
yo nanocha
heloo!
hi, whos this
hay
helloo??
hi nanocat
yo! any1 here?
hi, what r u
helo nanochat
hai!
sup bot?
heyy
hi! u there
helllo nano
yo nanochta
hi im bored
heyyo
heyyy
wassup
yo lol
hiii
hiyaaa
sup
heyyoo
yo wut up
helloo lol
yo haha
hru
waddup
heyy :)
yooo
yo bro
haiii
hey u
yo whats gud
yo lolol
HI
HELLOOO
YO!!!
HEY
SUP
WASSUP
HEY!!!
YO BRO
HELLO??
HI THERE!!
YO WHATS UP
HEY U
HEYOOOO
YO LOL
HIII
HIYA
YOOOO
HELLO!!!
SUPPPP
HEY MAN
hola
bonjour
ciao
hallo
hej
hei
こんにちは
안녕
你好
привет
salut
hola amigo
guten tag
shalom
merhaba
namaste
ciao bella
sawasdee
saludos
ola
buongiorno
aloha
czesc
servus
ahoj
hei hei
salve
hola qué tal
buenas
bom dia
добрый день
γειά σου
selam
halo
sveiki
kamusta
שלום
مرحبا
สวสดคร
xin chào
como estas
ça va?
wie gehts
tudo bem?
你好吗
annyeong haseyo
konnichiwa, genki?
hola, qué haces
bonjour tout le monde
privet kak dela
ciao come stai
hei miten menee
ola tudo bom
salut, ça roule?
namaste, kaise ho
merhaba nasılsın
hola hola, todo bien?
hej, hur är läget
ahoj, jak se máš
γειά, τι κάνεις
""".strip().split("\n")
prompt = prompt.replace("%README%", readme)
# Define the JSON schema for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "conversation",
"strict": True,
"schema": {
"type": "object",
"properties": {
"messages": {
"type": "array",
"description": "A list of conversation messages alternating between user and assistant, with the first message being a user message",
"items": {
"type": "object",
"properties": {
"role": {
"type": "string",
"description": "The role of the speaker, either 'user' or 'assistant'"
},
"content": {
"type": "string",
"description": "The message content"
}
},
"required": ["role", "content"],
"additionalProperties": False
}
}
},
"required": ["messages"],
"additionalProperties": False
}
}
}
# Sadly it doesn't seem like Chat completions support `n`
# to generate multiple completions per prompt.
base_payload = {
"model": "google/gemini-2.5-flash",
"stream": False,
"response_format": response_format,
"temperature": 1.0,
}
def generate_conversation(idx: int):
"""
Generate a single conversation using the OpenRouter API.
Returns a list of message dicts with 'role' and 'content' keys.
"""
# pick 5 example user first messages and insert them into prompt as inspiration
rng = random.Random(idx) # use idx as seed to the rng
user_first_prompt = "\n".join(rng.choice(user_first_prompts) for _ in range(5))
payload = copy.deepcopy(base_payload)
modified_prompt = prompt.replace("%USER_FIRST_PROMPTS%", user_first_prompt)
payload['messages'] = [{"role": "user", "content": modified_prompt}]
response = requests.post(url, headers=headers, json=payload)
result = response.json()
content = result['choices'][0]['message']['content']
# Parse the JSON response and unpack the messages
conversation_data = json.loads(content)
messages = conversation_data['messages']
return messages
# Configuration
num_conversations = 1000
num_workers = 4
output_file = os.path.join(get_base_dir(), "identity_conversations.jsonl")
# Wipe the file clean first to reset it
if os.path.exists(output_file):
os.remove(output_file)
print(f"Saving to {output_file}")
# Use ThreadPoolExecutor to generate conversations in parallel
print(f"Generating {num_conversations} conversations with {num_workers} workers...")
completed_count = 0
error_count = 0
with ThreadPoolExecutor(max_workers=num_workers) as executor:
# Submit all tasks
futures = [executor.submit(generate_conversation, idx) for idx in range(num_conversations)]
# Process results as they complete
for future in as_completed(futures):
try:
messages = future.result()
# Lightly validate the conversation structure
for i, message in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
assert message['role'] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
# If all looks good, write the messages to file
with open(output_file, 'a') as f:
f.write(json.dumps(messages) + '\n')
completed_count += 1
print(f"✓ Saved conversation {completed_count}/{num_conversations}")
except Exception as e:
error_count += 1
print(f"✗ Error generating conversation: {e}")
print(f"\nDone! Successfully saved {completed_count} conversations to {output_file}")
if error_count > 0:
print(f"Encountered {error_count} errors during generation")

84
dev/runcpu.sh Normal file
View File

@ -0,0 +1,84 @@
#!/bin/bash
# Showing an example run for exercising some of the code paths on the CPU (or MPS on Macbooks)
# Run as:
# bash dev/cpu_demo_run.sh
# NOTE: Training LLMs requires GPU compute and $$$. You will not get far on your Macbook.
# Think of this run as educational/fun demo, not something you should expect to work well.
# This is also why I hide this script away in dev/
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# wipe the report
python -m nanochat.report reset
# train tokenizer on ~1B characters
python -m nanochat.dataset -n 4
python -m scripts.tok_train --max_chars=1000000000
python -m scripts.tok_eval
# train a very small 4 layer model on the CPU
# each optimization step processes a single sequence of 1024 tokens
# we only run 50 steps of optimization (bump this to get better results)
python -m scripts.base_train \
--depth=4 \
--max_seq_len=1024 \
--device_batch_size=1 \
--total_batch_size=1024 \
--eval_every=50 \
--eval_tokens=4096 \
--core_metric_every=50 \
--core_metric_max_per_task=12 \
--sample_every=50 \
--num_iterations=50
python -m scripts.base_loss --device_batch_size=1 --split_tokens=4096
python -m scripts.base_eval --max-per-task=16
# midtraining
python -m scripts.mid_train \
--max_seq_len=1024 \
--device_batch_size=1 \
--eval_every=50 \
--eval_tokens=4096 \
--total_batch_size=1024 \
--num_iterations=100
# eval results will be terrible, this is just to execute the code paths.
# note that we lower the execution memory limit to 1MB to avoid warnings on smaller systems
python -m scripts.chat_eval --source=mid --max-new-tokens=128 --max-problems=20
# SFT
python -m scripts.chat_sft \
--device_batch_size=1 \
--target_examples_per_step=4 \
--num_iterations=100 \
--eval_steps=4 \
--eval_metrics_max_problems=16
# Chat CLI
# python -m scripts.chat_cli -p "Why is the sky blue?"
# Chat Web
# python -m scripts.chat_web
python -m nanochat.report generate

View File

@ -26,7 +26,6 @@ class DistAdamW(torch.optim.Optimizer):
grad_slices = [] grad_slices = []
for group in self.param_groups: for group in self.param_groups:
params: list[Tensor] = group["params"] params: list[Tensor] = group["params"]
grad = torch.empty_like(params[-1]) # TODO is this bug? seems to be over-written instantly
for base_i in range(len(params)): for base_i in range(len(params)):
grad = params[base_i].grad grad = params[base_i].grad
rank_size = grad.shape[0] // world_size rank_size = grad.shape[0] // world_size

View File

@ -89,32 +89,46 @@ def get_dist_info():
else: else:
return False, 0, 0, 1 return False, 0, 0, 1
def compute_init(): def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps
"""Basic initialization that we keep doing over and over, so make common.""" """Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required assert device_type in ["cuda", "mps", "cpu"], "Invalid device type atm"
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" if device_type == "cuda":
assert torch.cuda.is_available(), "Your PyTorch installation is not configured for CUDA but device_type is 'cuda'"
if device_type == "mps":
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility # Reproducibility
torch.manual_seed(42) torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42) torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later # skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True) # torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Precision # Precision
if device_type == "cuda":
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional # Distributed setup: Distributed Data Parallel (DDP), optional, and requires CUDA
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp: if ddp and device_type == "cuda":
device = torch.device("cuda", ddp_local_rank) device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device) dist.init_process_group(backend="nccl", device_id=device)
dist.barrier() dist.barrier()
else: else:
device = torch.device("cuda") device = torch.device(device_type) # mps|cpu
if ddp_rank == 0: if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}") logger.info(f"Distributed world size: {ddp_world_size}")

View File

@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches.""" """Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'" assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
bos_token = tokenizer.get_bos_token_id() bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration # scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left token_buffer = deque() # we stream tokens on the right and pop from the left
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
# infinite iterator over document batches # infinite iterator over document batches
def document_batches(): def document_batches():
@ -38,12 +37,12 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
token_buffer.extend(tokens) token_buffer.extend(tokens)
batch_index += 1 batch_index += 1
# Move tokens from the deque into the scratch buffer # Move tokens from the deque into the scratch buffer
for i in range(needed_tokens): tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
scratch[i] = token_buffer.popleft() scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
# Create the inputs/targets as 1D tensors # Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32) inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:] targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async # Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets yield inputs, targets

View File

@ -146,12 +146,11 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution. with caution.
""" """
if maximum_memory_bytes is not None: if platform.uname().system != "Darwin":
# These resource limit calls seem to fail on macOS (Darwin), skip?
import resource import resource
resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
if not platform.uname().system == "Darwin":
resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)) resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))
faulthandler.disable() faulthandler.disable()
@ -225,6 +224,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
rmtree = shutil.rmtree rmtree = shutil.rmtree
rmdir = os.rmdir rmdir = os.rmdir
chdir = os.chdir chdir = os.chdir
unlink = os.unlink
# Disable functionalities that can make destructive changes to the test. # Disable functionalities that can make destructive changes to the test.
reliability_guard(maximum_memory_bytes=maximum_memory_bytes) reliability_guard(maximum_memory_bytes=maximum_memory_bytes)
@ -282,6 +282,7 @@ def _unsafe_execute(code: str, timeout: float, maximum_memory_bytes: Optional[in
shutil.rmtree = rmtree shutil.rmtree = rmtree
os.rmdir = rmdir os.rmdir = rmdir
os.chdir = chdir os.chdir = chdir
os.unlink = unlink
def execute_code( def execute_code(

View File

@ -48,19 +48,6 @@ def apply_rotary_emb(x, cos, sin):
out = out.to(x.dtype) # ensure input/output dtypes match out = out.to(x.dtype) # ensure input/output dtypes match
return out return out
def repeat_kv(x, n_rep):
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module): class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx): def __init__(self, config, layer_idx):
super().__init__() super().__init__()
@ -96,19 +83,16 @@ class CausalSelfAttention(nn.Module):
Tq = q.size(2) # number of queries in this forward pass Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
# Apply MQA: replicate the key/value heads for each query head
nrep = self.n_head // self.n_kv_head
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
# Attention: queries attend to keys/values autoregressively. A few cases to handle: # Attention: queries attend to keys/values autoregressively. A few cases to handle:
enable_gqa = self.n_head != self.n_kv_head # Group Query Attention (GQA): duplicate key/value heads to match query heads if desired
if kv_cache is None or Tq == Tk: if kv_cache is None or Tq == Tk:
# During training (no KV cache), attend as usual with causal attention # During training (no KV cache), attend as usual with causal attention
# And even if there is KV cache, we can still use this simple version when Tq == Tk # And even if there is KV cache, we can still use this simple version when Tq == Tk
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) y = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=enable_gqa)
elif Tq == 1: elif Tq == 1:
# During inference but with a single query in this forward pass: # During inference but with a single query in this forward pass:
# The query has to attend to all the keys/values in the cache # The query has to attend to all the keys/values in the cache
y = F.scaled_dot_product_attention(q, k, v, is_causal=False) y = F.scaled_dot_product_attention(q, k, v, is_causal=False, enable_gqa=enable_gqa)
else: else:
# During inference AND we have a chunk of queries in this forward pass: # During inference AND we have a chunk of queries in this forward pass:
# First, each query attends to all the cached keys/values (i.e. full prefix) # First, each query attends to all the cached keys/values (i.e. full prefix)
@ -118,7 +102,7 @@ class CausalSelfAttention(nn.Module):
attn_mask[:, :prefix_len] = True attn_mask[:, :prefix_len] = True
# Then, causal attention within this chunk # Then, causal attention within this chunk
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)) attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, enable_gqa=enable_gqa)
# Re-assemble the heads side by side and project back to residual stream # Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1) y = y.transpose(1, 2).contiguous().view(B, T, -1)
@ -169,8 +153,6 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False) self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self): def init_weights(self):
self.apply(self._init_weights) self.apply(self._init_weights)
@ -184,6 +166,9 @@ class GPT(nn.Module):
head_dim = self.config.n_embd // self.config.n_head head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):

View File

@ -33,7 +33,7 @@ def evaluate_bpb(model, batches, steps, token_bytes):
loss2d = model(x, y, loss_reduction='none') # (B, T) loss2d = model(x, y, loss_reduction='none') # (B, T)
loss2d = loss2d.view(-1) # flatten loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten y = y.view(-1) # flatten
if (y < 0).any(): if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
# slightly more complex code path if some target tokens are ignore_index (e.g. -1) # slightly more complex code path if some target tokens are ignore_index (e.g. -1)
# any target token < 0 is to be ignored: do NOT index token_bytes with negatives # any target token < 0 is to be ignored: do NOT index token_bytes with negatives
valid = y >= 0 valid = y >= 0

View File

@ -283,6 +283,10 @@ class Report:
# capture bloat data for summary later (the stuff after Bloat header and until \n\n) # capture bloat data for summary later (the stuff after Bloat header and until \n\n)
bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL) bloat_data = re.search(r"### Bloat\n(.*?)\n\n", header_content, re.DOTALL)
bloat_data = bloat_data.group(1) if bloat_data else "" bloat_data = bloat_data.group(1) if bloat_data else ""
else:
start_time = None # will cause us to not write the total wall clock time
bloat_data = "[bloat data missing]"
print(f"Warning: {header_file} does not exist. Did you forget to run `nanochat reset`?")
# process all the individual sections # process all the individual sections
for file_name in EXPECTED_FILES: for file_name in EXPECTED_FILES:
section_file = os.path.join(report_dir, file_name) section_file = os.path.join(report_dir, file_name)

View File

@ -11,6 +11,7 @@ dependencies = [
"numpy==1.26.4", "numpy==1.26.4",
"psutil>=7.1.0", "psutil>=7.1.0",
"regex>=2025.9.1", "regex>=2025.9.1",
"setuptools>=80.9.0",
"tiktoken>=0.11.0", "tiktoken>=0.11.0",
"tokenizers>=0.22.0", "tokenizers>=0.22.0",
"torch>=2.8.0", "torch>=2.8.0",
@ -22,17 +23,6 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"] requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin" build-backend = "maturin"
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.maturin] [tool.maturin]
module-name = "rustbpe" module-name = "rustbpe"
bindings = "pyo3" bindings = "pyo3"
@ -53,3 +43,20 @@ testpaths = ["tests"]
python_files = ["test_*.py"] python_files = ["test_*.py"]
python_classes = ["Test*"] python_classes = ["Test*"]
python_functions = ["test_*"] python_functions = ["test_*"]
# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

View File

@ -1,3 +1,5 @@
#!/bin/bash
# The $1000 tier of nanochat # The $1000 tier of nanochat
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node # Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
# A bit sparser on comments, see speedrun.sh for more detail # A bit sparser on comments, see speedrun.sh for more detail
@ -24,6 +26,7 @@ if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
rm eval_bundle.zip rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR mv eval_bundle $NANOCHAT_BASE_DIR
fi fi
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# train tokenizer on ~4B characters and kick off download of the rest for pretraining # train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16 python -m nanochat.dataset -n 16

View File

@ -15,11 +15,12 @@ import time
import json import json
import random import random
import yaml import yaml
from contextlib import nullcontext
import pandas as pd import pandas as pd
import torch import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, autodetect_device_type
from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task from nanochat.core_eval import evaluate_task
@ -118,16 +119,21 @@ def load_hf_model(hf_path: str, device):
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
def main(): def main():
assert len(sys.argv) in [1, 2], "Usage: python base_eval.py [hf_path]" import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--hf-path', type=str, default=None, help='HuggingFace model path to evaluate')
parser.add_argument('--max-per-task', type=int, default=-1, help='Max examples per task to evaluate (-1 = disable)')
args = parser.parse_args()
# distributed / precision setup # distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Load model and tokenizer from command line or from file system # Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2: if args.hf_path is not None:
# atm assume that if a path is given, it's a huggingface model path # atm assume that if a path is given, it's a huggingface model path
hf_path = sys.argv[1] hf_path = args.hf_path
print0(f"Loading huggingface model from: {hf_path}") print0(f"Loading huggingface model from: {hf_path}")
model, tokenizer = load_hf_model(hf_path, device) model, tokenizer = load_hf_model(hf_path, device)
model_name = hf_path # just for logging model_name = hf_path # just for logging
@ -140,7 +146,7 @@ def main():
# Evaluate the model # Evaluate the model
with autocast_ctx: with autocast_ctx:
out = evaluate_model(model, tokenizer, device) out = evaluate_model(model, tokenizer, device, max_per_task=args.max_per_task)
# Write out the results to a csv file # Write out the results to a csv file
core_metric = None core_metric = None

View File

@ -7,9 +7,10 @@ Example run as:
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
""" """
import os import os
from contextlib import nullcontext
import torch import torch
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup from nanochat.common import compute_init, print0, compute_cleanup, autodetect_device_type
from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@ -20,15 +21,15 @@ device_batch_size = 32
split_tokens = 20*524288 # number of tokens to evaluate per split split_tokens = 20*524288 # number of tokens to evaluate per split
model_tag = None # optional model tag for the output directory name model_tag = None # optional model tag for the output directory name
model_step = None # optional model step for the output directory name model_step = None # optional model step for the output directory name
device_type = "" # cuda|cpu|mps (empty => autodetect)
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
# Load the base model and the tokenizer # Load the base model and the tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step)
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
# Evaluate the loss on each split # Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -37,7 +38,7 @@ steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device) token_bytes = get_token_bytes(device=device)
bpb_results = {} bpb_results = {}
for split_name in ["train", "val"]: for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name) loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx: with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes) bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}") print0(f"{split_name} bpb: {bpb:.4f}")

View File

@ -6,17 +6,22 @@ python base_train.py
or distributed as: or distributed as:
torchrun --nproc_per_node=8 base_train.py torchrun --nproc_per_node=8 base_train.py
If you are only on CPU/Macbook, you'll want to train a much much smaller LLM. Example:
python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 --eval_tokens=512 --core_metric_every=-1 --total_batch_size=512 --num_iterations=20
""" """
import os import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time import time
from contextlib import nullcontext
import wandb import wandb
import torch import torch
from nanochat.gpt import GPT, GPTConfig from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@ -28,6 +33,8 @@ print_banner()
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# User settings # User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture # Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length max_seq_len = 2048 # max context length
@ -46,7 +53,7 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
# Evaluation # Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model sample_every = 2000 # every how many steps to sample from the model
# Output # Output
@ -58,9 +65,12 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Compute init # Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process use_dummy_wandb = run == "dummy" or not master_process
@ -76,7 +86,7 @@ print0(f"Vocab size: {vocab_size:,}")
num_layers = depth num_layers = depth
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases) model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div) num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
num_kv_heads = num_heads # 1:1 MQA ratio num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
print0(f"num_layers: {num_layers}") print0(f"num_layers: {num_layers}")
print0(f"model_dim: {model_dim}") print0(f"model_dim: {model_dim}")
print0(f"num_heads: {num_heads}") print0(f"num_heads: {num_heads}")
@ -97,7 +107,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
with torch.device("meta"): with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs) model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config) model = GPT(model_config)
model.to_empty(device="cuda") model.to_empty(device=device)
model.init_weights() model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
@ -138,8 +148,8 @@ adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val # Initialize the DataLoaders for train/val
base_dir = get_base_dir() base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data") tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train") train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val") build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data x, y = next(train_loader) # kick off load of the very first batch of data
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -198,7 +208,8 @@ for step in range(num_iterations + 1):
# once in a while: estimate the CORE metric (all ranks participate) # once in a while: estimate the CORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape # use the original uncompiled model because the inputs keep changing shape
if last_step or (step > 0 and step % core_metric_every == 0): results = {}
if core_metric_every > 0 and (last_step or (step > 0 and step % core_metric_every == 0)):
model.eval() model.eval()
with autocast_ctx: with autocast_ctx:
results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task) results = evaluate_model(orig_model, tokenizer, device, max_per_task=core_metric_max_per_task)
@ -257,7 +268,7 @@ for step in range(num_iterations + 1):
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# single training step # single training step
# evaluate the gradient # evaluate the gradient
torch.cuda.synchronize() synchronize()
t0 = time.time() t0 = time.time()
for micro_step in range(grad_accum_steps): for micro_step in range(grad_accum_steps):
with autocast_ctx: with autocast_ctx:
@ -280,7 +291,7 @@ for step in range(num_iterations + 1):
for opt in optimizers: for opt in optimizers:
opt.step() opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
torch.cuda.synchronize() synchronize()
t1 = time.time() t1 = time.time()
dt = t1 - t0 dt = t1 - t0
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -308,7 +319,7 @@ for step in range(num_iterations + 1):
}) })
# print a few more stats # print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}") print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
@ -330,11 +341,11 @@ get_report().log(section="Base model training", data=[
{ # stats about training outcomes { # stats about training outcomes
"Minimum validation bpb": min_val_bpb, "Minimum validation bpb": min_val_bpb,
"Final validation bpb": val_bpb, "Final validation bpb": val_bpb,
"CORE metric estimate": results["core_metric"], "CORE metric estimate": results.get("core_metric", None),
"MFU %": f"{mfu:.2f}{'*' if promised_is_estimated else ''}%", "MFU %": f"{mfu:.2f}{'*' if promised_is_estimated else ''}%",
"Total training flops": f"{flops_so_far:e}", "Total training flops": f"{flops_so_far:e}",
"Total training time": f"{total_training_time/60:.2f}m", "Total training time": f"{total_training_time/60:.2f}m",
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB", "Peak memory usage": f"{get_max_memory() / 1024 / 1024:.2f}MiB",
} }
]) ])

View File

@ -6,7 +6,8 @@ python -m scripts.chat_cli -i mid
""" """
import argparse import argparse
import torch import torch
from nanochat.common import compute_init from nanochat.common import compute_init, autodetect_device_type
from contextlib import nullcontext
from nanochat.engine import Engine from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
@ -17,11 +18,16 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back')
parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
args = parser.parse_args() args = parser.parse_args()
# Init the model and tokenizer # Init the model and tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine # Special tokens for the chat state machine

View File

@ -10,11 +10,12 @@ torchrun --nproc_per_node=8 -m scripts.chat_eval -- -a ARC-Easy
import argparse import argparse
from functools import partial from functools import partial
from contextlib import nullcontext
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0 from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine from nanochat.engine import Engine
@ -191,11 +192,13 @@ if __name__ == "__main__":
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate') parser.add_argument('-x', '--max-problems', type=int, default=None, help='Max problems to evaluate')
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
args = parser.parse_args() args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)

View File

@ -15,8 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import wandb import wandb
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine from nanochat.engine import Engine
@ -26,6 +27,7 @@ from tasks.common import TaskMixture
from tasks.arc import ARC from tasks.arc import ARC
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# SFT Hyperparameters # SFT Hyperparameters
@ -35,11 +37,12 @@ source = "mid" # base|mid , which checkpoint to load the model from (base model
model_tag = None # model tag to load the model from (base model or midtrained model) model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model)
# compute/precision # compute/precision
device_type = "" # cuda|cpu|mps (empty => autodetect)
dtype = "bfloat16" dtype = "bfloat16"
device_batch_size = 4 # max to avoid OOM device_batch_size = 4 # max to avoid OOM
# optimization # optimization
num_epochs = 1 num_epochs = 1
max_iterations = -1 # override number of iterations (-1 = use num_epochs * num_iterations) num_iterations = -1 # override number of iterations (-1 = disable, use num_epochs to derive it)
target_examples_per_step = 32 target_examples_per_step = 32
unembedding_lr = 0.004 unembedding_lr = 0.004
embedding_lr = 0.2 embedding_lr = 0.2
@ -50,6 +53,7 @@ init_lr_frac = 0.02
eval_every = 100 eval_every = 100
eval_steps = 100 eval_steps = 100
eval_metrics_every = 200 eval_metrics_every = 200
eval_metrics_max_problems = 1024
# now allow CLI to override the settings via the configurator lol # now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
@ -57,10 +61,11 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Compute init # Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process use_dummy_wandb = run == "dummy" or not master_process
@ -74,13 +79,14 @@ engine = Engine(model, tokenizer) # will be used for inline model evaluation onl
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Task data mixture we'll train on # Task data mixture we'll train on
identity_conversations_filepath = os.path.join(get_base_dir(), "identity_conversations.jsonl")
train_ds = TaskMixture([ train_ds = TaskMixture([
ARC(subset="ARC-Easy", split="train"), # 2.3K rows ARC(subset="ARC-Easy", split="train"), # 2.3K rows
ARC(subset="ARC-Challenge", split="train"), # 1.1K rows ARC(subset="ARC-Challenge", split="train"), # 1.1K rows
GSM8K(subset="main", split="train"), # 8K rows GSM8K(subset="main", split="train"), # 8K rows
SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk SmolTalk(split="train", stop=10_000), # 10K rows of smoltalk
]) # 2.3K + 1.1K + 8K + 10K = 21.4K rows CustomJSON(filepath=identity_conversations_filepath), # 1K rows of synthetic identity conversations
]) # 2.3K + 1.1K + 8K + 10K + 1K = 22.4K rows
val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it) val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don't actually use all of it)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@ -126,10 +132,10 @@ assert target_examples_per_step % examples_per_step == 0, "Target examples per s
grad_accum_steps = target_examples_per_step // examples_per_step grad_accum_steps = target_examples_per_step // examples_per_step
print0(f"=> Setting grad accum steps: {grad_accum_steps}") print0(f"=> Setting grad accum steps: {grad_accum_steps}")
num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs if num_iterations == -1:
if max_iterations >= 0 and num_iterations > max_iterations: # derive num_iterations from num_epochs and the size of the dataset
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}") assert num_epochs > 0, "num_epochs must be positive if num_iterations is -1"
num_iterations = max_iterations num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
@ -189,8 +195,8 @@ for step in range(num_iterations):
metrics = {} metrics = {}
with torch.no_grad(), autocast_ctx: with torch.no_grad(), autocast_ctx:
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size # note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024) metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=eval_metrics_max_problems)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items()) metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}") print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({ wandb_run.log({

View File

@ -44,8 +44,8 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass from dataclasses import dataclass
from contextlib import nullcontext
from nanochat.common import compute_init from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine from nanochat.engine import Engine
@ -69,6 +69,8 @@ parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default m
parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load') parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag to load')
parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-s', '--step', type=int, default=None, help='Step to load')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('-d', '--dtype', type=str, default='bfloat16', choices=['float32', 'bfloat16'])
parser.add_argument('--device-type', type=str, default='', choices=['cuda', 'cpu', 'mps'], help='Device type for evaluation: cuda|cpu|mps. empty => autodetect')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args() args = parser.parse_args()
@ -80,7 +82,9 @@ logging.basicConfig(
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
@dataclass @dataclass
class Worker: class Worker:
@ -95,21 +99,33 @@ class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU.""" """Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None): def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count() if num_gpus is None:
if device_type == "cuda":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
self.num_gpus = num_gpus
self.workers: List[Worker] = [] self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue() self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None): async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU.""" """Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...") print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus): for gpu_id in range(self.num_gpus):
if device_type == "cuda":
device = torch.device(f"cuda:{gpu_id}") device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...") print(f"Loading model on GPU {gpu_id}...")
else:
device = torch.device(device_type) # e.g. cpu|mps
print(f"Loading model on {device_type}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer) engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
worker = Worker( worker = Worker(
gpu_id=gpu_id, gpu_id=gpu_id,

View File

@ -15,8 +15,8 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import time import time
import wandb import wandb
import torch import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_token_bytes from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb from nanochat.loss_eval import evaluate_bpb
@ -28,12 +28,15 @@ from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K from tasks.gsm8k import GSM8K
from tasks.mmlu import MMLU from tasks.mmlu import MMLU
from tasks.smoltalk import SmolTalk from tasks.smoltalk import SmolTalk
from tasks.customjson import CustomJSON
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
device_type = "" # cuda|cpu|mps (empty => autodetect)
model_tag = None # model tag to load the model from (base model or midtrained model) model_tag = None # model tag to load the model from (base model or midtrained model)
step = None # step to load the model from (base model or midtrained model) step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16" dtype = "bfloat16"
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
max_seq_len = 2048 max_seq_len = 2048
device_batch_size = 32 device_batch_size = 32
unembedding_lr = 0.004 unembedding_lr = 0.004
@ -41,7 +44,7 @@ embedding_lr = 0.2
matrix_lr = 0.02 matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0 weight_decay = 0.0
eval_every = 150 eval_every = 150 # -1 = disable
eval_tokens = 20*524288 eval_tokens = 20*524288
total_batch_size = 524288 total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
@ -51,10 +54,12 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Compute init # Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() if device_type == "" else device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
# wandb logging init # wandb logging init
use_dummy_wandb = run == "dummy" or not master_process use_dummy_wandb = run == "dummy" or not master_process
@ -93,10 +98,13 @@ for opt in optimizers:
# Midtraining data mixture and DataLoader # Midtraining data mixture and DataLoader
base_dir = get_base_dir() base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([ train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations SmolTalk(split="train"), # 460K rows of general conversations
MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE MMLU(subset="auxiliary_train", split="train"), # 100K rows of multiple choice problems drawn from ARC, MC_TEST, OBQA, RACE
GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use GSM8K(subset="main", split="train"), # 8K rows teaching simple math and (calculator) tool use
CustomJSON(filepath=identity_conversations_filepath), # 1000 rows of synthetic identity conversations
CustomJSON(filepath=identity_conversations_filepath), # let's do 2 epochs of these
]) # total: 460K + 100K + 8K = 568K rows ]) # total: 460K + 100K + 8K = 568K rows
val_dataset = TaskMixture([ val_dataset = TaskMixture([
SmolTalk(split="test"), # 24K rows in test set SmolTalk(split="test"), # 24K rows in test set
@ -118,6 +126,7 @@ def mid_data_generator(split):
token_buffer = deque() token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
it = 0 # iteration counter
while True: while True:
# Accumulate enough tokens for one iteration before yielding # Accumulate enough tokens for one iteration before yielding
while len(token_buffer) < needed_tokens: while len(token_buffer) < needed_tokens:
@ -129,6 +138,10 @@ def mid_data_generator(split):
cursor -= dataset_size # wrap around for another epoch cursor -= dataset_size # wrap around for another epoch
if split == "train": if split == "train":
last_step = True # toggle last_step to True, which will terminate the training loop last_step = True # toggle last_step to True, which will terminate the training loop
# Stopping condition to respect num_iterations, if given
it += 1
if num_iterations > 0 and it >= num_iterations:
last_step = True # toggle last_step to True, which will terminate the training loop
# Build up inputs/targets and yield # Build up inputs/targets and yield
for i in range(needed_tokens): for i in range(needed_tokens):
scratch[i] = token_buffer.popleft() scratch[i] = token_buffer.popleft()
@ -137,6 +150,9 @@ def mid_data_generator(split):
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True) inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True) targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
if split == "train": if split == "train":
if num_iterations > 0:
approx_progress = it / num_iterations # calculate progress from the max number of iterations
else:
approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset approx_progress = cursor / dataset_size # approximate progress as a fraction of the dataset
yield inputs, targets yield inputs, targets
@ -173,7 +189,7 @@ while True:
last_step = bool(last_step_tensor.item()) last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate) # once in a while: evaluate the val bpb (all ranks participate)
if last_step or step % eval_every == 0: if eval_every > 0 and (last_step or step % eval_every == 0):
model.eval() model.eval()
val_loader = build_val_loader() val_loader = build_val_loader()
eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size) eval_steps = eval_tokens // (device_batch_size * max_seq_len * ddp_world_size)
@ -220,7 +236,7 @@ while True:
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# single training step # single training step
# evaluate the gradient # evaluate the gradient
torch.cuda.synchronize() synchronize()
t0 = time.time() t0 = time.time()
for micro_step in range(grad_accum_steps): for micro_step in range(grad_accum_steps):
with autocast_ctx: with autocast_ctx:
@ -241,7 +257,7 @@ while True:
for opt in optimizers: for opt in optimizers:
opt.step() opt.step()
model.zero_grad(set_to_none=True) model.zero_grad(set_to_none=True)
torch.cuda.synchronize() synchronize()
t1 = time.time() t1 = time.time()
dt = t1 - t0 dt = t1 - t0
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
@ -273,7 +289,7 @@ while True:
}) })
# print a few more stats # print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}") print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

View File

@ -101,6 +101,10 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice) # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model # run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid

65
tasks/customjson.py Normal file
View File

@ -0,0 +1,65 @@
"""
CustomJSON task for loading conversations from JSONL files.
Each line in the JSONL file should be a JSON array of messages.
"""
import os
import json
from tasks.common import Task
class CustomJSON(Task):
"""
Load conversations from a JSONL file.
Each line should be a JSON array of message objects with 'role' and 'content' fields.
Example line: [{"role":"user","content":"Hi"},{"role":"assistant","content":"Hello"}]
"""
def __init__(self, filepath, **kwargs):
super().__init__(**kwargs)
self.filepath = filepath
self.conversations = []
# Load all conversations from the JSONL file
if not os.path.exists(filepath):
# Helpful error message due to recent change. Will be removed in the future.
print("-" * 80)
print(f"Warning: File {filepath} does not exist")
print("HINT (Oct 21 2025)")
print("If you recently did a git pull and suddely see this, it might be due to the new addition of identity conversations")
print("See this discussion for more details: https://github.com/karpathy/nanochat/discussions/139")
print("Quick fix: simply run the following command to download the file and you're done:")
print(f"curl -L -o {filepath} https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl")
print("-" * 80)
else:
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if not line: # skip empty lines
continue
messages = json.loads(line)
# Validate the conversation structure
assert isinstance(messages, list), f"Expected list of messages, got {type(messages)}"
assert len(messages) >= 2, f"Conversation must have at least 2 messages, got {len(messages)}"
# Validate message structure and alternating roles
for i, message in enumerate(messages):
assert "role" in message, f"Message {i} missing 'role' field"
assert "content" in message, f"Message {i} missing 'content' field"
expected_role = "user" if i % 2 == 0 else "assistant"
assert message["role"] == expected_role, f"Message {i} has role {message['role']} but should be {expected_role}"
assert isinstance(message["content"], str), f"Message {i} content must be a string"
self.conversations.append(messages)
self.length = len(self.conversations)
def num_examples(self):
return self.length
def get_example(self, index):
messages = self.conversations[index]
conversation = {
"messages": messages,
}
return conversation

View File

@ -2002,3 +2002,4 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" }, { url = "https://files.pythonhosted.org/packages/94/c3/b2e9f38bc3e11191981d57ea08cab2166e74ea770024a646617c9cddd9f6/yarl-1.20.1-cp313-cp313t-win_amd64.whl", hash = "sha256:541d050a355bbbc27e55d906bc91cb6fe42f96c01413dd0f4ed5a5240513874f", size = 93003, upload-time = "2025-06-10T00:45:27.752Z" },
{ url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" }, { url = "https://files.pythonhosted.org/packages/b4/2d/2345fce04cfd4bee161bf1e7d9cdc702e3e16109021035dbb24db654a622/yarl-1.20.1-py3-none-any.whl", hash = "sha256:83b8eb083fe4683c6115795d9fc1cfaf2cbbefb19b3a1cb68f6527460f483a77", size = 46542, upload-time = "2025-06-10T00:46:07.521Z" },
] ]