backlite integration

This commit is contained in:
Omer Shlomovits 2026-03-12 01:05:03 +00:00
parent f068604948
commit 0d50efe69b
3 changed files with 142 additions and 12 deletions

27
dev/BACKLITE.md Normal file
View File

@ -0,0 +1,27 @@
## BackLite (experimental)
[BackLite](https://github.com/moonmath-ai/BackLite) is an experimental drop-in kernel that can accelerate pretraining by modifying the backward pass. It is currently Hopper-only (H100/H200).
### Install BackLite
Clone the BackLite repo into the project root and build the Hopper kernel:
```bash
git clone https://github.com/moonmath-ai/BackLite.git
uv pip install --no-build-isolation BackLite/hopper/
```
### Launch a BackLite training run
Pass `--backlite-negl-prob` to `base_train`:
```bash
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=24 \
--run="d24-backlite" \
--model-tag="d24_backlite" \
--fp8 \
--backlite-negl-prob=0.1
```
You should see `✓ BackLite enabled, negl_prob=0.1` in the output.

View File

@ -1,13 +1,15 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Unified Flash Attention interface with automatic BackLite/FA3/SDPA switching.
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
Priority: BackLite (sparse backward, negl_prob > 0) > FA3 (dense) > SDPA fallback.
Usage (drop-in replacement for FA3):
from nanochat.flash_attention import flash_attn
BackLite limitation: does NOT support varlen/packed sequences. For any
flash_attn_varlen_func call, use vanilla FA3 directly never route through here.
# Training (no KV cache)
Usage:
from nanochat.flash_attention import flash_attn, set_backlite_negl_prob
# Training (no KV cache) — BackLite applies when negl_prob > 0
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Inference (with KV cache)
@ -63,6 +65,46 @@ def _resolve_use_fa3():
USE_FA3 = _resolve_use_fa3()
# =============================================================================
# Detection: Try to load BackLite on Hopper+ GPUs
# =============================================================================
def _load_backlite():
"""Try to load BackLite sparse-backward attention (Hopper only, requires FA3).
Returns the module-level flash_attn_func from back_lite, or None.
BackLite only covers flash_attn_func (training); kvcache inference always
uses vanilla FA3. Does NOT support varlen/packed sequences.
"""
if not USE_FA3:
return None
try:
import back_lite as _bl
return _bl.back_lite.flash_attn_func
except Exception:
return None
_backlite_flash_attn_func = _load_backlite()
HAS_BACKLITE = _backlite_flash_attn_func is not None
# Module-level negl_prob state (0.0 = disabled, i.e. vanilla FA3)
_backlite_negl_prob: float = 0.0
def set_backlite_negl_prob(p: float) -> None:
"""Set the negligible-probability threshold for BackLite sparse backward.
p=0.0 disables BackLite and falls through to vanilla FA3.
"""
global _backlite_negl_prob
_backlite_negl_prob = p
def get_backlite_negl_prob() -> float:
return _backlite_negl_prob
# =============================================================================
# SDPA helpers
# =============================================================================
@ -108,6 +150,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Routes to BackLite (sparse backward) when negl_prob > 0, else vanilla FA3,
else PyTorch SDPA. BackLite does NOT support varlen sequences.
Args:
q, k, v: Tensors of shape (B, T, H, D)
causal: Whether to use causal masking
@ -116,6 +161,11 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
Returns:
Output tensor of shape (B, T, H, D)
"""
if HAS_BACKLITE and _backlite_negl_prob > 0.0:
return _backlite_flash_attn_func(
q, k, v, causal=causal, window_size=window_size,
negl_prob=_backlite_negl_prob,
)
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
@ -133,7 +183,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
"""
Flash Attention with KV cache for inference.
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
FA3/BackLite update k_cache/v_cache in-place. Our SDPA fallback does the same.
Args:
q: Queries, shape (B, T_new, H, D)
@ -146,6 +196,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
Returns:
Output tensor of shape (B, T_new, H, D)
"""
# BackLite is for training. Inference always uses vanilla FA3
if USE_FA3:
return _fa3.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
@ -185,3 +236,10 @@ flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)
__all__ = [
"flash_attn",
"HAS_FA3", "USE_FA3",
"HAS_BACKLITE",
"set_backlite_negl_prob", "get_backlite_negl_prob",
]

View File

@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA3, HAS_BACKLITE, set_backlite_negl_prob, get_backlite_negl_prob
from scripts.base_eval import evaluate_core
print_banner()
@ -77,6 +77,9 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
# Output
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
# BackLite sparse backward attention
parser.add_argument("--backlite-negl-prob", type=float, default=0.0, help="BackLite negligible-probability threshold (0.0 = disabled, use vanilla FA3)")
parser.add_argument("--negl-prob-warmup-steps", type=int, default=0, help="keep negl_prob=0 for this many steps before enabling BackLite sparsity")
args = parser.parse_args()
user_config = vars(args).copy() # for logging
# -----------------------------------------------------------------------------
@ -99,7 +102,7 @@ print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
# Flash Attention / BackLite status
from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3:
@ -116,6 +119,20 @@ else:
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
print0("!" * 80)
# BackLite setup: enable only after negl_prob_warmup_steps (set to 0 initially if warmup > 0)
if args.backlite_negl_prob > 0.0 and not HAS_BACKLITE:
print0("WARNING: --backlite-negl-prob set but BackLite not installed; ignoring (pip install BackLite/hopper/)")
if HAS_BACKLITE:
initial_negl = 0.0 if args.negl_prob_warmup_steps > 0 else args.backlite_negl_prob
set_backlite_negl_prob(initial_negl)
if args.backlite_negl_prob > 0.0:
if args.negl_prob_warmup_steps > 0:
print0(f"✓ BackLite available — negl_prob will activate at step {args.negl_prob_warmup_steps} (target={args.backlite_negl_prob})")
else:
print0(f"✓ BackLite enabled, negl_prob={args.backlite_negl_prob}")
else:
print0("BackLite available but disabled (--backlite-negl-prob=0.0)")
# -----------------------------------------------------------------------------
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
tokenizer = get_tokenizer()
@ -238,6 +255,18 @@ def disable_fp8(model):
for parent, attr_name, fp8_module in fp8_locations:
setattr(parent, attr_name, fp8_module)
# Context manager to temporarily disable BackLite during eval (no backward pass = no benefit)
@contextmanager
def disable_backlite():
saved = get_backlite_negl_prob()
if saved > 0.0:
set_backlite_negl_prob(0.0)
try:
yield
finally:
if saved > 0.0:
set_backlite_negl_prob(saved)
# -----------------------------------------------------------------------------
# Compile the model
@ -404,6 +433,17 @@ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_l
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
_negl_warmup = args.negl_prob_warmup_steps
# Pre-compile BackLite code path to avoid recompilation spike at warmup activation
if not resuming and _negl_warmup > 0 and HAS_BACKLITE and args.backlite_negl_prob > 0.0:
set_backlite_negl_prob(args.backlite_negl_prob)
model.train()
loss = model(x, y) / grad_accum_steps
(scaler.scale(loss) if scaler else loss).backward()
model.zero_grad(set_to_none=True)
set_backlite_negl_prob(0.0)
# Go!
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
@ -414,7 +454,7 @@ while True:
model.eval()
val_loader = build_val_loader()
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
with disable_fp8(model):
with disable_fp8(model), disable_backlite():
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
if val_bpb < min_val_bpb:
@ -433,7 +473,7 @@ while True:
results = {}
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
model.eval()
with disable_fp8(orig_model):
with disable_fp8(orig_model), disable_backlite():
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
wandb_run.log({
@ -460,7 +500,7 @@ while True:
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with disable_fp8(orig_model):
with disable_fp8(orig_model), disable_backlite():
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
print0(tokenizer.decode(sample[0]))
model.train()
@ -575,6 +615,11 @@ while True:
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
step += 1
# Activate negl_prob after warmup completes
if _negl_warmup > 0 and step == _negl_warmup and HAS_BACKLITE:
set_backlite_negl_prob(args.backlite_negl_prob)
print0(f"negl_prob warmup complete at step {step}")
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
# So we manually manage and help it out here