Implemented XSA, added to pace scripts

This commit is contained in:
zolopgh 2026-04-26 01:30:26 -04:00
parent eb66bbd4e2
commit a305bd9612
11 changed files with 58 additions and 8 deletions

View File

@ -26,6 +26,9 @@ def _patch_missing_config_keys(model_config_kwargs):
if "window_pattern" not in model_config_kwargs:
model_config_kwargs["window_pattern"] = "L"
log0(f"Patching missing window_pattern in model config to 'L'")
if "use_xsa" not in model_config_kwargs:
model_config_kwargs["use_xsa"] = False
log0(f"Patching missing use_xsa in model config to False")
def _patch_missing_keys(model_data, model_config):
"""Add default values for new parameters that may be missing in old checkpoints."""

View File

@ -37,6 +37,7 @@ class GPTConfig:
# Characters: L=long (full context), S=short (quarter context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
use_xsa: bool = False
def norm(x):
@ -70,6 +71,8 @@ class CausalSelfAttention(nn.Module):
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
self.use_xsa = config.use_xsa
self.xsa = ExclusiveSelfAttention()
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
@ -120,12 +123,22 @@ class CausalSelfAttention(nn.Module):
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
if self.use_xsa:
y = self.xsa.XSA(y, v)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class ExclusiveSelfAttention(nn.Module):
def XSA(self, y, v):
Vn = F.normalize(v, dim=-1)
Z = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn
return Z
class MLP(nn.Module):
def __init__(self, config):
super().__init__()

View File

@ -20,6 +20,7 @@ mkdir -p runs/logs
echo "=== Stage 1: Tokenizer ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "XSA: ${XSA:-FALSE}"
echo "Started: $(date)"
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh

View File

@ -21,12 +21,16 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
mkdir -p runs/logs
WANDB_RUN="${WANDB_RUN:-dummy}"
XSA="${XSA:-FALSE}"
XSA_ARG=""
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
echo "=== Stage 2a: Pretraining (chunk 1) ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "WANDB_RUN: $WANDB_RUN"
echo "XSA: $XSA"
echo "Started: $(date)"
source .venv/bin/activate
@ -36,6 +40,7 @@ torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \
--target-param-data-ratio=8 \
--device-batch-size=16 \
--save-every=200 \
$XSA_ARG \
--run=$WANDB_RUN
mkdir -p "$CHECKPOINT_DIR"

View File

@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
mkdir -p runs/logs
WANDB_RUN="${WANDB_RUN:-dummy}"
XSA="${XSA:-FALSE}"
XSA_ARG=""
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
echo "=== Stage 2b: Pretraining (chunk 2 / auto-resume) ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "XSA: $XSA"
echo "Started: $(date)"
if [ -f "$DONE_MARKER" ]; then
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
--target-param-data-ratio=8 \
--device-batch-size=16 \
--save-every=200 \
$XSA_ARG \
--run=$WANDB_RUN
else
echo "Resuming from step $LAST_STEP"
@ -60,6 +65,7 @@ else
--device-batch-size=16 \
--save-every=200 \
--resume-from-step=$LAST_STEP \
$XSA_ARG \
--run=$WANDB_RUN
fi

View File

@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
mkdir -p runs/logs
WANDB_RUN="${WANDB_RUN:-dummy}"
XSA="${XSA:-FALSE}"
XSA_ARG=""
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
echo "=== Stage 2c: Pretraining (chunk 3 / auto-resume) ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "XSA: $XSA"
echo "Started: $(date)"
if [ -f "$DONE_MARKER" ]; then
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
--target-param-data-ratio=8 \
--device-batch-size=16 \
--save-every=200 \
$XSA_ARG \
--run=$WANDB_RUN
else
echo "Resuming from step $LAST_STEP"
@ -60,6 +65,7 @@ else
--device-batch-size=16 \
--save-every=200 \
--resume-from-step=$LAST_STEP \
$XSA_ARG \
--run=$WANDB_RUN
fi

View File

@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat"
mkdir -p runs/logs
WANDB_RUN="${WANDB_RUN:-dummy}"
XSA="${XSA:-FALSE}"
XSA_ARG=""
[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa"
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
DONE_MARKER="$CHECKPOINT_DIR/.training_complete"
echo "=== Stage 2d: Pretraining (chunk 4 / auto-resume) ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "XSA: $XSA"
echo "Started: $(date)"
if [ -f "$DONE_MARKER" ]; then
@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then
--target-param-data-ratio=8 \
--device-batch-size=16 \
--save-every=200 \
$XSA_ARG \
--run=$WANDB_RUN
else
echo "Resuming from step $LAST_STEP"
@ -60,6 +65,7 @@ else
--device-batch-size=16 \
--save-every=200 \
--resume-from-step=$LAST_STEP \
$XSA_ARG \
--run=$WANDB_RUN
fi

View File

@ -25,6 +25,7 @@ WANDB_RUN="${WANDB_RUN:-dummy}"
echo "=== Stage 3: Eval + SFT ==="
echo "Base dir: $NANOCHAT_BASE_DIR"
echo "WANDB_RUN: $WANDB_RUN"
echo "XSA: ${XSA:-FALSE}"
echo "Started: $(date)"
CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24"
@ -52,4 +53,3 @@ torchrun --standalone --nproc_per_node=2 -m scripts.chat_eval -- -i sft
python -m nanochat.report generate
echo "=== Stage 3 complete: $(date) ==="

View File

@ -13,6 +13,9 @@
#
# Optional W&B logging:
# WANDB_RUN=my-run bash runs/pace_submit.sh
#
# Optional XSA attention:
# XSA=TRUE bash runs/pace_submit.sh
set -e
cd "$HOME/scratch/nanochat"
@ -20,50 +23,53 @@ cd "$HOME/scratch/nanochat"
mkdir -p runs/logs
WANDB_RUN="${WANDB_RUN:-dummy}"
XSA="${XSA:-FALSE}"
export WANDB_RUN
export XSA
echo "Submitting nanochat full pipeline..."
echo "WANDB_RUN=$WANDB_RUN"
echo "XSA=$XSA"
echo ""
# Stage 1
JOB1=$(sbatch --parsable \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage1_tokenizer.sh)
echo "Stage 1 submitted: job $JOB1 (tokenizer + dataset)"
# Stage 2a
JOB2A=$(sbatch --parsable \
--dependency=afterok:$JOB1 \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage2a_pretrain.sh)
echo "Stage 2a submitted: job $JOB2A (pretrain chunk 1, depends on $JOB1)"
# Stage 2b
JOB2B=$(sbatch --parsable \
--dependency=afterany:$JOB2A \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage2b_pretrain.sh)
echo "Stage 2b submitted: job $JOB2B (pretrain chunk 2, depends on $JOB2A)"
# Stage 2c
JOB2C=$(sbatch --parsable \
--dependency=afterany:$JOB2B \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage2c_pretrain.sh)
echo "Stage 2c submitted: job $JOB2C (pretrain chunk 3, depends on $JOB2B)"
# Stage 2d
JOB2D=$(sbatch --parsable \
--dependency=afterany:$JOB2C \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage2d_pretrain.sh)
echo "Stage 2d submitted: job $JOB2D (pretrain chunk 4, depends on $JOB2C)"
# Stage 3
JOB3=$(sbatch --parsable \
--dependency=afterok:$JOB2D \
--export=ALL,WANDB_RUN=$WANDB_RUN \
--export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \
runs/pace_stage3_sft.sh)
echo "Stage 3 submitted: job $JOB3 (eval + SFT, depends on $JOB2D)"

View File

@ -52,6 +52,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de
parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention")
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')")
parser.add_argument("--xsa", action="store_true", help="enable XSA")
# Training horizon (only one used, in order of precedence)
parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)")
parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)")
@ -102,6 +103,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat",
# Flash Attention status
from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if args.xsa:
print0("XSA enabled")
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
@ -136,7 +139,7 @@ def build_model_meta(depth):
config = GPTConfig(
sequence_len=args.max_seq_len, vocab_size=vocab_size,
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
window_pattern=args.window_pattern,
window_pattern=args.window_pattern, use_xsa=args.xsa,
)
with torch.device("meta"):
model_meta = GPT(config)

View File

@ -415,6 +415,7 @@ while True:
"n_kv_head": model.config.n_kv_head,
"n_embd": model.config.n_embd,
"window_pattern": model.config.window_pattern,
"use_xsa": model.config.use_xsa,
},
"user_config": user_config, # inputs to the training script
},