nanochat/speedrun.sh
google-labs-jules[bot] 68148b1bf3 Export TRITON_HIP_LLD_PATH in speedrun.sh for AMD ROCm
When running on AMD ROCm using `uv`-installed packages (`rocm-sdk-core`), the `ld.lld` linker is not in the default `/opt/rocm/llvm/bin/` location expected by `pytorch-triton-rocm`. This causes `InductorError` during `torch.compile`.

This change updates `speedrun.sh` to dynamically find the `ld.lld` binary within the active python environment's site-packages (`_rocm_sdk_core`) and export the `TRITON_HIP_LLD_PATH` environment variable, allowing Triton to locate the linker correctly.
2025-11-23 08:19:07 +00:00

180 lines
8.4 KiB
Bash

#!/bin/bash
set -e
# This script is the "Best ChatGPT clone that $100 can buy",
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
# 1) Example launch (simplest):
# bash speedrun.sh
# 2) Example launch in a screen session (because the run takes ~4 hours):
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# 3) Example launch with wandb logging, but see below for setting up wandb first:
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
# Default intermediate artifacts directory is in ~/.cache/nanochat
export OMP_NUM_THREADS=1
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------
# Python venv setup with uv
# install uv (if not already installed)
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
# create a .venv local virtual environment (if it doesn't exist)
[ -d ".venv" ] || uv venv
# install the repo dependencies
# Detect hardware to install the correct torch version
if command -v nvidia-smi &> /dev/null; then
echo "NVIDIA GPU detected. Installing CUDA dependencies..."
EXTRAS="gpu"
elif [ -e /dev/kfd ]; then
echo "AMD GPU detected. Installing ROCm dependencies..."
EXTRAS="amd"
else
echo "No dedicated GPU detected. Installing CPU dependencies..."
EXTRAS="cpu"
fi
uv sync --extra $EXTRAS
# activate venv so that `python` uses the project's venv instead of system python
source .venv/bin/activate
# Explicitly uninstall triton if present, as it conflicts with pytorch-triton-rocm
# and can cause "ImportError: cannot import name 'Config' from 'triton'" errors
# if the NVIDIA version of triton (e.g. 3.4.0) is accidentally installed.
if [ "$EXTRAS" == "amd" ]; then
uv pip uninstall -q triton || true
# Uninstalling triton may have deleted the shared 'triton' directory, breaking pytorch-triton-rocm.
# Reinstall pytorch-triton-rocm to ensure it's intact.
uv pip install --force-reinstall --index-url https://repo.amd.com/rocm/whl/gfx1151 pytorch-triton-rocm
# Find and export the path to ld.lld from rocm-sdk-core if available, as torch.compile/triton needs it
ROCM_LLD_PATH=$(python -c "import sysconfig; import os; p = f\"{sysconfig.get_paths()['purelib']}/_rocm_sdk_core/lib/llvm/bin/ld.lld\"; print(p) if os.path.exists(p) else print('')")
if [ -n "$ROCM_LLD_PATH" ]; then
export TRITON_HIP_LLD_PATH=$ROCM_LLD_PATH
echo "Exported TRITON_HIP_LLD_PATH=$TRITON_HIP_LLD_PATH"
fi
fi
# -----------------------------------------------------------------------------
# wandb setup
# If you wish to use wandb for logging (it's nice!, recommended).
# 1) Make sure to first log in to wandb, e.g. run:
# `wandb login`
# 2) Set the WANDB_RUN environment variable when running this script, e.g.:
# `WANDB_RUN=d26 bash speedrun.sh`
if [ -z "$WANDB_RUN" ]; then
# by default use "dummy" : it's handled as a special case, skips logging to wandb
WANDB_RUN=dummy
fi
# -----------------------------------------------------------------------------
# During the course of the run, we will be writing markdown reports to the report/
# directory in the base dir. This command clears it out and writes a header section
# with a bunch of system info and a timestamp that marks the start of the run.
python -m nanochat.report reset
# -----------------------------------------------------------------------------
# Tokenizer
# Install Rust / Cargo
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
# Build the rustbpe Tokenizer
# use --no-sync to avoid re-installing triton on AMD, which we just uninstalled
uv run --no-sync --extra $EXTRAS maturin develop --release --manifest-path rustbpe/Cargo.toml
# Download the first ~2B characters of pretraining dataset
# look at dev/repackage_data_reference.py for details on how this data was prepared
# each data shard is ~250M chars
# so we download 2e9 / 250e6 = 8 data shards at this point
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
python -m nanochat.dataset -n 8
# Immediately also kick off downloading more shards in the background while tokenizer trains
# See comment below for why 240 is the right number here
python -m nanochat.dataset -n 240 &
DATASET_DOWNLOAD_PID=$!
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
python -m scripts.tok_train --max_chars=2000000000
# evaluate the tokenizer (report compression ratio etc.)
python -m scripts.tok_eval
# -----------------------------------------------------------------------------
# Base model (pretraining)
# The d20 model is 561M parameters.
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
# Round up to 240 for safety. At ~100MB/shard, this downloads ~24GB of data to disk.
# (The total number of shards available in the entire dataset is 1822.)
echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# Number of processes/GPUs to use
# Auto-detect if we have GPUs (including ROCm)
if python -c "import torch; exit(0) if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip) else exit(1)"; then
# Detected GPU capability. Now get the actual count to avoid launching too many processes.
NPROC_PER_NODE=$(python -c "import torch; print(torch.cuda.device_count())")
# If for some reason it returns 0 (e.g. weird ROCm state), default to 1 to be safe.
if [ "$NPROC_PER_NODE" -eq "0" ]; then
echo "GPU detected but torch.cuda.device_count() is 0. Defaulting to NPROC_PER_NODE=1."
NPROC_PER_NODE=1
else
echo "Detected $NPROC_PER_NODE GPUs."
fi
else
echo "No GPU detected. Defaulting to NPROC_PER_NODE=1 to avoid OOM and using multi-threading."
NPROC_PER_NODE=1
# If running on CPU, let PyTorch use all available cores for the single process
unset OMP_NUM_THREADS
fi
# pretrain the d20 model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# download 2.3MB of synthetic identity conversations to impart a personality to nanochat
# see dev/gen_sft_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
# even better, chat with your model over a pretty WebUI ChatGPT style
# python -m scripts.chat_web
# -----------------------------------------------------------------------------
# Reinforcement Learning. Optional, and currently only on GSM8K
# (optional)
# run reinforcement learning
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
# report.md is the output and will be copied to current directory for convenience
python -m nanochat.report generate