diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index f71524ed..852139f7 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -26,6 +26,9 @@ def _patch_missing_config_keys(model_config_kwargs): if "window_pattern" not in model_config_kwargs: model_config_kwargs["window_pattern"] = "L" log0(f"Patching missing window_pattern in model config to 'L'") + if "use_xsa" not in model_config_kwargs: + model_config_kwargs["use_xsa"] = False + log0(f"Patching missing use_xsa in model config to False") def _patch_missing_keys(model_data, model_config): """Add default values for new parameters that may be missing in old checkpoints.""" diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 07a1eae8..20db8b8a 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -37,6 +37,7 @@ class GPTConfig: # Characters: L=long (full context), S=short (quarter context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long window_pattern: str = "SSSL" + use_xsa: bool = False def norm(x): @@ -70,6 +71,8 @@ class CausalSelfAttention(nn.Module): self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head + self.use_xsa = config.use_xsa + self.xsa = ExclusiveSelfAttention() assert self.n_embd % self.n_head == 0 assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 self.c_q = Linear(self.n_embd, self.n_head * self.head_dim, bias=False) @@ -120,12 +123,22 @@ class CausalSelfAttention(nn.Module): if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) + if self.use_xsa: + y = self.xsa.XSA(y, v) + # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) y = self.c_proj(y) return y +class ExclusiveSelfAttention(nn.Module): + def XSA(self, y, v): + Vn = F.normalize(v, dim=-1) + Z = y - (y * Vn).sum(dim=-1, keepdim=True) * Vn + return Z + + class MLP(nn.Module): def __init__(self, config): super().__init__() diff --git a/runs/pace_stage1_tokenizer.sh b/runs/pace_stage1_tokenizer.sh index 8caa6214..95573de1 100644 --- a/runs/pace_stage1_tokenizer.sh +++ b/runs/pace_stage1_tokenizer.sh @@ -20,6 +20,7 @@ mkdir -p runs/logs echo "=== Stage 1: Tokenizer ===" echo "Base dir: $NANOCHAT_BASE_DIR" +echo "XSA: ${XSA:-FALSE}" echo "Started: $(date)" command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh diff --git a/runs/pace_stage2a_pretrain.sh b/runs/pace_stage2a_pretrain.sh index bee6d11f..732099d8 100644 --- a/runs/pace_stage2a_pretrain.sh +++ b/runs/pace_stage2a_pretrain.sh @@ -21,12 +21,16 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat" mkdir -p runs/logs WANDB_RUN="${WANDB_RUN:-dummy}" +XSA="${XSA:-FALSE}" +XSA_ARG="" +[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa" CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24" DONE_MARKER="$CHECKPOINT_DIR/.training_complete" echo "=== Stage 2a: Pretraining (chunk 1) ===" echo "Base dir: $NANOCHAT_BASE_DIR" echo "WANDB_RUN: $WANDB_RUN" +echo "XSA: $XSA" echo "Started: $(date)" source .venv/bin/activate @@ -36,6 +40,7 @@ torchrun --standalone --nproc_per_node=2 -m scripts.base_train -- \ --target-param-data-ratio=8 \ --device-batch-size=16 \ --save-every=200 \ + $XSA_ARG \ --run=$WANDB_RUN mkdir -p "$CHECKPOINT_DIR" diff --git a/runs/pace_stage2b_pretrain.sh b/runs/pace_stage2b_pretrain.sh index 7b8ea806..7e7ad3be 100644 --- a/runs/pace_stage2b_pretrain.sh +++ b/runs/pace_stage2b_pretrain.sh @@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat" mkdir -p runs/logs WANDB_RUN="${WANDB_RUN:-dummy}" +XSA="${XSA:-FALSE}" +XSA_ARG="" +[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa" CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24" DONE_MARKER="$CHECKPOINT_DIR/.training_complete" echo "=== Stage 2b: Pretraining (chunk 2 / auto-resume) ===" echo "Base dir: $NANOCHAT_BASE_DIR" +echo "XSA: $XSA" echo "Started: $(date)" if [ -f "$DONE_MARKER" ]; then @@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then --target-param-data-ratio=8 \ --device-batch-size=16 \ --save-every=200 \ + $XSA_ARG \ --run=$WANDB_RUN else echo "Resuming from step $LAST_STEP" @@ -60,6 +65,7 @@ else --device-batch-size=16 \ --save-every=200 \ --resume-from-step=$LAST_STEP \ + $XSA_ARG \ --run=$WANDB_RUN fi diff --git a/runs/pace_stage2c_pretrain.sh b/runs/pace_stage2c_pretrain.sh index 92c748e9..5089a444 100644 --- a/runs/pace_stage2c_pretrain.sh +++ b/runs/pace_stage2c_pretrain.sh @@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat" mkdir -p runs/logs WANDB_RUN="${WANDB_RUN:-dummy}" +XSA="${XSA:-FALSE}" +XSA_ARG="" +[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa" CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24" DONE_MARKER="$CHECKPOINT_DIR/.training_complete" echo "=== Stage 2c: Pretraining (chunk 3 / auto-resume) ===" echo "Base dir: $NANOCHAT_BASE_DIR" +echo "XSA: $XSA" echo "Started: $(date)" if [ -f "$DONE_MARKER" ]; then @@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then --target-param-data-ratio=8 \ --device-batch-size=16 \ --save-every=200 \ + $XSA_ARG \ --run=$WANDB_RUN else echo "Resuming from step $LAST_STEP" @@ -60,6 +65,7 @@ else --device-batch-size=16 \ --save-every=200 \ --resume-from-step=$LAST_STEP \ + $XSA_ARG \ --run=$WANDB_RUN fi diff --git a/runs/pace_stage2d_pretrain.sh b/runs/pace_stage2d_pretrain.sh index 70030001..ffe190d6 100644 --- a/runs/pace_stage2d_pretrain.sh +++ b/runs/pace_stage2d_pretrain.sh @@ -21,11 +21,15 @@ export NANOCHAT_BASE_DIR="$HOME/scratch/nanochat" mkdir -p runs/logs WANDB_RUN="${WANDB_RUN:-dummy}" +XSA="${XSA:-FALSE}" +XSA_ARG="" +[ "$XSA" = "TRUE" ] && XSA_ARG="--xsa" CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24" DONE_MARKER="$CHECKPOINT_DIR/.training_complete" echo "=== Stage 2d: Pretraining (chunk 4 / auto-resume) ===" echo "Base dir: $NANOCHAT_BASE_DIR" +echo "XSA: $XSA" echo "Started: $(date)" if [ -f "$DONE_MARKER" ]; then @@ -51,6 +55,7 @@ if [ "$LAST_STEP" -eq 0 ]; then --target-param-data-ratio=8 \ --device-batch-size=16 \ --save-every=200 \ + $XSA_ARG \ --run=$WANDB_RUN else echo "Resuming from step $LAST_STEP" @@ -60,6 +65,7 @@ else --device-batch-size=16 \ --save-every=200 \ --resume-from-step=$LAST_STEP \ + $XSA_ARG \ --run=$WANDB_RUN fi diff --git a/runs/pace_stage3_sft.sh b/runs/pace_stage3_sft.sh index f354f7c4..532c6fa9 100644 --- a/runs/pace_stage3_sft.sh +++ b/runs/pace_stage3_sft.sh @@ -25,6 +25,7 @@ WANDB_RUN="${WANDB_RUN:-dummy}" echo "=== Stage 3: Eval + SFT ===" echo "Base dir: $NANOCHAT_BASE_DIR" echo "WANDB_RUN: $WANDB_RUN" +echo "XSA: ${XSA:-FALSE}" echo "Started: $(date)" CHECKPOINT_DIR="$NANOCHAT_BASE_DIR/base_checkpoints/d24" @@ -52,4 +53,3 @@ torchrun --standalone --nproc_per_node=2 -m scripts.chat_eval -- -i sft python -m nanochat.report generate echo "=== Stage 3 complete: $(date) ===" - diff --git a/runs/pace_submit.sh b/runs/pace_submit.sh index 951d8849..63cc3de2 100644 --- a/runs/pace_submit.sh +++ b/runs/pace_submit.sh @@ -13,6 +13,9 @@ # # Optional W&B logging: # WANDB_RUN=my-run bash runs/pace_submit.sh +# +# Optional XSA attention: +# XSA=TRUE bash runs/pace_submit.sh set -e cd "$HOME/scratch/nanochat" @@ -20,50 +23,53 @@ cd "$HOME/scratch/nanochat" mkdir -p runs/logs WANDB_RUN="${WANDB_RUN:-dummy}" +XSA="${XSA:-FALSE}" export WANDB_RUN +export XSA echo "Submitting nanochat full pipeline..." echo "WANDB_RUN=$WANDB_RUN" +echo "XSA=$XSA" echo "" # Stage 1 JOB1=$(sbatch --parsable \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage1_tokenizer.sh) echo "Stage 1 submitted: job $JOB1 (tokenizer + dataset)" # Stage 2a JOB2A=$(sbatch --parsable \ --dependency=afterok:$JOB1 \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage2a_pretrain.sh) echo "Stage 2a submitted: job $JOB2A (pretrain chunk 1, depends on $JOB1)" # Stage 2b JOB2B=$(sbatch --parsable \ --dependency=afterany:$JOB2A \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage2b_pretrain.sh) echo "Stage 2b submitted: job $JOB2B (pretrain chunk 2, depends on $JOB2A)" # Stage 2c JOB2C=$(sbatch --parsable \ --dependency=afterany:$JOB2B \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage2c_pretrain.sh) echo "Stage 2c submitted: job $JOB2C (pretrain chunk 3, depends on $JOB2B)" # Stage 2d JOB2D=$(sbatch --parsable \ --dependency=afterany:$JOB2C \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage2d_pretrain.sh) echo "Stage 2d submitted: job $JOB2D (pretrain chunk 4, depends on $JOB2C)" # Stage 3 JOB3=$(sbatch --parsable \ --dependency=afterok:$JOB2D \ - --export=ALL,WANDB_RUN=$WANDB_RUN \ + --export=ALL,WANDB_RUN=$WANDB_RUN,XSA=$XSA \ runs/pace_stage3_sft.sh) echo "Stage 3 submitted: job $JOB3 (eval + SFT, depends on $JOB2D)" diff --git a/scripts/base_train.py b/scripts/base_train.py index a161c477..ec037bbb 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -52,6 +52,7 @@ parser.add_argument("--aspect-ratio", type=int, default=64, help="model_dim = de parser.add_argument("--head-dim", type=int, default=128, help="target head dimension for attention") parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length") parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding window pattern tiled across layers: L=full, S=half context (e.g. 'SSL')") +parser.add_argument("--xsa", action="store_true", help="enable XSA") # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") @@ -102,6 +103,8 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", # Flash Attention status from nanochat.flash_attention import USE_FA3 using_fa3 = USE_FA3 +if args.xsa: + print0("XSA enabled") if using_fa3: print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") else: @@ -136,7 +139,7 @@ def build_model_meta(depth): config = GPTConfig( sequence_len=args.max_seq_len, vocab_size=vocab_size, n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim, - window_pattern=args.window_pattern, + window_pattern=args.window_pattern, use_xsa=args.xsa, ) with torch.device("meta"): model_meta = GPT(config) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index b46dd817..e085b119 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -415,6 +415,7 @@ while True: "n_kv_head": model.config.n_kv_head, "n_embd": model.config.n_embd, "window_pattern": model.config.window_pattern, + "use_xsa": model.config.use_xsa, }, "user_config": user_config, # inputs to the training script },