diff --git a/dev/BACKLITE.md b/dev/BACKLITE.md new file mode 100644 index 0000000..8adce53 --- /dev/null +++ b/dev/BACKLITE.md @@ -0,0 +1,28 @@ +## BackLite (experimental) + +[BackLite](https://github.com/moonmath-ai/BackLite) is an experimental drop-in kernel that can accelerate pretraining by modifying the backward pass. It is currently Hopper-only (H100/H200). + +### Install BackLite + +Clone the BackLite repo into the project root and build the Hopper kernel: + +```bash +git clone https://github.com/moonmath-ai/BackLite.git +uv pip install --no-build-isolation BackLite/hopper/ +``` + +### Launch a BackLite training run + +Pass `--backlite-negl-prob` and `--negl-prob-warmup-steps` to `base_train`: + +```bash +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run="d24-backlite" \ + --model-tag="d24_backlite" \ + --fp8 \ + --backlite-negl-prob=0.1 \ + --negl-prob-warmup-steps=100 +``` + +You should see `✓ BackLite enabled, negl_prob=0.1` in the output. diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee3..0735e00 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -1,13 +1,15 @@ """ -Unified Flash Attention interface with automatic FA3/SDPA switching. +Unified Flash Attention interface with automatic BackLite/FA3/SDPA switching. -Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. +Priority: BackLite (sparse backward, negl_prob > 0) > FA3 (dense) > SDPA fallback. -Usage (drop-in replacement for FA3): - from nanochat.flash_attention import flash_attn +BackLite limitation: does NOT support varlen/packed sequences. For any +flash_attn_varlen_func call, use vanilla FA3 directly — never route through here. - # Training (no KV cache) +Usage: + from nanochat.flash_attention import flash_attn, set_backlite_negl_prob + + # Training (no KV cache) — BackLite applies when negl_prob > 0 y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) # Inference (with KV cache) @@ -63,6 +65,46 @@ def _resolve_use_fa3(): USE_FA3 = _resolve_use_fa3() +# ============================================================================= +# Detection: Try to load BackLite on Hopper+ GPUs +# ============================================================================= +def _load_backlite(): + """Try to load BackLite sparse-backward attention (Hopper only, requires FA3). + + Returns the module-level flash_attn_func from back_lite, or None. + BackLite only covers flash_attn_func (training); kvcache inference always + uses vanilla FA3. Does NOT support varlen/packed sequences. + """ + if not USE_FA3: + return None + try: + import back_lite as _bl + return _bl.back_lite.flash_attn_func + except Exception: + return None + + +_backlite_flash_attn_func = _load_backlite() +HAS_BACKLITE = _backlite_flash_attn_func is not None + +# Module-level negl_prob state (0.0 = disabled, i.e. vanilla FA3) +_backlite_negl_prob: float = 0.0 + + +def set_backlite_negl_prob(p: float) -> None: + """Set the negligible-probability threshold for BackLite sparse backward. + + p=0.0 disables BackLite and falls through to vanilla FA3. + """ + global _backlite_negl_prob + _backlite_negl_prob = p + + +def get_backlite_negl_prob() -> float: + return _backlite_negl_prob + + + # ============================================================================= # SDPA helpers # ============================================================================= @@ -108,6 +150,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): """ Flash Attention for training (no KV cache). + Routes to BackLite (sparse backward) when negl_prob > 0, else vanilla FA3, + else PyTorch SDPA. BackLite does NOT support varlen sequences. + Args: q, k, v: Tensors of shape (B, T, H, D) causal: Whether to use causal masking @@ -116,6 +161,11 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): Returns: Output tensor of shape (B, T, H, D) """ + if HAS_BACKLITE and _backlite_negl_prob > 0.0: + return _backlite_flash_attn_func( + q, k, v, causal=causal, window_size=window_size, + negl_prob=_backlite_negl_prob, + ) if USE_FA3: return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) @@ -133,7 +183,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N """ Flash Attention with KV cache for inference. - FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same. + FA3/BackLite update k_cache/v_cache in-place. Our SDPA fallback does the same. Args: q: Queries, shape (B, T_new, H, D) @@ -146,6 +196,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Returns: Output tensor of shape (B, T_new, H, D) """ + # BackLite is for training. Inference always uses vanilla FA3 if USE_FA3: return _fa3.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, @@ -185,3 +236,10 @@ flash_attn = SimpleNamespace( flash_attn_func=flash_attn_func, flash_attn_with_kvcache=flash_attn_with_kvcache, ) + +__all__ = [ + "flash_attn", + "HAS_FA3", "USE_FA3", + "HAS_BACKLITE", + "set_backlite_negl_prob", "get_backlite_negl_prob", +] diff --git a/scripts/base_train.py b/scripts/base_train.py index 86aa770..b3c52e5 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA3, HAS_BACKLITE, set_backlite_negl_prob, get_backlite_negl_prob from scripts.base_eval import evaluate_core print_banner() @@ -77,6 +77,9 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)") # Output parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name") +# BackLite sparse backward attention +parser.add_argument("--backlite-negl-prob", type=float, default=0.0, help="BackLite negligible-probability threshold (0.0 = disabled, use vanilla FA3)") +parser.add_argument("--negl-prob-warmup-steps", type=int, default=0, help="keep negl_prob=0 for this many steps before enabling BackLite sparsity") args = parser.parse_args() user_config = vars(args).copy() # for logging # ----------------------------------------------------------------------------- @@ -99,7 +102,7 @@ print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})") use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) -# Flash Attention status +# Flash Attention / BackLite status from nanochat.flash_attention import USE_FA3 using_fa3 = USE_FA3 if using_fa3: @@ -116,6 +119,20 @@ else: print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") print0("!" * 80) +# BackLite setup: enable only after negl_prob_warmup_steps (set to 0 initially if warmup > 0) +if args.backlite_negl_prob > 0.0 and not HAS_BACKLITE: + print0("WARNING: --backlite-negl-prob set but BackLite not installed; ignoring (pip install BackLite/hopper/)") +if HAS_BACKLITE: + initial_negl = 0.0 if args.negl_prob_warmup_steps > 0 else args.backlite_negl_prob + set_backlite_negl_prob(initial_negl) + if args.backlite_negl_prob > 0.0: + if args.negl_prob_warmup_steps > 0: + print0(f"✓ BackLite available — negl_prob will activate at step {args.negl_prob_warmup_steps} (target={args.backlite_negl_prob})") + else: + print0(f"✓ BackLite enabled, negl_prob={args.backlite_negl_prob}") + else: + print0("BackLite available but disabled (--backlite-negl-prob=0.0)") + # ----------------------------------------------------------------------------- # Tokenizer will be useful for evaluation and also we need the vocab size to init the model tokenizer = get_tokenizer() @@ -238,6 +255,18 @@ def disable_fp8(model): for parent, attr_name, fp8_module in fp8_locations: setattr(parent, attr_name, fp8_module) +# Context manager to temporarily disable BackLite during eval (no backward pass = no benefit) +@contextmanager +def disable_backlite(): + saved = get_backlite_negl_prob() + if saved > 0.0: + set_backlite_negl_prob(0.0) + try: + yield + finally: + if saved > 0.0: + set_backlite_negl_prob(saved) + # ----------------------------------------------------------------------------- # Compile the model @@ -411,6 +440,17 @@ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_l print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") +_negl_warmup = args.negl_prob_warmup_steps + +# Pre-compile BackLite code path to avoid recompilation spike at warmup activation +if not resuming and _negl_warmup > 0 and HAS_BACKLITE and args.backlite_negl_prob > 0.0: + set_backlite_negl_prob(args.backlite_negl_prob) + model.train() + loss = model(x, y) / grad_accum_steps + (scaler.scale(loss) if scaler else loss).backward() + model.zero_grad(set_to_none=True) + set_backlite_negl_prob(0.0) + # Go! while True: last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end @@ -421,7 +461,7 @@ while True: model.eval() val_loader = build_val_loader() eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size) - with disable_fp8(model): + with disable_fp8(model), disable_backlite(): val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes) print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}") if val_bpb < min_val_bpb: @@ -440,7 +480,7 @@ while True: results = {} if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)): model.eval() - with disable_fp8(orig_model): + with disable_fp8(orig_model), disable_backlite(): results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task) print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}") wandb_run.log({ @@ -467,7 +507,7 @@ while True: engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation for prompt in prompts: tokens = tokenizer(prompt, prepend="<|bos|>") - with disable_fp8(orig_model): + with disable_fp8(orig_model), disable_backlite(): sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0) print0(tokenizer.decode(sample[0])) model.train() @@ -582,6 +622,11 @@ while True: first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step) step += 1 + # Activate negl_prob after warmup completes + if _negl_warmup > 0 and step == _negl_warmup and HAS_BACKLITE: + set_backlite_negl_prob(args.backlite_negl_prob) + print0(f"negl_prob warmup complete at step {step}") + # The garbage collector is sadly a little bit overactive and for some poorly understood reason, # it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time. # So we manually manage and help it out here