mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 13:45:21 +00:00
Merge pull request #1 from moonmath-ai/backlite-integration
backlite integration
This commit is contained in:
commit
2a8c423150
27
dev/BACKLITE.md
Normal file
27
dev/BACKLITE.md
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
## BackLite (experimental)
|
||||
|
||||
[BackLite](https://github.com/moonmath-ai/BackLite) is an experimental drop-in kernel that can accelerate pretraining by modifying the backward pass. It is currently Hopper-only (H100/H200).
|
||||
|
||||
### Install BackLite
|
||||
|
||||
Clone the BackLite repo into the project root and build the Hopper kernel:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/moonmath-ai/BackLite.git
|
||||
uv pip install --no-build-isolation BackLite/hopper/
|
||||
```
|
||||
|
||||
### Launch a BackLite training run
|
||||
|
||||
Pass `--backlite-negl-prob` to `base_train`:
|
||||
|
||||
```bash
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--run="d24-backlite" \
|
||||
--model-tag="d24_backlite" \
|
||||
--fp8 \
|
||||
--backlite-negl-prob=0.1
|
||||
```
|
||||
|
||||
You should see `✓ BackLite enabled, negl_prob=0.1` in the output.
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
Unified Flash Attention interface with automatic BackLite/FA3/SDPA switching.
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
Priority: BackLite (sparse backward, negl_prob > 0) > FA3 (dense) > SDPA fallback.
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
from nanochat.flash_attention import flash_attn
|
||||
BackLite limitation: does NOT support varlen/packed sequences. For any
|
||||
flash_attn_varlen_func call, use vanilla FA3 directly — never route through here.
|
||||
|
||||
# Training (no KV cache)
|
||||
Usage:
|
||||
from nanochat.flash_attention import flash_attn, set_backlite_negl_prob
|
||||
|
||||
# Training (no KV cache) — BackLite applies when negl_prob > 0
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
|
||||
# Inference (with KV cache)
|
||||
|
|
@ -63,6 +65,46 @@ def _resolve_use_fa3():
|
|||
USE_FA3 = _resolve_use_fa3()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load BackLite on Hopper+ GPUs
|
||||
# =============================================================================
|
||||
def _load_backlite():
|
||||
"""Try to load BackLite sparse-backward attention (Hopper only, requires FA3).
|
||||
|
||||
Returns the module-level flash_attn_func from back_lite, or None.
|
||||
BackLite only covers flash_attn_func (training); kvcache inference always
|
||||
uses vanilla FA3. Does NOT support varlen/packed sequences.
|
||||
"""
|
||||
if not USE_FA3:
|
||||
return None
|
||||
try:
|
||||
import back_lite as _bl
|
||||
return _bl.back_lite.flash_attn_func
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
_backlite_flash_attn_func = _load_backlite()
|
||||
HAS_BACKLITE = _backlite_flash_attn_func is not None
|
||||
|
||||
# Module-level negl_prob state (0.0 = disabled, i.e. vanilla FA3)
|
||||
_backlite_negl_prob: float = 0.0
|
||||
|
||||
|
||||
def set_backlite_negl_prob(p: float) -> None:
|
||||
"""Set the negligible-probability threshold for BackLite sparse backward.
|
||||
|
||||
p=0.0 disables BackLite and falls through to vanilla FA3.
|
||||
"""
|
||||
global _backlite_negl_prob
|
||||
_backlite_negl_prob = p
|
||||
|
||||
|
||||
def get_backlite_negl_prob() -> float:
|
||||
return _backlite_negl_prob
|
||||
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# SDPA helpers
|
||||
# =============================================================================
|
||||
|
|
@ -108,6 +150,9 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
"""
|
||||
Flash Attention for training (no KV cache).
|
||||
|
||||
Routes to BackLite (sparse backward) when negl_prob > 0, else vanilla FA3,
|
||||
else PyTorch SDPA. BackLite does NOT support varlen sequences.
|
||||
|
||||
Args:
|
||||
q, k, v: Tensors of shape (B, T, H, D)
|
||||
causal: Whether to use causal masking
|
||||
|
|
@ -116,6 +161,11 @@ def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
|||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
"""
|
||||
if HAS_BACKLITE and _backlite_negl_prob > 0.0:
|
||||
return _backlite_flash_attn_func(
|
||||
q, k, v, causal=causal, window_size=window_size,
|
||||
negl_prob=_backlite_negl_prob,
|
||||
)
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
|
||||
|
|
@ -133,7 +183,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
"""
|
||||
Flash Attention with KV cache for inference.
|
||||
|
||||
FA3 updates k_cache/v_cache in-place. Our SDPA fallback does the same.
|
||||
FA3/BackLite update k_cache/v_cache in-place. Our SDPA fallback does the same.
|
||||
|
||||
Args:
|
||||
q: Queries, shape (B, T_new, H, D)
|
||||
|
|
@ -146,6 +196,7 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
# BackLite is for training. Inference always uses vanilla FA3
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
|
|
@ -185,3 +236,10 @@ flash_attn = SimpleNamespace(
|
|||
flash_attn_func=flash_attn_func,
|
||||
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"flash_attn",
|
||||
"HAS_FA3", "USE_FA3",
|
||||
"HAS_BACKLITE",
|
||||
"set_backlite_negl_prob", "get_backlite_negl_prob",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA3, HAS_BACKLITE, set_backlite_negl_prob, get_backlite_negl_prob
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -77,6 +77,9 @@ parser.add_argument("--sample-every", type=int, default=2000, help="sample from
|
|||
parser.add_argument("--save-every", type=int, default=-1, help="save checkpoints every N steps (-1 = only at end)")
|
||||
# Output
|
||||
parser.add_argument("--model-tag", type=str, default=None, help="override model tag for checkpoint directory name")
|
||||
# BackLite sparse backward attention
|
||||
parser.add_argument("--backlite-negl-prob", type=float, default=0.0, help="BackLite negligible-probability threshold (0.0 = disabled, use vanilla FA3)")
|
||||
parser.add_argument("--negl-prob-warmup-steps", type=int, default=0, help="keep negl_prob=0 for this many steps before enabling BackLite sparsity")
|
||||
args = parser.parse_args()
|
||||
user_config = vars(args).copy() # for logging
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -99,7 +102,7 @@ print0(f"COMPUTE_DTYPE: {COMPUTE_DTYPE} ({COMPUTE_DTYPE_REASON})")
|
|||
use_dummy_wandb = args.run == "dummy" or not master_process
|
||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
# Flash Attention / BackLite status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
|
|
@ -116,6 +119,20 @@ else:
|
|||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
print0("!" * 80)
|
||||
|
||||
# BackLite setup: enable only after negl_prob_warmup_steps (set to 0 initially if warmup > 0)
|
||||
if args.backlite_negl_prob > 0.0 and not HAS_BACKLITE:
|
||||
print0("WARNING: --backlite-negl-prob set but BackLite not installed; ignoring (pip install BackLite/hopper/)")
|
||||
if HAS_BACKLITE:
|
||||
initial_negl = 0.0 if args.negl_prob_warmup_steps > 0 else args.backlite_negl_prob
|
||||
set_backlite_negl_prob(initial_negl)
|
||||
if args.backlite_negl_prob > 0.0:
|
||||
if args.negl_prob_warmup_steps > 0:
|
||||
print0(f"✓ BackLite available — negl_prob will activate at step {args.negl_prob_warmup_steps} (target={args.backlite_negl_prob})")
|
||||
else:
|
||||
print0(f"✓ BackLite enabled, negl_prob={args.backlite_negl_prob}")
|
||||
else:
|
||||
print0("BackLite available but disabled (--backlite-negl-prob=0.0)")
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer will be useful for evaluation and also we need the vocab size to init the model
|
||||
tokenizer = get_tokenizer()
|
||||
|
|
@ -238,6 +255,18 @@ def disable_fp8(model):
|
|||
for parent, attr_name, fp8_module in fp8_locations:
|
||||
setattr(parent, attr_name, fp8_module)
|
||||
|
||||
# Context manager to temporarily disable BackLite during eval (no backward pass = no benefit)
|
||||
@contextmanager
|
||||
def disable_backlite():
|
||||
saved = get_backlite_negl_prob()
|
||||
if saved > 0.0:
|
||||
set_backlite_negl_prob(0.0)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved > 0.0:
|
||||
set_backlite_negl_prob(saved)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Compile the model
|
||||
|
||||
|
|
@ -404,6 +433,17 @@ print0(f"Tokens / micro-batch / rank: {args.device_batch_size} x {args.max_seq_l
|
|||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
|
||||
_negl_warmup = args.negl_prob_warmup_steps
|
||||
|
||||
# Pre-compile BackLite code path to avoid recompilation spike at warmup activation
|
||||
if not resuming and _negl_warmup > 0 and HAS_BACKLITE and args.backlite_negl_prob > 0.0:
|
||||
set_backlite_negl_prob(args.backlite_negl_prob)
|
||||
model.train()
|
||||
loss = model(x, y) / grad_accum_steps
|
||||
(scaler.scale(loss) if scaler else loss).backward()
|
||||
model.zero_grad(set_to_none=True)
|
||||
set_backlite_negl_prob(0.0)
|
||||
|
||||
# Go!
|
||||
while True:
|
||||
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
|
||||
|
|
@ -414,7 +454,7 @@ while True:
|
|||
model.eval()
|
||||
val_loader = build_val_loader()
|
||||
eval_steps = args.eval_tokens // (args.device_batch_size * args.max_seq_len * ddp_world_size)
|
||||
with disable_fp8(model):
|
||||
with disable_fp8(model), disable_backlite():
|
||||
val_bpb = evaluate_bpb(model, val_loader, eval_steps, token_bytes)
|
||||
print0(f"Step {step:05d} | Validation bpb: {val_bpb:.6f}")
|
||||
if val_bpb < min_val_bpb:
|
||||
|
|
@ -433,7 +473,7 @@ while True:
|
|||
results = {}
|
||||
if args.core_metric_every > 0 and (last_step or (step > 0 and step % args.core_metric_every == 0)):
|
||||
model.eval()
|
||||
with disable_fp8(orig_model):
|
||||
with disable_fp8(orig_model), disable_backlite():
|
||||
results = evaluate_core(orig_model, tokenizer, device, max_per_task=args.core_metric_max_per_task)
|
||||
print0(f"Step {step:05d} | CORE metric: {results['core_metric']:.4f}")
|
||||
wandb_run.log({
|
||||
|
|
@ -460,7 +500,7 @@ while True:
|
|||
engine = Engine(orig_model, tokenizer) # use orig_model to avoid recompilation
|
||||
for prompt in prompts:
|
||||
tokens = tokenizer(prompt, prepend="<|bos|>")
|
||||
with disable_fp8(orig_model):
|
||||
with disable_fp8(orig_model), disable_backlite():
|
||||
sample, _ = engine.generate_batch(tokens, num_samples=1, max_tokens=16, temperature=0)
|
||||
print0(tokenizer.decode(sample[0]))
|
||||
model.train()
|
||||
|
|
@ -575,6 +615,11 @@ while True:
|
|||
first_step_of_run = (step == 0) or (resuming and step == args.resume_from_step)
|
||||
step += 1
|
||||
|
||||
# Activate negl_prob after warmup completes
|
||||
if _negl_warmup > 0 and step == _negl_warmup and HAS_BACKLITE:
|
||||
set_backlite_negl_prob(args.backlite_negl_prob)
|
||||
print0(f"negl_prob warmup complete at step {step}")
|
||||
|
||||
# The garbage collector is sadly a little bit overactive and for some poorly understood reason,
|
||||
# it spends ~500ms scanning for cycles quite frequently, just to end up cleaning up very few tiny objects each time.
|
||||
# So we manually manage and help it out here
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user