mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-17 13:17:35 +00:00
Implemented XSA, added to pace scripts
This commit is contained in:
parent
eb66bbd4e2
commit
a305bd9612
|
|
@ -26,6 +26,9 @@ def _patch_missing_config_keys(model_config_kwargs):
|
|||
if "window_pattern" not in model_config_kwargs:
|
||||
model_config_kwargs["window_pattern"] = "L"
|
||||
log0(f"Patching missing window_pattern in model config to 'L'")
|
||||
if "use_xsa" not in model_config_kwargs:
|
||||
model_config_kwargs["use_xsa"] = False
|
||||
log0(f"Patching missing use_xsa in model config to False")
|
||||
|
||||
def _patch_missing_keys(model_data, model_config):
|
||||
"""Add default values for new parameters that may be missing in old checkpoints."""
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class GPTConfig:
|
|||
# Characters: L=long (full context), S=short (quarter context)
|
||||
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
|
||||
window_pattern: str = "SSSL"
|
||||
use_xsa: bool = False
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -70,6 +71,8 @@ class CausalSelfAttention(nn.Module):
|
|||
self.n_kv_head = config.n_kv_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
self.use_xsa = config.use_xsa
|
||||
self.xsa = ExclusiveSelfAttention()
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
|
|
@ -120,12 +123,22 @@ class CausalSelfAttention(nn.Module):
|
|||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
||||
if self.use_xsa:
|
||||
y = self.xsa.XSA(y, v)
|
||||
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
y = self.c_proj(y)
|
||||
return y
|
||||
|
||||
|
||||
class ExclusiveSelfAttention(nn.Module):
|
||||
def XSA(self, y, v):
|
||||
Vn = F.normalize(v, dim=-1)
|
||||
Z = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
|
||||
return Z
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ mkdir -p runs/logs
|
|||
|
||||
echo "=== Stage 1: Tokenizer ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: ${XSA:-FALSE}"
|
||||
echo "Started: $(date)"
|
||||
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
|
|
|||
|
|
@ -21,12 +21,16 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
|||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2a: Pretraining (chunk 1) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "WANDB_RUN: $WANDB_RUN"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
source .venv/bin/activate
|
||||
|
|
@ -36,6 +40,7 @@ torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
|
|||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
|
||||
mkdir -p "$CHECKPOINT_DIR"
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
|||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2b: Pretraining (chunk 2 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
|
|
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
|
|||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
|
|
@ -60,6 +65,7 @@ else
|
|||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
|||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2c: Pretraining (chunk 3 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
|
|
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
|
|||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
|
|
@ -60,6 +65,7 @@ else
|
|||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
|
|||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
XSA_ARG=""
|
||||
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
|
||||
|
||||
echo "=== Stage 2d: Pretraining (chunk 4 / auto-resume) ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "XSA: $XSA"
|
||||
echo "Started: $(date)"
|
||||
|
||||
if [ -f "$DONE_MARKER" ]; then
|
||||
|
|
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
|
|||
--target-param-data-ratio=8 \
|
||||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
else
|
||||
echo "Resuming from step $LAST_STEP"
|
||||
|
|
@ -60,6 +65,7 @@ else
|
|||
--device-batch-size=16 \
|
||||
--save-every=200 \
|
||||
--resume-from-step=$LAST_STEP \
|
||||
$XSA_ARG \
|
||||
--run=$WANDB_RUN
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ WANDB_RUN="${WANDB_RUN:-dummy}"
|
|||
echo "=== Stage 3: Eval + SFT ==="
|
||||
echo "Base dir: $NANOCHAT_BASE_DIR"
|
||||
echo "WANDB_RUN: $WANDB_RUN"
|
||||
echo "XSA: ${XSA:-FALSE}"
|
||||
echo "Started: $(date)"
|
||||
|
||||
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
|
||||
|
|
@ -52,4 +53,3 @@ torchrun --standalone --nproc_per_node=2 -m scripts.chat_eval -- -i sft
|
|||
python -m nanochat.report generate
|
||||
|
||||
echo "=== Stage 3 complete: $(date) ==="
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@
|
|||
#
|
||||
# Optional W&B logging:
|
||||
# WANDB_RUN=my-run bash runs/pace_submit.sh
|
||||
#
|
||||
# Optional XSA attention:
|
||||
# XSA=TRUE bash runs/pace_submit.sh
|
||||
|
||||
set -e
|
||||
cd "$HOME/scratch/nanochat"
|
||||
|
|
@ -20,50 +23,53 @@ cd "$HOME/scratch/nanochat"
|
|||
mkdir -p runs/logs
|
||||
|
||||
WANDB_RUN="${WANDB_RUN:-dummy}"
|
||||
XSA="${XSA:-FALSE}"
|
||||
export WANDB_RUN
|
||||
export XSA
|
||||
|
||||
echo "Submitting nanochat full pipeline..."
|
||||
echo "WANDB_RUN=$WANDB_RUN"
|
||||
echo "XSA=$XSA"
|
||||
echo ""
|
||||
|
||||
# Stage 1
|
||||
JOB1=$(sbatch --parsable \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage1_tokenizer.sh)
|
||||
echo "Stage 1 submitted: job $JOB1 (tokenizer + dataset)"
|
||||
|
||||
# Stage 2a
|
||||
JOB2A=$(sbatch --parsable \
|
||||
--dependency=afterok:$JOB1 \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2a_pretrain.sh)
|
||||
echo "Stage 2a submitted: job $JOB2A (pretrain chunk 1, depends on $JOB1)"
|
||||
|
||||
# Stage 2b
|
||||
JOB2B=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2A \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2b_pretrain.sh)
|
||||
echo "Stage 2b submitted: job $JOB2B (pretrain chunk 2, depends on $JOB2A)"
|
||||
|
||||
# Stage 2c
|
||||
JOB2C=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2B \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2c_pretrain.sh)
|
||||
echo "Stage 2c submitted: job $JOB2C (pretrain chunk 3, depends on $JOB2B)"
|
||||
|
||||
# Stage 2d
|
||||
JOB2D=$(sbatch --parsable \
|
||||
--dependency=afterany:$JOB2C \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage2d_pretrain.sh)
|
||||
echo "Stage 2d submitted: job $JOB2D (pretrain chunk 4, depends on $JOB2C)"
|
||||
|
||||
# Stage 3
|
||||
JOB3=$(sbatch --parsable \
|
||||
--dependency=afterok:$JOB2D \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN \
|
||||
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
|
||||
runs/pace_stage3_sft.sh)
|
||||
echo "Stage 3 submitted: job $JOB3 (eval + SFT, depends on $JOB2D)"
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
|
|||
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
|
||||
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
|
||||
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
|
||||
parser.add_argument("--xsa", action="store_true", help="enable XSA")
|
||||
# Training horizon (only one used, in order of precedence)
|
||||
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
|
||||
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
|
||||
|
|
@ -102,6 +103,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
|
|||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if args.xsa:
|
||||
print0("XSA enabled")
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
|
|
@ -136,7 +139,7 @@ def build_model_meta(depth):
|
|||
config = GPTConfig(
|
||||
sequence_len=args.max_seq_len, vocab_size=vocab_size,
|
||||
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
||||
window_pattern=args.window_pattern,
|
||||
window_pattern=args.window_pattern, use_xsa=args.xsa,
|
||||
)
|
||||
with torch.device("meta"):
|
||||
model_meta = GPT(config)
|
||||
|
|
|
|||
|
|
@ -415,6 +415,7 @@ while True:
|
|||
"n_kv_head": model.config.n_kv_head,
|
||||
"n_embd": model.config.n_embd,
|
||||
"window_pattern": model.config.window_pattern,
|
||||
"use_xsa": model.config.use_xsa,
|
||||
},
|
||||
"user_config": user_config, # inputs to the training script
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user