mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-15 10:39:08 +00:00
Merge a305bd9612 into 0aaca56805
This commit is contained in:
commit
669d9e1009
|
|
@ -26,6 +26,9 @@ def _patch_missing_config_keys(model_config_kwargs):
|
|||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
if "use_xsa" not in model_config_kwargs:
|
||||
model_config_kwargs["use_xsa"] = False
|
||||
log0(f"Patching missing use_xsa in model config to False")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
use_xsa: bool = False
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -70,6 +71,8 @@ class CausalSelfAttention(nn.Module):
|
|||
self.n_kv_head = config.n_kv_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
self.use_xsa = config.use_xsa
|
||||
self.xsa = ExclusiveSelfAttention()
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
|
|
@ -120,12 +123,22 @@ class CausalSelfAttention(nn.Module):
|
|||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
if self.use_xsa:
|
||||
y = self.xsa.XSA(y, v)
|
||||
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
||||
class ExclusiveSelfAttention(nn.Module):
|
||||
def XSA(self, y, v):
|
||||
Vn = F.normalize(v, dim=-1)
|
||||
Z = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
|
||||
return Z
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
|
|
|||
43
runs/pace_stage1_tokenizer.sh
Normal file
43
runs/pace_stage1_tokenizer.sh
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=24
|
||||
#SBATCH --mem=64G
|
||||
#SBATCH -t 2:00:00
|
||||
#SBATCH -J nanochat-stage1-tokenizer
|
||||
#SBATCH -o runs/logs/stage1_%j.out
|
||||
#SBATCH -e runs/logs/stage1_%j.err
|
||||
|
||||
# Stage 1
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p "$NANOCHAT_BASE_DIR"
|
||||
mkdir -p runs/logs
|
||||
|
||||
echo "=== Stage 1: Tokenizer ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: ${XSA:-FALSE}"
|
||||
echo "Started: $(date)"
|
||||
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
python -m nanochat.report reset
|
||||
python -m nanochat.dataset -n 8
|
||||
python -m nanochat.dataset -n 170 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
|
||||
python -m scripts.tok_train
|
||||
python -m scripts.tok_eval
|
||||
|
||||
echo "Waiting for full dataset download..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
echo "=== Stage 1 complete: $(date) ==="
|
||||
echo "Dataset and tokenizer ready in $NANOCHAT_BASE_DIR"
|
||||
48
runs/pace_stage2a_pretrain.sh
Normal file
48
runs/pace_stage2a_pretrain.sh
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH -p ice-gpu
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --gres=gpu:2
|
||||
#SBATCH --constraint="gpu-h100|gpu-h200"
|
||||
#SBATCH --mem-per-gpu=48G
|
||||
#SBATCH -t 3:55:00
|
||||
#SBATCH -J nanochat-stage2a
|
||||
#SBATCH -o runs/logs/stage2a_%j.out
|
||||
#SBATCH -e runs/logs/stage2a_%j.err
|
||||
|
||||
# Stage 2a
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2a: Pretraining (chunk 1) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "WANDB_RUN: $WANDB_RUN"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
mkdir -p "$CHECKPOINT_DIR"
|
||||
touch "$DONE_MARKER"
|
||||
echo "=== Stage 2a complete: $(date) ==="
|
||||
74
runs/pace_stage2b_pretrain.sh
Normal file
74
runs/pace_stage2b_pretrain.sh
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH -p ice-gpu
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --gres=gpu:2
|
||||
#SBATCH --constraint="gpu-h100|gpu-h200"
|
||||
#SBATCH --mem-per-gpu=48G
|
||||
#SBATCH -t 3:55:00
|
||||
#SBATCH -J nanochat-stage2b
|
||||
#SBATCH -o runs/logs/stage2b_%j.out
|
||||
#SBATCH -e runs/logs/stage2b_%j.err
|
||||
|
||||
# Stage 2b
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2b: Pretraining (chunk 2 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
echo "Training already complete (marker: $DONE_MARKER). Nothing to do."
|
||||
echo "=== Stage 2b skipped: $(date) ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
LAST_STEP=$(python -c "
|
||||
import glob, os, sys
|
||||
files = glob.glob('${CHECKPOINT_DIR}/model_*.pt')
|
||||
if not files:
|
||||
print(0); sys.exit(0)
|
||||
print(max(int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in files))
|
||||
")
|
||||
|
||||
if [ "$LAST_STEP" -eq 0 ]; then
|
||||
echo "No checkpoint found — starting from scratch"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
mkdir -p "$CHECKPOINT_DIR"
|
||||
touch "$DONE_MARKER"
|
||||
echo "=== Stage 2b complete: $(date) ==="
|
||||
74
runs/pace_stage2c_pretrain.sh
Normal file
74
runs/pace_stage2c_pretrain.sh
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH -p ice-gpu
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --gres=gpu:2
|
||||
#SBATCH --constraint="gpu-h100|gpu-h200"
|
||||
#SBATCH --mem-per-gpu=48G
|
||||
#SBATCH -t 3:55:00
|
||||
#SBATCH -J nanochat-stage2c
|
||||
#SBATCH -o runs/logs/stage2c_%j.out
|
||||
#SBATCH -e runs/logs/stage2c_%j.err
|
||||
|
||||
# Stage 2c
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2c: Pretraining (chunk 3 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
echo "Training already complete (marker: $DONE_MARKER). Nothing to do."
|
||||
echo "=== Stage 2c skipped: $(date) ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
LAST_STEP=$(python -c "
|
||||
import glob, os, sys
|
||||
files = glob.glob('${CHECKPOINT_DIR}/model_*.pt')
|
||||
if not files:
|
||||
print(0); sys.exit(0)
|
||||
print(max(int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in files))
|
||||
")
|
||||
|
||||
if [ "$LAST_STEP" -eq 0 ]; then
|
||||
echo "No checkpoint found — starting from scratch"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
mkdir -p "$CHECKPOINT_DIR"
|
||||
touch "$DONE_MARKER"
|
||||
echo "=== Stage 2c complete: $(date) ==="
|
||||
74
runs/pace_stage2d_pretrain.sh
Normal file
74
runs/pace_stage2d_pretrain.sh
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH -p ice-gpu
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --gres=gpu:2
|
||||
#SBATCH --constraint="gpu-h100|gpu-h200"
|
||||
#SBATCH --mem-per-gpu=48G
|
||||
#SBATCH -t 3:55:00
|
||||
#SBATCH -J nanochat-stage2d
|
||||
#SBATCH -o runs/logs/stage2d_%j.out
|
||||
#SBATCH -e runs/logs/stage2d_%j.err
|
||||
|
||||
# Stage 2d
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2d: Pretraining (chunk 4 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
echo "Training already complete (marker: $DONE_MARKER). Nothing to do."
|
||||
echo "=== Stage 2d skipped: $(date) ==="
|
||||
exit 0
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
LAST_STEP=$(python -c "
|
||||
import glob, os, sys
|
||||
files = glob.glob('${CHECKPOINT_DIR}/model_*.pt')
|
||||
if not files:
|
||||
print(0); sys.exit(0)
|
||||
print(max(int(os.path.basename(f).split('_')[-1].split('.')[0]) for f in files))
|
||||
")
|
||||
|
||||
if [ "$LAST_STEP" -eq 0 ]; then
|
||||
echo "No checkpoint found — starting from scratch"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
mkdir -p "$CHECKPOINT_DIR"
|
||||
touch "$DONE_MARKER"
|
||||
echo "=== Stage 2d complete: $(date) ==="
|
||||
55
runs/pace_stage3_sft.sh
Normal file
55
runs/pace_stage3_sft.sh
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
#SBATCH -N 1
|
||||
#SBATCH -p ice-gpu
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --cpus-per-task=8
|
||||
#SBATCH --gres=gpu:2
|
||||
#SBATCH --constraint="gpu-h100|gpu-h200"
|
||||
#SBATCH --mem-per-gpu=48G
|
||||
#SBATCH -t 3:55:00
|
||||
#SBATCH -J nanochat-stage3-sft
|
||||
#SBATCH -o runs/logs/stage3_%j.out
|
||||
#SBATCH -e runs/logs/stage3_%j.err
|
||||
|
||||
# Stage 3
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
|
||||
echo "=== Stage 3: Eval + SFT ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "WANDB_RUN: $WANDB_RUN"
|
||||
echo "XSA: ${XSA:-FALSE}"
|
||||
echo "Started: $(date)"
|
||||
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
if [ ! -f "$DONE_MARKER" ]; then
|
||||
echo "ERROR: pretraining did not finish — missing $DONE_MARKER"
|
||||
echo "Re-run pretrain chunks 2a–2d until the marker is created before running stage 3."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source .venv/bin/activate
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.base_eval -- \
|
||||
--device-batch-size=16
|
||||
|
||||
curl -L -o "$NANOCHAT_BASE_DIR/identity_conversations.jsonl" \
|
||||
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.chat_sft -- \
|
||||
--device-batch-size=16 \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 -m scripts.chat_eval -- -i sft
|
||||
|
||||
python -m nanochat.report generate
|
||||
|
||||
echo "=== Stage 3 complete: $(date) ==="
|
||||
87
runs/pace_submit.sh
Normal file
87
runs/pace_submit.sh
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Pipeline:
|
||||
# Stage 1 — CPU: tokenizer + dataset
|
||||
# Stage 2a — GPU: pretraining chunk 1
|
||||
# Stage 2b — GPU: auto-resume chunk 2
|
||||
# Stage 2c — GPU: auto-resume chunk 3
|
||||
# Stage 2d — GPU: auto-resume chunk 4
|
||||
# Stage 3 — GPU: base eval + SFT + chat eval + report
|
||||
#
|
||||
# Usage (from repo root):
|
||||
# bash runs/pace_submit.sh
|
||||
#
|
||||
# Optional W&B logging:
|
||||
# WANDB_RUN=my-run bash runs/pace_submit.sh
|
||||
#
|
||||
# Optional XSA attention:
|
||||
# XSA=TRUE bash runs/pace_submit.sh
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
||||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
export WANDB_RUN
|
||||
export XSA
|
||||
|
||||
echo "Submitting nanochat full pipeline..."
|
||||
echo "WANDB_RUN=$WANDB_RUN"
|
||||
echo "XSA=$XSA"
|
||||
echo ""
|
||||
|
||||
# Stage 1
|
||||
JOB1=$(sbatch --parsable \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage1_tokenizer.sh)
|
||||
echo "Stage 1 submitted: job $JOB1 (tokenizer + dataset)"
|
||||
|
||||
# Stage 2a
|
||||
JOB2A=$(sbatch --parsable \
|
||||
--dependency=afterok:$JOB1 \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2a_pretrain.sh)
|
||||
echo "Stage 2a submitted: job $JOB2A (pretrain chunk 1, depends on $JOB1)"
|
||||
|
||||
# Stage 2b
|
||||
JOB2B=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2A \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2b_pretrain.sh)
|
||||
echo "Stage 2b submitted: job $JOB2B (pretrain chunk 2, depends on $JOB2A)"
|
||||
|
||||
# Stage 2c
|
||||
JOB2C=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2B \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2c_pretrain.sh)
|
||||
echo "Stage 2c submitted: job $JOB2C (pretrain chunk 3, depends on $JOB2B)"
|
||||
|
||||
# Stage 2d
|
||||
JOB2D=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2C \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2d_pretrain.sh)
|
||||
echo "Stage 2d submitted: job $JOB2D (pretrain chunk 4, depends on $JOB2C)"
|
||||
|
||||
# Stage 3
|
||||
JOB3=$(sbatch --parsable \
|
||||
--dependency=afterok:$JOB2D \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage3_sft.sh)
|
||||
echo "Stage 3 submitted: job $JOB3 (eval + SFT, depends on $JOB2D)"
|
||||
|
||||
echo ""
|
||||
echo "All jobs queued. Monitor with:"
|
||||
echo " squeue -u $USER"
|
||||
echo " tail -f runs/logs/stage1_${JOB1}.out"
|
||||
echo " tail -f runs/logs/stage2a_${JOB2A}.out"
|
||||
echo " tail -f runs/logs/stage2b_${JOB2B}.out"
|
||||
echo " tail -f runs/logs/stage2c_${JOB2C}.out"
|
||||
echo " tail -f runs/logs/stage2d_${JOB2D}.out"
|
||||
echo " tail -f runs/logs/stage3_${JOB3}.out"
|
||||
echo ""
|
||||
echo "To cancel everything:"
|
||||
echo " scancel $JOB1 $JOB2A $JOB2B $JOB2C $JOB2D $JOB3"
|
||||
196
runs/runpod/d12.sh
Executable file
196
runs/runpod/d12.sh
Executable file
|
|
@ -0,0 +1,196 @@
|
|||
#!/usr/bin/env bash
|
||||
# d12 baseline runner. Runs INSIDE a RunPod pod.
|
||||
# Pipeline: tokenizer -> base_train -> base_eval -> SFT -> chat_eval -> report.
|
||||
# On exit:
|
||||
# success -> upload final cache to HF, self-delete pod
|
||||
# failure -> upload logs + report dir to HF under _failures/, self-delete pod
|
||||
# (set UPLOAD_FAILURE_CACHE=1 to also dump partial cache for offline debug)
|
||||
#
|
||||
# Required env (passed via runpodctl --env at pod-create):
|
||||
# HF_TOKEN, WANDB_API_KEY
|
||||
# Optional env:
|
||||
# WANDB_RUN default: d12
|
||||
# NANOCHAT_REPO default: Team-XSA/nanochat
|
||||
# NANOCHAT_REF default: dev
|
||||
# HF_REPO default: haydenfree/nanochat-d12-baseline
|
||||
# BACKUP_INTERVAL default: 300 (seconds between background HF uploads)
|
||||
# UPLOAD_FAILURE_CACHE default: 0
|
||||
# Auto-set by RunPod:
|
||||
# RUNPOD_POD_ID, RUNPOD_API_KEY (pod-scoped)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
NANOCHAT_REPO="${NANOCHAT_REPO:-Team-XSA/nanochat}"
|
||||
NANOCHAT_REF="${NANOCHAT_REF:-dev}"
|
||||
HF_REPO="${HF_REPO:-haydenfree/nanochat-d12-baseline}"
|
||||
WANDB_RUN="${WANDB_RUN:-d12}"
|
||||
BACKUP_INTERVAL="${BACKUP_INTERVAL:-300}"
|
||||
UPLOAD_FAILURE_CACHE="${UPLOAD_FAILURE_CACHE:-0}"
|
||||
|
||||
WORKDIR="/workspace/nanochat"
|
||||
LOG_FILE="/workspace/runner.log"
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
BACKUP_PID=""
|
||||
|
||||
mkdir -p /workspace
|
||||
# NOTE: dockerStartCmd already redirects stdout/stderr to $LOG_FILE.
|
||||
# Don't add a second tee here — would write every line twice.
|
||||
|
||||
echo "[runner] $(date -Iseconds) starting on pod=$RUNPOD_POD_ID"
|
||||
echo "[runner] repo=$NANOCHAT_REPO ref=$NANOCHAT_REF hf_repo=$HF_REPO wandb_run=$WANDB_RUN"
|
||||
|
||||
# Bootstrap huggingface_hub system-wide so the cleanup trap can upload logs
|
||||
# even if we fail before the venv is activated.
|
||||
{ pip3 install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
python3 -m pip install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
echo "[runner] WARN: could not pre-install huggingface_hub; cleanup uploads may fail"; } || true
|
||||
|
||||
cleanup() {
|
||||
local rc=$?
|
||||
set +e
|
||||
echo "[runner] cleanup: exit code $rc at $(date -Iseconds)"
|
||||
if [ -n "$BACKUP_PID" ] && kill -0 "$BACKUP_PID" 2>/dev/null; then
|
||||
kill "$BACKUP_PID" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
local TS
|
||||
TS=$(date -u +%Y%m%dT%H%M%SZ)
|
||||
|
||||
if [ "$rc" -eq 0 ]; then
|
||||
echo "[runner] success — final upload to $HF_REPO"
|
||||
if [ -d "$NANOCHAT_BASE_DIR" ]; then
|
||||
# Skip the climbmix dataset shards (~2GB of public data, not model artifacts)
|
||||
hf upload "$HF_REPO" "$NANOCHAT_BASE_DIR" . \
|
||||
--repo-type model --commit-message "final rc=0 $TS" \
|
||||
--exclude "base_data_climbmix/**" --exclude "wandb/**" || \
|
||||
echo "[runner] WARN: final upload failed"
|
||||
fi
|
||||
# Also upload the runner log so we have a permanent record of this successful run.
|
||||
if [ -f "$LOG_FILE" ]; then
|
||||
hf upload "$HF_REPO" "$LOG_FILE" "_runs/${TS}/runner.log" \
|
||||
--repo-type model --commit-message "runner log $TS" || \
|
||||
echo "[runner] WARN: runner.log upload failed"
|
||||
fi
|
||||
else
|
||||
echo "[runner] failure rc=$rc — dumping logs to HF for offline debug"
|
||||
mkdir -p /tmp/failure
|
||||
cp /workspace/*.log /tmp/failure/ 2>/dev/null || true
|
||||
[ -d "$NANOCHAT_BASE_DIR/report" ] && cp -r "$NANOCHAT_BASE_DIR/report" /tmp/failure/ 2>/dev/null || true
|
||||
[ -d "$WORKDIR" ] && (cd "$WORKDIR" && git rev-parse HEAD 2>/dev/null > /tmp/failure/git-head.txt || true)
|
||||
|
||||
hf upload "$HF_REPO" /tmp/failure "_failures/${TS}-rc${rc}/logs" \
|
||||
--repo-type model --commit-message "failure rc=$rc logs $TS" || \
|
||||
echo "[runner] WARN: log upload failed"
|
||||
|
||||
if [ "$UPLOAD_FAILURE_CACHE" = "1" ] && [ -d "$NANOCHAT_BASE_DIR" ]; then
|
||||
echo "[runner] UPLOAD_FAILURE_CACHE=1 — also dumping partial cache (may be slow)"
|
||||
hf upload "$HF_REPO" "$NANOCHAT_BASE_DIR" "_failures/${TS}-rc${rc}/cache" \
|
||||
--repo-type model --commit-message "failure rc=$rc cache $TS" \
|
||||
--exclude "base_data_climbmix/**" --exclude "wandb/**" || true
|
||||
fi
|
||||
echo "[runner] failure artifacts: https://huggingface.co/$HF_REPO/tree/main/_failures/${TS}-rc${rc}"
|
||||
fi
|
||||
|
||||
echo "[runner] self-deleting pod $RUNPOD_POD_ID"
|
||||
# REST API first — pod-scoped key has delete permission and the API is reliable.
|
||||
# The pod's preinstalled runpodctl is unreliable (often missing config or 'pod' subcommand).
|
||||
if curl -fsS -X DELETE \
|
||||
-H "Authorization: Bearer ${RUNPOD_API_KEY:-}" \
|
||||
"https://rest.runpod.io/v1/pods/$RUNPOD_POD_ID" 2>&1; then
|
||||
echo "[runner] REST delete request accepted"
|
||||
else
|
||||
echo "[runner] REST delete failed, trying runpodctl as fallback"
|
||||
runpodctl pod delete "$RUNPOD_POD_ID" 2>&1 || \
|
||||
runpodctl remove pod "$RUNPOD_POD_ID" 2>&1 || \
|
||||
echo "[runner] WARN: all delete methods failed — pod may need manual cleanup"
|
||||
fi
|
||||
exit "$rc"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
: "${HF_TOKEN:?HF_TOKEN must be set}"
|
||||
: "${WANDB_API_KEY:?WANDB_API_KEY must be set}"
|
||||
: "${RUNPOD_POD_ID:?RUNPOD_POD_ID must be set (auto by RunPod)}"
|
||||
|
||||
rm -rf "$WORKDIR"
|
||||
git clone "https://github.com/${NANOCHAT_REPO}.git" "$WORKDIR"
|
||||
cd "$WORKDIR"
|
||||
# `--` disambiguates ref-vs-file (some images create a `dev` file in HOME)
|
||||
git checkout "$NANOCHAT_REF" --
|
||||
echo "[runner] HEAD = $(git rev-parse HEAD)"
|
||||
|
||||
sed -i 's/--depth=24/--depth=12/' runs/speedrun.sh
|
||||
sed -i 's/ --target-param-data-ratio=8//' runs/speedrun.sh
|
||||
# Inject `set -euo pipefail` so a mid-pipeline failure (e.g. chat_sft) propagates
|
||||
# as rc!=0 instead of being silently swallowed by the next command.
|
||||
sed -i '1a set -euo pipefail' runs/speedrun.sh
|
||||
echo "[runner] speedrun.sh edits applied:"
|
||||
grep -n 'depth\|target-param\|set -e' runs/speedrun.sh || true
|
||||
|
||||
# Explicit venv setup BEFORE speedrun.sh so we can run diagnostic probes
|
||||
# inside the venv. speedrun.sh's uv sync is idempotent (no-op the second time).
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR
|
||||
mkdir -p "$NANOCHAT_BASE_DIR"
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
pip install --quiet --upgrade huggingface_hub
|
||||
|
||||
# Ensure HF token flows to the kernels lib (some libs read HF_HUB_TOKEN, not HF_TOKEN)
|
||||
export HF_HUB_TOKEN="${HF_TOKEN}"
|
||||
|
||||
# Bump kernels to latest — pyproject pins >=0.11.7 and uv often picks exactly that;
|
||||
# 0.11.x had kernel-resolution bugs that affect FA3 loading silently.
|
||||
echo "[runner] upgrading kernels lib for FA3 reliability"
|
||||
uv pip install --quiet --upgrade 'kernels>=0.13.0' 2>&1 || \
|
||||
echo "[runner] WARN: kernels upgrade failed (continuing)"
|
||||
|
||||
# Install hf_transfer — runpod base image sets HF_HUB_ENABLE_HF_TRANSFER=1, which
|
||||
# makes huggingface_hub raise ValueError if the package is missing. chat_sft loads
|
||||
# HuggingFaceTB/smol-smoltalk via datasets and crashes without this.
|
||||
echo "[runner] installing hf_transfer for SFT dataset download"
|
||||
uv pip install --quiet hf_transfer 2>&1 || echo "[runner] WARN: hf_transfer install failed"
|
||||
|
||||
# FA3 diagnostic probe — surfaces real errors (nanochat silently swallows them).
|
||||
# Non-fatal: SDPA fallback is automatic. We want this output in the log
|
||||
# regardless of outcome so we can decide what to do about FA3.
|
||||
echo "[runner] === FA3 PROBE BEGIN ==="
|
||||
python "$WORKDIR/runs/runpod/probe_fa3.py" || echo "[runner] FA3 probe reported issues (non-fatal — continuing with SDPA fallback)"
|
||||
echo "[runner] === FA3 PROBE END ==="
|
||||
|
||||
(
|
||||
while true; do
|
||||
sleep "$BACKUP_INTERVAL"
|
||||
if [ -d "$NANOCHAT_BASE_DIR" ]; then
|
||||
hf upload "$HF_REPO" "$NANOCHAT_BASE_DIR" . \
|
||||
--repo-type model \
|
||||
--commit-message "checkpoint $(date -Iseconds)" \
|
||||
--exclude "base_data_climbmix/**" --exclude "wandb/**" \
|
||||
>> /workspace/backup.log 2>&1 || true
|
||||
fi
|
||||
done
|
||||
) &
|
||||
BACKUP_PID=$!
|
||||
echo "[runner] backup loop pid=$BACKUP_PID interval=${BACKUP_INTERVAL}s"
|
||||
|
||||
export WANDB_RUN
|
||||
WANDB_RUN="$WANDB_RUN" bash runs/speedrun.sh
|
||||
|
||||
# Verify expected pipeline outputs — speedrun.sh historically didn't `set -e`;
|
||||
# we patched it above, but double-check the artifacts that matter for the d12 baseline.
|
||||
echo "[runner] verifying pipeline outputs"
|
||||
missing=()
|
||||
for required in base_checkpoints/d12 chatsft_checkpoints/d12 tokenizer report; do
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/$required" ]; then
|
||||
missing+=("$required")
|
||||
fi
|
||||
done
|
||||
if [ ${#missing[@]} -gt 0 ]; then
|
||||
echo "[runner] FAIL: pipeline finished but missing expected artifacts: ${missing[*]}"
|
||||
exit 1
|
||||
fi
|
||||
echo "[runner] all expected artifacts present"
|
||||
|
||||
echo "[runner] $(date -Iseconds) pipeline complete"
|
||||
161
runs/runpod/d12_sft_only.sh
Executable file
161
runs/runpod/d12_sft_only.sh
Executable file
|
|
@ -0,0 +1,161 @@
|
|||
#!/usr/bin/env bash
|
||||
# d12 SFT-only resume runner. Runs INSIDE a RunPod pod.
|
||||
#
|
||||
# Use case: the d12 base_train + base_eval already succeeded and uploaded to HF,
|
||||
# but chat_sft failed (e.g., missing hf_transfer package). Instead of re-running
|
||||
# the whole pipeline, this runner:
|
||||
# 1. Downloads base_checkpoints/d12/ + tokenizer/ from HF
|
||||
# 2. Installs hf_transfer (the actual SFT bug fix)
|
||||
# 3. Runs chat_sft + chat_eval directly (skips speedrun.sh)
|
||||
# 4. Uploads chatsft_checkpoints/ + chat_eval results + report to HF
|
||||
# 5. Self-deletes
|
||||
#
|
||||
# Required env: HF_TOKEN, WANDB_API_KEY
|
||||
# Optional env:
|
||||
# WANDB_RUN default: d12-sft
|
||||
# NANOCHAT_REPO default: Team-XSA/nanochat
|
||||
# NANOCHAT_REF default: dev
|
||||
# HF_REPO default: haydenfree/nanochat-d12-baseline (where the base lives)
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
NANOCHAT_REPO="${NANOCHAT_REPO:-Team-XSA/nanochat}"
|
||||
NANOCHAT_REF="${NANOCHAT_REF:-dev}"
|
||||
HF_REPO="${HF_REPO:-haydenfree/nanochat-d12-baseline}"
|
||||
WANDB_RUN="${WANDB_RUN:-d12-sft}"
|
||||
|
||||
WORKDIR="/workspace/nanochat"
|
||||
LOG_FILE="/workspace/runner.log"
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
|
||||
mkdir -p /workspace
|
||||
echo "[sft] $(date -Iseconds) starting on pod=$RUNPOD_POD_ID"
|
||||
echo "[sft] resuming from base checkpoint at $HF_REPO"
|
||||
|
||||
# Bootstrap huggingface_hub system-wide so cleanup can upload logs even on early failure.
|
||||
{ pip3 install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
python3 -m pip install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
echo "[sft] WARN: could not pre-install huggingface_hub"; } || true
|
||||
|
||||
cleanup() {
|
||||
local rc=$?
|
||||
set +e
|
||||
echo "[sft] cleanup: exit code $rc at $(date -Iseconds)"
|
||||
|
||||
local TS
|
||||
TS=$(date -u +%Y%m%dT%H%M%SZ)
|
||||
|
||||
if [ "$rc" -eq 0 ]; then
|
||||
echo "[sft] success — uploading chatsft_checkpoints + report + log"
|
||||
# Only upload the SFT-specific subdirs so we don't re-upload base.
|
||||
for subdir in chatsft_checkpoints report; do
|
||||
if [ -d "$NANOCHAT_BASE_DIR/$subdir" ]; then
|
||||
hf upload "$HF_REPO" "$NANOCHAT_BASE_DIR/$subdir" "$subdir" \
|
||||
--repo-type model --commit-message "$subdir SFT-resume rc=0 $TS" || \
|
||||
echo "[sft] WARN: $subdir upload failed"
|
||||
fi
|
||||
done
|
||||
if [ -f "$LOG_FILE" ]; then
|
||||
hf upload "$HF_REPO" "$LOG_FILE" "_runs/${TS}-sft/runner.log" \
|
||||
--repo-type model --commit-message "SFT runner log $TS" || \
|
||||
echo "[sft] WARN: runner.log upload failed"
|
||||
fi
|
||||
else
|
||||
echo "[sft] failure rc=$rc — dumping logs"
|
||||
mkdir -p /tmp/failure
|
||||
cp /workspace/*.log /tmp/failure/ 2>/dev/null || true
|
||||
[ -d "$NANOCHAT_BASE_DIR/report" ] && cp -r "$NANOCHAT_BASE_DIR/report" /tmp/failure/ 2>/dev/null || true
|
||||
[ -d "$WORKDIR" ] && (cd "$WORKDIR" && git rev-parse HEAD 2>/dev/null > /tmp/failure/git-head.txt || true)
|
||||
hf upload "$HF_REPO" /tmp/failure "_failures/${TS}-sft-rc${rc}/logs" \
|
||||
--repo-type model --commit-message "SFT-resume failure rc=$rc $TS" || \
|
||||
echo "[sft] WARN: log upload failed"
|
||||
fi
|
||||
|
||||
echo "[sft] self-deleting pod $RUNPOD_POD_ID"
|
||||
if curl -fsS -X DELETE \
|
||||
-H "Authorization: Bearer ${RUNPOD_API_KEY:-}" \
|
||||
"https://rest.runpod.io/v1/pods/$RUNPOD_POD_ID" 2>&1; then
|
||||
echo "[sft] REST delete request accepted"
|
||||
else
|
||||
echo "[sft] REST delete failed, trying runpodctl as fallback"
|
||||
runpodctl pod delete "$RUNPOD_POD_ID" 2>&1 || \
|
||||
runpodctl remove pod "$RUNPOD_POD_ID" 2>&1 || \
|
||||
echo "[sft] WARN: all delete methods failed — pod may need manual cleanup"
|
||||
fi
|
||||
exit "$rc"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
: "${HF_TOKEN:?HF_TOKEN must be set}"
|
||||
: "${WANDB_API_KEY:?WANDB_API_KEY must be set}"
|
||||
: "${RUNPOD_POD_ID:?RUNPOD_POD_ID must be set (auto by RunPod)}"
|
||||
|
||||
# Clone fork
|
||||
rm -rf "$WORKDIR"
|
||||
git clone "https://github.com/${NANOCHAT_REPO}.git" "$WORKDIR"
|
||||
cd "$WORKDIR"
|
||||
git checkout "$NANOCHAT_REF" --
|
||||
echo "[sft] HEAD = $(git rev-parse HEAD)"
|
||||
|
||||
# Env + uv
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR
|
||||
mkdir -p "$NANOCHAT_BASE_DIR"
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
pip install --quiet --upgrade huggingface_hub
|
||||
export HF_HUB_TOKEN="${HF_TOKEN}"
|
||||
|
||||
# Install hf_transfer — THE actual fix for the previous SFT failure.
|
||||
echo "[sft] installing hf_transfer (the bug from last run)"
|
||||
uv pip install --quiet hf_transfer
|
||||
|
||||
# Pull tokenizer + base checkpoint from HF in TWO separate calls.
|
||||
# `hf download` only honors the LAST --include when specified multiple times
|
||||
# (multi-include works for upload, not download — verified the hard way).
|
||||
echo "[sft] downloading tokenizer from $HF_REPO"
|
||||
hf download "$HF_REPO" \
|
||||
--include "tokenizer/**" \
|
||||
--local-dir "$NANOCHAT_BASE_DIR" \
|
||||
--repo-type model
|
||||
|
||||
echo "[sft] downloading base_checkpoints/d12 from $HF_REPO"
|
||||
hf download "$HF_REPO" \
|
||||
--include "base_checkpoints/d12/**" \
|
||||
--local-dir "$NANOCHAT_BASE_DIR" \
|
||||
--repo-type model
|
||||
|
||||
# Verify both pieces actually landed before invoking chat_sft.
|
||||
echo "[sft] verifying downloads"
|
||||
ls -la "$NANOCHAT_BASE_DIR/base_checkpoints/d12/" 2>&1 || true
|
||||
ls -la "$NANOCHAT_BASE_DIR/tokenizer/" 2>&1 || true
|
||||
[ -f "$NANOCHAT_BASE_DIR/tokenizer/tokenizer.pkl" ] || { echo "[sft] FAIL: tokenizer.pkl missing after download"; exit 1; }
|
||||
[ -n "$(ls -A "$NANOCHAT_BASE_DIR/base_checkpoints/d12/" 2>/dev/null)" ] || { echo "[sft] FAIL: base_checkpoints/d12 is empty"; exit 1; }
|
||||
echo "[sft] downloads verified"
|
||||
|
||||
# Also need identity_conversations.jsonl for SFT (speedrun.sh normally fetches it)
|
||||
echo "[sft] fetching identity_conversations.jsonl"
|
||||
curl -L -fsS -o "$NANOCHAT_BASE_DIR/identity_conversations.jsonl" \
|
||||
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# Run only SFT + chat_eval + report. NOT speedrun.sh (which would re-do base_train).
|
||||
NPROC=$(nvidia-smi -L | wc -l)
|
||||
echo "[sft] running chat_sft on $NPROC GPUs"
|
||||
torchrun --standalone --nproc_per_node="$NPROC" -m scripts.chat_sft -- \
|
||||
--device-batch-size=16 --run="$WANDB_RUN"
|
||||
|
||||
echo "[sft] running chat_eval"
|
||||
torchrun --standalone --nproc_per_node="$NPROC" -m scripts.chat_eval -- -i sft
|
||||
|
||||
echo "[sft] regenerating report (will include new SFT sections)"
|
||||
python -m nanochat.report generate || true
|
||||
|
||||
# Verify SFT artifacts exist before declaring success
|
||||
if [ ! -d "$NANOCHAT_BASE_DIR/chatsft_checkpoints" ]; then
|
||||
echo "[sft] FAIL: chatsft_checkpoints/ missing after chat_sft"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "[sft] $(date -Iseconds) SFT pipeline complete"
|
||||
92
runs/runpod/kickoff.sh
Executable file
92
runs/runpod/kickoff.sh
Executable file
|
|
@ -0,0 +1,92 @@
|
|||
#!/usr/bin/env bash
|
||||
# Generic local kickoff for RunPod runs.
|
||||
# Picks a runner script in this repo (runs/runpod/<RUNNER>.sh) and spins up a pod.
|
||||
#
|
||||
# Prereqs:
|
||||
# 1. ~/.config/team-xsa/runpod.env exports HF_TOKEN, WANDB_API_KEY, RUNPOD_TEMPLATE_ID
|
||||
# 2. The template referenced by RUNPOD_TEMPLATE_ID has docker-start-cmd:
|
||||
# bash,-lc,curl -fsSL "$RUNNER_URL" | bash >> /workspace/runner.log 2>&1
|
||||
# 3. The runner script for this experiment has been pushed to Team-XSA/nanochat
|
||||
#
|
||||
# Usage:
|
||||
# source ~/.config/team-xsa/runpod.env
|
||||
# bash runs/runpod/kickoff.sh d12 # uses runs/runpod/d12.sh
|
||||
# bash runs/runpod/kickoff.sh d24 # uses runs/runpod/d24.sh
|
||||
# bash runs/runpod/kickoff.sh xsa_d12 # uses runs/runpod/xsa_d12.sh
|
||||
#
|
||||
# Optional env overrides:
|
||||
# GPU_ID default: "NVIDIA H100 80GB HBM3"
|
||||
# GPU_COUNT default: 8
|
||||
# CLOUD_TYPE default: SECURE (COMMUNITY when capacity available, cheaper)
|
||||
# DISK_GB default: 200
|
||||
# NANOCHAT_REPO default: Team-XSA/nanochat
|
||||
# NANOCHAT_REF default: dev
|
||||
# WANDB_RUN default: <RUNNER>
|
||||
# POD_NAME default: <RUNNER>-<timestamp>
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
RUNNER="${1:-}"
|
||||
if [ -z "$RUNNER" ]; then
|
||||
echo "Usage: bash runs/runpod/kickoff.sh <runner-name>"
|
||||
echo " e.g. bash runs/runpod/kickoff.sh d12"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
: "${HF_TOKEN:?HF_TOKEN not set — source ~/.config/team-xsa/runpod.env}"
|
||||
: "${WANDB_API_KEY:?WANDB_API_KEY not set — source ~/.config/team-xsa/runpod.env}"
|
||||
: "${RUNPOD_TEMPLATE_ID:?RUNPOD_TEMPLATE_ID not set — create the template once and add it to ~/.config/team-xsa/runpod.env}"
|
||||
|
||||
NANOCHAT_REPO="${NANOCHAT_REPO:-Team-XSA/nanochat}"
|
||||
NANOCHAT_REF="${NANOCHAT_REF:-dev}"
|
||||
WANDB_RUN="${WANDB_RUN:-$RUNNER}"
|
||||
RUNNER_URL="${RUNNER_URL:-https://raw.githubusercontent.com/${NANOCHAT_REPO}/${NANOCHAT_REF}/runs/runpod/${RUNNER}.sh}"
|
||||
|
||||
GPU_ID="${GPU_ID:-NVIDIA H100 80GB HBM3}"
|
||||
GPU_COUNT="${GPU_COUNT:-8}"
|
||||
CLOUD_TYPE="${CLOUD_TYPE:-SECURE}"
|
||||
DISK_GB="${DISK_GB:-200}"
|
||||
POD_NAME="${POD_NAME:-${RUNNER}-$(date +%Y%m%d-%H%M)}"
|
||||
|
||||
echo "Verifying runner URL is reachable: $RUNNER_URL"
|
||||
if ! curl -sfI "$RUNNER_URL" >/dev/null; then
|
||||
echo "ERROR: runner not reachable at $RUNNER_URL"
|
||||
echo " - Did you push runs/runpod/${RUNNER}.sh to ${NANOCHAT_REPO}@${NANOCHAT_REF}?"
|
||||
echo " - Is the repo public?"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export HF_TOKEN WANDB_API_KEY WANDB_RUN RUNNER_URL NANOCHAT_REPO NANOCHAT_REF
|
||||
ENV_JSON=$(python3 - <<'PY'
|
||||
import json, os
|
||||
keys = ["HF_TOKEN","WANDB_API_KEY","WANDB_RUN","RUNNER_URL","NANOCHAT_REPO","NANOCHAT_REF"]
|
||||
print(json.dumps({k: os.environ[k] for k in keys if k in os.environ}))
|
||||
PY
|
||||
)
|
||||
|
||||
echo "Creating pod:"
|
||||
echo " name = $POD_NAME"
|
||||
echo " template = $RUNPOD_TEMPLATE_ID"
|
||||
echo " runner = $RUNNER_URL"
|
||||
echo " gpu = $GPU_COUNT × $GPU_ID"
|
||||
echo " cloud = $CLOUD_TYPE"
|
||||
echo " disk = ${DISK_GB} GB"
|
||||
|
||||
runpodctl pod create \
|
||||
--name "$POD_NAME" \
|
||||
--template-id "$RUNPOD_TEMPLATE_ID" \
|
||||
--gpu-id "$GPU_ID" \
|
||||
--gpu-count "$GPU_COUNT" \
|
||||
--cloud-type "$CLOUD_TYPE" \
|
||||
--container-disk-in-gb "$DISK_GB" \
|
||||
--env "$ENV_JSON"
|
||||
|
||||
echo
|
||||
echo "Logs (after pod boots):"
|
||||
echo " POD_ID=\$(runpodctl pod list --name '$POD_NAME' -o json | jq -r '.[0].id')"
|
||||
echo " runpodctl ssh info \$POD_ID"
|
||||
echo " ssh <user>@<host> 'tail -f /workspace/runner.log'"
|
||||
echo
|
||||
echo "Wandb: project=nanochat / nanochat-sft, run name: $WANDB_RUN"
|
||||
echo "HF (success): https://huggingface.co/haydenfree/nanochat-d12-baseline"
|
||||
echo "HF (failure): https://huggingface.co/haydenfree/nanochat-d12-baseline/tree/main/_failures"
|
||||
222
runs/runpod/probe_fa3.py
Normal file
222
runs/runpod/probe_fa3.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
#!/usr/bin/env python3
|
||||
# pyright: reportMissingImports=false
|
||||
"""
|
||||
Comprehensive FA3 / kernels diagnostic probe.
|
||||
|
||||
nanochat/flash_attention.py:_load_flash_attention_3 swallows ALL exceptions silently
|
||||
and falls back to SDPA. This script runs the same code path with full tracebacks
|
||||
so we can see why FA3 isn't loading on the pod.
|
||||
|
||||
Run inside the pod (after uv sync, with venv active):
|
||||
python runs/runpod/probe_fa3.py
|
||||
|
||||
Exits 0 if FA3 is fully usable, 1 if any check fails.
|
||||
|
||||
Note: the pyright pragma above is intentional — torch/huggingface_hub/kernels
|
||||
are only present at pod runtime; the local IDE will flag them as unresolved.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
# This probe lives at <repo>/runs/runpod/probe_fa3.py — add repo root to
|
||||
# sys.path so we can `import nanochat.*` regardless of cwd or how we're invoked.
|
||||
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
if _REPO_ROOT not in sys.path:
|
||||
sys.path.insert(0, _REPO_ROOT)
|
||||
|
||||
OK = "\033[32mOK\033[0m"
|
||||
FAIL = "\033[31mFAIL\033[0m"
|
||||
WARN = "\033[33mWARN\033[0m"
|
||||
|
||||
|
||||
def section(n, name):
|
||||
print()
|
||||
print("=" * 80)
|
||||
print(f"[{n}] {name}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def fmt_token(tok):
|
||||
if not tok:
|
||||
return "NOT SET"
|
||||
return f"SET (len={len(tok)}, prefix={tok[:7]}…)"
|
||||
|
||||
|
||||
passed_all = True
|
||||
|
||||
|
||||
def fail(msg):
|
||||
global passed_all
|
||||
passed_all = False
|
||||
print(f" {FAIL} {msg}")
|
||||
|
||||
|
||||
def warn(msg):
|
||||
print(f" {WARN} {msg}")
|
||||
|
||||
|
||||
def ok(msg):
|
||||
print(f" {OK} {msg}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(1, "Environment")
|
||||
print(f" python : {sys.version.split()[0]}")
|
||||
print(f" platform : {platform.platform()}")
|
||||
print(f" cwd : {os.getcwd()}")
|
||||
hf_tok = os.environ.get("HF_TOKEN", "")
|
||||
hf_hub_tok = os.environ.get("HF_HUB_TOKEN", "")
|
||||
print(f" HF_TOKEN : {fmt_token(hf_tok)}")
|
||||
print(f" HF_HUB_TOKEN : {fmt_token(hf_hub_tok)}")
|
||||
print(f" HF_HOME : {os.environ.get('HF_HOME', '(default ~/.cache/huggingface)')}")
|
||||
print(f" HUGGINGFACE_HUB_CACHE : {os.environ.get('HUGGINGFACE_HUB_CACHE', '(unset)')}")
|
||||
print(f" WANDB_API_KEY: {fmt_token(os.environ.get('WANDB_API_KEY',''))}")
|
||||
|
||||
if not hf_tok:
|
||||
fail("HF_TOKEN env var is empty — kernels lib will fall back to anonymous and may rate-limit")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(2, "Network connectivity")
|
||||
for url, label in [
|
||||
("https://huggingface.co", "huggingface.co"),
|
||||
("https://cdn-lfs.huggingface.co", "cdn-lfs.huggingface.co"),
|
||||
("https://github.com", "github.com"),
|
||||
]:
|
||||
try:
|
||||
rc = subprocess.run(
|
||||
["curl", "-sfI", "--max-time", "10", url],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
).returncode
|
||||
if rc == 0:
|
||||
ok(f"{label} reachable")
|
||||
else:
|
||||
fail(f"{label} unreachable (curl rc={rc})")
|
||||
except Exception as e:
|
||||
fail(f"{label}: {type(e).__name__}: {e}")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(3, "huggingface_hub auth (does the token actually work?)")
|
||||
try:
|
||||
from huggingface_hub import whoami
|
||||
info = whoami(token=hf_tok or None)
|
||||
ok(f"authenticated as: {info.get('name','?')} (type={info.get('type','?')})")
|
||||
print(f" orgs: {[o.get('name') for o in info.get('orgs', [])]}")
|
||||
print(f" access token role: {info.get('auth',{}).get('accessToken',{}).get('role','?')}")
|
||||
except Exception as e:
|
||||
fail(f"whoami failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(4, "PyTorch / CUDA / GPU")
|
||||
try:
|
||||
import torch
|
||||
print(f" torch : {torch.__version__}")
|
||||
print(f" cuda available : {torch.cuda.is_available()}")
|
||||
print(f" cuda version : {torch.version.cuda}")
|
||||
if torch.cuda.is_available():
|
||||
print(f" device count : {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
name = torch.cuda.get_device_name(i)
|
||||
mark = OK if major == 9 else WARN
|
||||
print(f" device {i} : {name} sm{major}{minor} [{mark}]")
|
||||
major, _ = torch.cuda.get_device_capability(0)
|
||||
if major != 9:
|
||||
fail(f"FA3 requires sm90 (Hopper); device 0 is sm{major}{_}")
|
||||
else:
|
||||
fail("CUDA not available")
|
||||
except Exception as e:
|
||||
fail(f"torch import failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(5, "kernels library")
|
||||
try:
|
||||
import kernels
|
||||
ver = getattr(kernels, "__version__", "?")
|
||||
print(f" kernels : {ver}")
|
||||
if ver != "?":
|
||||
major_minor = tuple(int(x) for x in ver.split(".")[:2])
|
||||
if major_minor < (0, 13):
|
||||
warn(f"kernels {ver} < 0.13 — older versions have known kernel-resolution bugs; consider 'uv pip install --upgrade kernels'")
|
||||
else:
|
||||
ok(f"kernels {ver} is recent")
|
||||
print(f" kernels path : {kernels.__file__}")
|
||||
except Exception as e:
|
||||
fail(f"kernels not importable: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(6, "Fetch varunneal/flash-attention-3 (THE actual nanochat code path)")
|
||||
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
print(" calling get_kernel('varunneal/flash-attention-3') …")
|
||||
k = get_kernel("varunneal/flash-attention-3")
|
||||
ok(f"get_kernel returned: {type(k).__name__}")
|
||||
print(f" module path: {getattr(k, '__file__', '(no __file__)')}")
|
||||
iface = k.flash_attn_interface
|
||||
ok(f"flash_attn_interface: {type(iface).__name__}")
|
||||
fn = iface.flash_attn_func
|
||||
ok(f"flash_attn_func: callable={callable(fn)}")
|
||||
print("\n >>> FA3 binary is usable on this pod. <<<")
|
||||
except Exception as e:
|
||||
fail(f"FA3 fetch failed: {type(e).__name__}: {e}")
|
||||
print()
|
||||
traceback.print_exc()
|
||||
print()
|
||||
print(" Likely causes:")
|
||||
print(" 1. Network/DNS issue (HF Hub unreachable from this DC)")
|
||||
print(" 2. Old kernels version with resolver bugs (try kernels>=0.13)")
|
||||
print(" 3. HF token not flowing — try `export HF_HUB_TOKEN=$HF_TOKEN`")
|
||||
print(" 4. No prebuilt binary for this torch/cuda combo (we have torch 2.9 + cu128 — should be supported)")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(7, "HF Hub cache state")
|
||||
import pathlib
|
||||
cache_root = pathlib.Path(os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")))
|
||||
print(f" cache root : {cache_root}")
|
||||
if cache_root.exists():
|
||||
try:
|
||||
size = sum(p.stat().st_size for p in cache_root.rglob("*") if p.is_file())
|
||||
print(f" size on disk : {size / 1024 / 1024:.1f} MB")
|
||||
except Exception as e:
|
||||
print(f" (could not size cache: {e})")
|
||||
fa3_marker = list(cache_root.rglob("*flash*attention*3*"))
|
||||
if fa3_marker:
|
||||
ok(f"found FA3-related cache entries: {len(fa3_marker)}")
|
||||
for p in fa3_marker[:5]:
|
||||
print(f" {p}")
|
||||
else:
|
||||
warn("no flash-attention-3 entries in cache yet")
|
||||
else:
|
||||
print(" (cache directory does not exist)")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(8, "Replicate nanochat.flash_attention detection")
|
||||
try:
|
||||
from nanochat.flash_attention import HAS_FA3, USE_FA3, _fa3
|
||||
if HAS_FA3:
|
||||
ok("nanochat.flash_attention.HAS_FA3 = True")
|
||||
else:
|
||||
fail("nanochat.flash_attention.HAS_FA3 = False (despite section 6 results)")
|
||||
print(f" USE_FA3 = {USE_FA3}")
|
||||
print(f" _fa3 object: {_fa3}")
|
||||
except Exception as e:
|
||||
fail(f"import nanochat.flash_attention failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(9, "Verdict")
|
||||
if passed_all:
|
||||
print(f" {OK} all checks passed — FA3 is wired up and base_train should use it")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f" {FAIL} one or more checks failed — see above. Run will fall back to SDPA (slower, possibly much).")
|
||||
print()
|
||||
print(" Continuing the training run anyway is safe; FA3 fallback to SDPA is automatic.")
|
||||
sys.exit(1)
|
||||
134
runs/runpod/smoke.sh
Executable file
134
runs/runpod/smoke.sh
Executable file
|
|
@ -0,0 +1,134 @@
|
|||
#!/usr/bin/env bash
|
||||
# Minimal smoke test. Runs INSIDE a RunPod pod.
|
||||
# Validates: pod boot, env-var injection, git clone, uv sync, GPU torch,
|
||||
# tokenizer + base_train code paths, HF upload, runpodctl self-delete.
|
||||
# Does NOT test: multi-GPU, FP8, full training horizon, SFT, eval.
|
||||
#
|
||||
# Sized for a 1-GPU pod, completes in ~3-4 min wall clock.
|
||||
# Kick off with: GPU_COUNT=1 bash runs/runpod/kickoff.sh smoke
|
||||
#
|
||||
# Required env: HF_TOKEN, WANDB_API_KEY
|
||||
# Auto-set by RunPod: RUNPOD_POD_ID, RUNPOD_API_KEY
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
NANOCHAT_REPO="${NANOCHAT_REPO:-Team-XSA/nanochat}"
|
||||
NANOCHAT_REF="${NANOCHAT_REF:-dev}"
|
||||
HF_REPO="${HF_REPO:-haydenfree/nanochat-d12-baseline}"
|
||||
WANDB_RUN="${WANDB_RUN:-smoke}"
|
||||
|
||||
TS=$(date -u +%Y%m%dT%H%M%SZ)
|
||||
HF_PATH_PREFIX="_smoke/${TS}"
|
||||
|
||||
WORKDIR="/workspace/nanochat"
|
||||
LOG_FILE="/workspace/runner.log"
|
||||
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
|
||||
mkdir -p /workspace
|
||||
# NOTE: dockerStartCmd already redirects stdout/stderr to $LOG_FILE.
|
||||
# Don't add a second tee here — would write every line twice.
|
||||
|
||||
echo "[smoke] $(date -Iseconds) starting on pod=$RUNPOD_POD_ID"
|
||||
|
||||
# Bootstrap huggingface_hub system-wide so the cleanup trap can upload logs
|
||||
# even if we fail before the venv is activated. Try pip3, then python3 -m pip.
|
||||
{ pip3 install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
python3 -m pip install --break-system-packages --quiet --upgrade huggingface_hub 2>&1 || \
|
||||
echo "[smoke] WARN: could not pre-install huggingface_hub; cleanup uploads may fail"; } || true
|
||||
|
||||
cleanup() {
|
||||
local rc=$?
|
||||
set +e
|
||||
echo "[smoke] cleanup: exit code $rc"
|
||||
|
||||
# Always upload the runner log (success or failure) so we can see what happened
|
||||
mkdir -p /tmp/smoke-out
|
||||
cp /workspace/*.log /tmp/smoke-out/ 2>/dev/null || true
|
||||
echo "rc=$rc ts=$TS pod=$RUNPOD_POD_ID" > /tmp/smoke-out/result.txt
|
||||
[ -d "$WORKDIR" ] && (cd "$WORKDIR" && git rev-parse HEAD 2>/dev/null > /tmp/smoke-out/git-head.txt || true)
|
||||
hf upload "$HF_REPO" /tmp/smoke-out "$HF_PATH_PREFIX" \
|
||||
--repo-type model --commit-message "smoke rc=$rc $TS" || \
|
||||
echo "[smoke] WARN: HF upload failed"
|
||||
|
||||
echo "[smoke] artifacts: https://huggingface.co/$HF_REPO/tree/main/$HF_PATH_PREFIX"
|
||||
echo "[smoke] self-deleting pod $RUNPOD_POD_ID"
|
||||
# REST API first — pod-scoped key has delete permission and the API is reliable.
|
||||
# The pod's preinstalled runpodctl is unreliable (often missing config or 'pod' subcommand).
|
||||
if curl -fsS -X DELETE \
|
||||
-H "Authorization: Bearer ${RUNPOD_API_KEY:-}" \
|
||||
"https://rest.runpod.io/v1/pods/$RUNPOD_POD_ID" 2>&1; then
|
||||
echo "[smoke] REST delete request accepted"
|
||||
else
|
||||
echo "[smoke] REST delete failed, trying runpodctl as fallback"
|
||||
runpodctl pod delete "$RUNPOD_POD_ID" 2>&1 || \
|
||||
runpodctl remove pod "$RUNPOD_POD_ID" 2>&1 || \
|
||||
echo "[smoke] WARN: all delete methods failed — pod may need manual cleanup"
|
||||
fi
|
||||
exit "$rc"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
: "${HF_TOKEN:?HF_TOKEN must be set}"
|
||||
: "${WANDB_API_KEY:?WANDB_API_KEY must be set}"
|
||||
: "${RUNPOD_POD_ID:?RUNPOD_POD_ID must be set (auto by RunPod)}"
|
||||
|
||||
# Clone fork
|
||||
rm -rf "$WORKDIR"
|
||||
git clone "https://github.com/${NANOCHAT_REPO}.git" "$WORKDIR"
|
||||
cd "$WORKDIR"
|
||||
# `--` disambiguates ref-vs-file (some images create a `dev` file in HOME)
|
||||
git checkout "$NANOCHAT_REF" --
|
||||
echo "[smoke] HEAD = $(git rev-parse HEAD)"
|
||||
|
||||
# Env + uv
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR
|
||||
mkdir -p "$NANOCHAT_BASE_DIR"
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
pip install --quiet --upgrade huggingface_hub
|
||||
|
||||
# Ensure HF token flows to the kernels lib (some libs read HF_HUB_TOKEN, not HF_TOKEN)
|
||||
export HF_HUB_TOKEN="${HF_TOKEN}"
|
||||
|
||||
# Bump kernels to latest — pyproject pins >=0.11.7, uv often picks exactly that;
|
||||
# 0.11.x had kernel-resolution bugs that affect FA3 loading silently.
|
||||
echo "[smoke] upgrading kernels lib for FA3 reliability"
|
||||
uv pip install --quiet --upgrade 'kernels>=0.13.0' 2>&1 || \
|
||||
echo "[smoke] WARN: kernels upgrade failed (continuing with whatever uv installed)"
|
||||
|
||||
# GPU sanity
|
||||
python -c "import torch; print('[smoke] torch', torch.__version__, 'cuda', torch.cuda.is_available(), 'devices', torch.cuda.device_count())"
|
||||
|
||||
# FA3 diagnostic probe — surfaces the real error if FA3 won't load (nanochat
|
||||
# silently swallows it). Non-fatal: SDPA fallback is automatic if probe fails.
|
||||
echo "[smoke] === FA3 PROBE BEGIN ==="
|
||||
python "$WORKDIR/runs/runpod/probe_fa3.py" || echo "[smoke] FA3 probe reported issues (non-fatal — continuing with SDPA fallback)"
|
||||
echo "[smoke] === FA3 PROBE END ==="
|
||||
|
||||
# Minimum dataset + tokenizer (1 shard, 50M chars — enough for the tokenizer
|
||||
# to train on AND for base_train to consume 20 iterations of tokens)
|
||||
python -m nanochat.dataset -n 1
|
||||
python -m scripts.tok_train --max-chars=50000000
|
||||
|
||||
# Tiny base_train. Params from base_train.py docstring (the CPU smoke), adjusted
|
||||
# slightly for GPU. depth=4, 20 iterations. Should finish in ~30s.
|
||||
NPROC=$(nvidia-smi -L | wc -l)
|
||||
echo "[smoke] training on $NPROC GPU(s)"
|
||||
torchrun --standalone --nproc_per_node="$NPROC" -m scripts.base_train -- \
|
||||
--depth=4 \
|
||||
--max-seq-len=512 \
|
||||
--device-batch-size=1 \
|
||||
--total-batch-size=512 \
|
||||
--num-iterations=20 \
|
||||
--eval-every=10 \
|
||||
--eval-tokens=512 \
|
||||
--core-metric-every=-1 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
--run="$WANDB_RUN"
|
||||
|
||||
echo "[smoke] $(date -Iseconds) base_train complete — smoke passed"
|
||||
# trap cleanup handles HF upload + self-delete
|
||||
|
|
@ -52,6 +52,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--xsa", action="store_true", help="enable XSA")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -102,6 +103,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
|
|||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if args.xsa:
|
||||
print0("XSA enabled")
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
|
|
@ -136,7 +139,7 @@ def build_model_meta(depth):
|
|||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
window_pattern=args.window_pattern, use_xsa=args.xsa,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
|
|||
|
|
@ -415,6 +415,7 @@ while True:
|
|||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
"window_pattern": model.config.window_pattern,
|
||||
"use_xsa": model.config.use_xsa,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user