Refactor for FA varlen

Made-with: Cursor
This commit is contained in:
Chris McCormick 2026-03-22 13:16:08 -07:00
parent 5019accc5b
commit 6ee8fd6908
12 changed files with 669 additions and 414 deletions

View File

@ -200,3 +200,36 @@ This commit is special because all of the improvements that went into [this comm
## Run 6
Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes).
## Run 7
Achieved Mar 22, 2026 on commit `a075166`. Launch command (same as `runs/speedrun.sh`):
```
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
--depth=24 \
--run="d24-varlen" \
--model-tag="d24-varlen" \
--device-batch-size=16 \
--sample-every=-1 \
--save-every=-1 \
--core-metric-max-per-task=-1 \
--core-metric-every=999999 \
--target-param-data-ratio=8 \
--fp8
```
Result:
```
step 05567/05568 (99.98%) | loss: 2.388741 | lrm: 0.05 | dt: 1048.30ms | tok/sec: 1,000,264 | bf16_mfu: 60.37 | epoch: 1 pq: 117 rg: 64 | total time: 97.38m | eta: 0.0m
Step 05568 | Validation bpb: 0.724772
Step 05568 | CORE metric: 0.2614
Peak memory usage: 52865.94MiB
Total training time: 97.38m
Minimum validation bpb: 0.724772
```
The big change in this run is switching nanochat's entire pre-training attention path from standard batched attention to FlashAttention's variable-length packed attention (`flash_attn_varlen_func` with `cu_seqlens`). Each document now attends only to itself, which saves compute by not calculating attention scores across document boundaries. The dataloader is also simplified: the old bestfit bin-packing is replaced by greedy 1D packing where only the final document in each micro-batch is truncated. See the PR description in [dev/PULL_REQUEST.md](PULL_REQUEST.md) for full details.
Previous record was 1.65 hours / 99 minutes (Run 6), so 1.62 hours / 97.38 minutes is ~1.8% speed improvement. Model checkpoint: [ChrisMcCormick/nanochat-varlen-d24-2026-03-22](https://huggingface.co/ChrisMcCormick/nanochat-varlen-d24-2026-03-22).

View File

@ -101,15 +101,6 @@ def find_common_length(token_sequences, direction='left'):
return min_len
def stack_sequences(tokens, pad_token_id):
"""Stack up a list of token sequences, pad to longest on the right"""
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
for i, x in enumerate(tokens):
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
return input_ids
def batch_sequences_mc(tokenizer, prompts):
# In multiple choice, contexts are the same but the continuation is different (common prefix)
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
@ -142,26 +133,31 @@ def batch_sequences_lm(tokenizer, prompts):
@torch.no_grad()
def forward_model(model, input_ids):
def forward_model(model, tokens_list, device):
"""
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
The last column of losses is set to nan because we don't have autoregressive targets there.
Pack token sequences into varlen, forward, extract per-sequence losses and predictions.
Returns (losses_list, preds_list) where each element corresponds to one input sequence.
The last element of each loss vector is nan (no autoregressive target there).
"""
batch_size, seq_len = input_ids.size()
outputs = model(input_ids)
# Roll the tensor to the left by one position to get the (autoregressive) target ids
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
# Calculate cross entropy at all positions
losses = torch.nn.functional.cross_entropy(
outputs.view(batch_size * seq_len, -1),
target_ids.view(batch_size * seq_len),
reduction='none'
).view(batch_size, seq_len)
# Set the last column to be nan because there is no autoregressive loss there
losses[:, -1] = float('nan')
# Get the argmax predictions at each position
predictions = outputs.argmax(dim=-1)
return losses, predictions
packed = torch.cat([torch.tensor(t, dtype=torch.long, device=device) for t in tokens_list])
cu_seqlens = torch.zeros(len(tokens_list) + 1, dtype=torch.int32, device=device)
for i, t in enumerate(tokens_list):
cu_seqlens[i + 1] = cu_seqlens[i] + len(t)
outputs = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V)
losses_list = []
preds_list = []
for i in range(len(tokens_list)):
s, e = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
doc_logits = outputs[s:e]
doc_tokens = packed[s:e]
doc_targets = torch.roll(doc_tokens, -1)
doc_losses = torch.nn.functional.cross_entropy(doc_logits[:-1], doc_targets[:-1], reduction='none')
doc_losses = torch.cat([doc_losses, torch.tensor([float('nan')], device=device)])
losses_list.append(doc_losses)
preds_list.append(doc_logits.argmax(dim=-1))
return losses_list, preds_list
@torch.no_grad()
@ -212,26 +208,21 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
new_end_idxs.append(e)
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
# Stack up all the sequences into a batch
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
input_ids = stack_sequences(tokens, pad_token_id)
input_ids = input_ids.to(device)
# Forward the model, get the autoregressive loss and argmax prediction at each token
losses, predictions = forward_model(model, input_ids)
# Forward the model with varlen packing (no padding waste)
losses_list, preds_list = forward_model(model, tokens, device)
# See if the losses/predictions come out correctly
if task_type == 'language_modeling':
# language modeling task is currently always batch size 1
si = start_idxs[0]
ei = end_idxs[0]
# predictions[i] predict input_ids[i+1] autoregressively
predicted_tokens = predictions[0, si-1:ei-1]
actual_tokens = input_ids[0, si:ei]
# preds_list[i][j] predicts tokens[i][j+1] autoregressively
predicted_tokens = preds_list[0][si-1:ei-1]
actual_tokens = torch.tensor(tokens[0][si:ei], device=device)
is_correct = torch.all(predicted_tokens == actual_tokens).item()
elif task_type in ['multiple_choice', 'schema']:
# For MC/schema: find the option with lowest average loss
mean_losses = [losses[i, si-1:ei-1].mean().item()
mean_losses = [losses_list[i][si-1:ei-1].mean().item()
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
pred_idx = mean_losses.index(min(mean_losses))
is_correct = pred_idx == item['gold']

View File

@ -1,11 +1,10 @@
"""
Distributed dataloaders for pretraining.
Distributed dataloader for pretraining.
BOS-aligned bestfit:
- Every row starts with BOS token
- Documents packed using best-fit algorithm to minimize cropping
- When no document fits remaining space, crops a document to fill exactly
- 100% utilization (no padding), ~35% tokens cropped at T=2048
Varlen 1D packing:
- Packs documents into 1D buffer with cu_seqlens for per-document attention isolation
- No cropping, no padding: every token is used exactly once
- Yields (inputs_1d, targets_1d, cu_seqlens) for flash_attn_varlen_func
Compared to the original tokenizing_distributed_data_loader:
BOS-aligned loses ~35% of tokens to cropping, but ensures that
@ -71,31 +70,43 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
epoch += 1
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
tokenizer, B, T, split,
# =============================================================================
# 1D packed varlen dataloader
# =============================================================================
# Packs documents into a single flat buffer of B*T tokens with cu_seqlens marking
# document boundaries for flash_attn_varlen_func. Each document gets its own
# attention context. Greedy packing: documents are added sequentially until the
# buffer is full. Only the last document in each micro-batch gets cropped.
#
# Requires specificying a fixed maximum number of docs supported per batch.
# The dataloader will append additional documents to the final segment if needed,
# resulting in cross-document attention bleeding, but that hasn't been a problem
# in practice.
# It's recommended to keep max_num_docs tight rather than padding it conservatively
# because an oversized `cu_seqlens` tensor will hurt FlashAttention performance
# somewhat.
def tokenizing_distributed_data_loader_varlen(
tokenizer, B, T, split, max_num_docs,
tokenizer_threads=4, tokenizer_batch_size=128,
device="cuda", resume_state_dict=None,
buffer_size=1000
):
"""
BOS-aligned dataloader with Best-Fit Cropping.
1D packed varlen dataloader for use with flash_attn_varlen_func.
Reduces token waste compared to simple greedy cropping by searching a buffer
for documents that fit well, while maintaining 100% utilization (no padding).
Algorithm for each row:
1. From buffered docs, pick the LARGEST doc that fits entirely
2. Repeat until no doc fits
3. When nothing fits, crop a doc to fill remaining space exactly
Key properties:
- Every row starts with BOS
- 100% utilization (no padding, every token is trained on)
- Approximately 35% of all tokens are discarded due to cropping
Yields (inputs, targets, cu_seqlens, state_dict) where:
- inputs: 1D long tensor of shape (B*T,)
- targets: 1D long tensor of shape (B*T,), shifted by 1
- cu_seqlens: int32 tensor of shape (max_num_docs,), cumulative doc lengths
padded with total_tokens for unused slots (ghost segments of length 0)
- state_dict: {"pq_idx", "rg_idx", "epoch"} for checkpoint resume
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
row_capacity = T + 1
total_tokens = B * T
buffer_capacity = total_tokens + 1 # +1 so the last input position has a target
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
bos_token = tokenizer.get_bos_token_id()
doc_buffer = []
@ -105,62 +116,139 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
nonlocal pq_idx, rg_idx, epoch
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
doc_buffer.append(tokens)
doc_buffer.extend(token_lists)
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
# This gives us contiguous views and a single HtoD transfer
# Pre-allocate all buffers once
use_cuda = device == "cuda"
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
cpu_targets = cpu_buffer[B * T:].view(B, T)
inputs = gpu_buffer[:B * T].view(B, T)
targets = gpu_buffer[B * T:].view(B, T)
pack_buffer = torch.empty(buffer_capacity, dtype=torch.long) # 1D packing workspace
cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda)
gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device)
cpu_inputs = cpu_buffer[:total_tokens]
cpu_targets = cpu_buffer[total_tokens:]
inputs = gpu_buffer[:total_tokens]
targets = gpu_buffer[total_tokens:]
cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32)
cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device)
warned = False
warned_seqlen = False
while True:
for row_idx in range(B):
pos = 0
while pos < row_capacity:
# Ensure buffer has documents
while len(doc_buffer) < buffer_size:
refill_buffer()
# Greedily pack documents into a single 1D buffer
pos = 0
doc_count = 0
cu_seqlens_cpu[0] = 0
remaining = row_capacity - pos
while pos < buffer_capacity:
while len(doc_buffer) == 0:
refill_buffer()
# Find largest doc that fits entirely
best_idx = -1
best_len = 0
for i, doc in enumerate(doc_buffer):
doc_len = len(doc)
if doc_len <= remaining and doc_len > best_len:
best_idx = i
best_len = doc_len
doc = doc_buffer.pop(0)
doc_len = min(len(doc), T) # truncate to max_seq_len
remaining = buffer_capacity - pos
use_len = min(doc_len, remaining) # crop last doc to fill exactly
if best_idx >= 0:
doc = doc_buffer.pop(best_idx)
doc_len = len(doc)
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
pos += doc_len
else:
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
doc = doc_buffer.pop(shortest_idx)
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
pos += remaining
pack_buffer[pos:pos + use_len] = torch.tensor(doc[:use_len], dtype=torch.long)
pos += use_len
if doc_count < max_num_docs - 1:
doc_count += 1
cu_seqlens_cpu[doc_count] = min(pos, total_tokens)
else:
if not warned:
print(f"Warning: too many documents for cu_seqlens size ({max_num_docs}), "
f"merging remaining docs (cross-document attention bleeding)")
warned = True
merged_len = min(pos, total_tokens) - cu_seqlens_cpu[doc_count].item()
if merged_len > T and not warned_seqlen:
print(f"Warning: merged segment length ({merged_len}) exceeds max_seq_len ({T}). "
f"Increase max_num_docs to avoid silent attention truncation.")
warned_seqlen = True
# Copy to pinned CPU buffer, then single HtoD transfer
cpu_inputs.copy_(row_buffer[:, :-1])
cpu_targets.copy_(row_buffer[:, 1:])
# Ensure the final document boundary always points to the end of the batch
cu_seqlens_cpu[doc_count] = total_tokens
# Pad remaining cu_seqlens slots (ghost segments of length 0)
cu_seqlens_cpu[doc_count + 1:] = total_tokens
# Split into inputs/targets (standard next-token prediction shift)
cpu_inputs.copy_(pack_buffer[:total_tokens])
cpu_targets.copy_(pack_buffer[1:total_tokens + 1])
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
# Single HtoD copy into persistent GPU buffer and yield
# H2D transfer: single copy for tokens, small copy for cu_seqlens
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
yield inputs, targets, state_dict
cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda)
yield inputs, targets, cu_seqlens_gpu, state_dict
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
"""Helper that omits state_dict from yields."""
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
yield inputs, targets
# =============================================================================
# SFT varlen dataloader (replay from pre-packed batch plans)
# =============================================================================
def sft_data_loader_varlen(
conversations, batch_plan, B, T, max_num_docs, bos_token,
device="cuda", cycle=False,
):
"""
Replay dataloader for SFT: constructs 1D-packed varlen batches from
pre-computed batch plans (see tokenize_and_pack_sft in chat_sft.py).
Args:
conversations: list of (ids, mask) tuples (pre-tokenized)
batch_plan: list of lists of conversation indices
B, T: batch dimensions (total_tokens = B * T)
max_num_docs: cu_seqlens tensor size (exact max from pre-packing)
bos_token: BOS token id for padding
device: target device
cycle: if True, repeat the batch plan indefinitely (for val eval)
"""
total_tokens = B * T
buffer_capacity = total_tokens + 1
use_cuda = torch.device(device).type == "cuda"
pack_buffer = torch.empty(buffer_capacity, dtype=torch.long)
mask_buffer = torch.empty(buffer_capacity, dtype=torch.int8)
cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda)
gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device)
cpu_inputs = cpu_buffer[:total_tokens]
cpu_targets = cpu_buffer[total_tokens:]
inputs = gpu_buffer[:total_tokens]
targets = gpu_buffer[total_tokens:]
cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32)
cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device)
while True:
for conv_indices in batch_plan:
pos = 0
doc_count = 0
cu_seqlens_cpu[0] = 0
for conv_idx in conv_indices:
ids, mask = conversations[conv_idx]
conv_len = len(ids)
pack_buffer[pos:pos + conv_len] = torch.tensor(ids, dtype=torch.long)
mask_buffer[pos:pos + conv_len] = torch.tensor(mask, dtype=torch.int8)
pos += conv_len
doc_count += 1
cu_seqlens_cpu[doc_count] = min(pos, total_tokens)
if pos < buffer_capacity:
remaining = buffer_capacity - pos
pack_buffer[pos:pos + remaining] = bos_token
mask_buffer[pos:pos + remaining] = 0
doc_count += 1
cu_seqlens_cpu[doc_count] = total_tokens
cu_seqlens_cpu[doc_count + 1:] = total_tokens
cpu_inputs.copy_(pack_buffer[:total_tokens])
cpu_targets.copy_(pack_buffer[1:total_tokens + 1])
target_mask = mask_buffer[1:total_tokens + 1]
cpu_targets[target_mask == 0] = -1
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda)
yield inputs, targets, cu_seqlens_gpu
if not cycle:
break

View File

@ -1,66 +1,86 @@
"""
Unified Flash Attention interface with automatic FA3/SDPA switching.
Unified Flash Attention interface with three-tier automatic backend selection:
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
FA3 (Hopper sm90) -> FA2 (Ampere sm80 / Ada sm89) -> PyTorch SDPA fallback
Usage (drop-in replacement for FA3):
Exports `flash_attn` module with two functions:
Usage:
from nanochat.flash_attention import flash_attn
# Training (no KV cache)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
# Training with packed variable-length sequences: q, k, v are (total_tokens, H, D)
y = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
# Inference (with KV cache)
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
All non-cached forward passes go through varlen. (B, T) callers that don't provide
cu_seqlens get it auto-constructed in GPT.forward.
FA3 and FA2 both support flash_attn_varlen_func with per-document attention isolation.
The SDPA fallback reshapes to (B, T_seq) and uses is_causal=True -- no doc isolation,
but efficient kernels on all hardware (Mac, CPU, Blackwell, older GPUs).
"""
import torch
import torch.nn.functional as F
# =============================================================================
# Detection: Try to load FA3 on Hopper+ GPUs
# Detection: Try FA3 (Hopper), then FA2 (Ampere/Ada), then SDPA fallback
# =============================================================================
def _load_flash_attention_3():
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
def _load_flash_attention():
"""Try to load Flash Attention kernels. Returns (module, version_string) or (None, None)."""
if not torch.cuda.is_available():
return None
return None, None
try:
major, _ = torch.cuda.get_device_capability()
# FA3 kernels are compiled for Hopper (sm90) only
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
if major != 9:
return None
import os
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
from kernels import get_kernel
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
# FA3: Hopper (sm90) only
if major == 9:
try:
return get_kernel('varunneal/flash-attention-3').flash_attn_interface, 'fa3'
except Exception:
pass
# FA2: Ampere (sm80), Ada (sm89), and Hopper fallback
if major >= 8:
try:
return get_kernel('kernels-community/flash-attn2').flash_attn_interface, 'fa2'
except Exception:
pass
except Exception:
return None
pass
return None, None
_fa3 = _load_flash_attention_3()
HAS_FA3 = _fa3 is not None
_fa, FA_VERSION = _load_flash_attention()
HAS_FA = _fa is not None
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
_override_impl = None
def _resolve_use_fa3():
"""Decide once whether to use FA3, based on availability, override, and dtype."""
if _override_impl == 'fa3':
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
def _resolve_use_fa():
"""Decide once whether to use FA, based on availability, override, and dtype."""
if _override_impl in ('fa3', 'fa2', 'fa'):
assert HAS_FA, "Cannot override to FA: not available on this hardware"
return True
if _override_impl == 'sdpa':
return False
if HAS_FA3:
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
if HAS_FA:
from nanochat.common import COMPUTE_DTYPE
if COMPUTE_DTYPE == torch.bfloat16:
return True
return False
if FA_VERSION == 'fa3':
# FA3 Hopper kernels only support bf16 and fp8
return COMPUTE_DTYPE == torch.bfloat16
else:
# FA2 supports bf16 and fp16
return COMPUTE_DTYPE in (torch.bfloat16, torch.float16)
return False
USE_FA3 = _resolve_use_fa3()
USE_FA = _resolve_use_fa()
# =============================================================================
@ -101,31 +121,57 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
def _sdpa_varlen_attention(q, k, v, max_seqlen, window_size, enable_gqa):
"""
SDPA fallback for varlen: reshapes packed (T, H, D) to (B, T_seq, H, D)
and uses standard causal SDPA. No document isolation (cross-doc bleeding
within each T_seq chunk), but uses efficient is_causal=True kernels.
"""
T, H, D = q.shape
H_kv = k.shape[1]
B = T // max_seqlen
q = q.view(B, max_seqlen, H, D).transpose(1, 2)
k = k.view(B, max_seqlen, H_kv, D).transpose(1, 2)
v = v.view(B, max_seqlen, H_kv, D).transpose(1, 2)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2).reshape(T, H, D)
# =============================================================================
# Public API: Same interface as FA3
# =============================================================================
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k,
causal=False, window_size=(-1, -1)):
"""
Flash Attention for training (no KV cache).
Flash Attention for packed variable-length sequences (training, no KV cache).
1D packed inputs where multiple documents are concatenated into one buffer.
Each document attends only to itself, with boundaries defined by cu_seqlens.
Args:
q, k, v: Tensors of shape (B, T, H, D)
causal: Whether to use causal masking
q, k, v: Tensors of shape (total_tokens, H, D)
cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths, shape (max_num_seqs,).
Format: [0, end_doc1, end_doc2, ..., total, total, ...]
max_seqlen_q, max_seqlen_k: Max individual sequence length (FA3 tiling hint).
causal: Whether to use causal masking.
window_size: (left, right) sliding window. -1 means unlimited.
Returns:
Output tensor of shape (B, T, H, D)
Output tensor of shape (total_tokens, H, D)
"""
if USE_FA3:
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
if USE_FA:
return _fa.flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
causal=causal, window_size=window_size,
)
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# SDPA fallback: reshape to (B, T_seq) and use standard causal SDPA (no doc isolation)
enable_gqa = q.size(1) != k.size(1)
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
return y.transpose(1, 2) # back to (B, T, H, D)
return _sdpa_varlen_attention(q, k, v, max_seqlen_q, window_size, enable_gqa)
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
@ -146,8 +192,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
Returns:
Output tensor of shape (B, T_new, H, D)
"""
if USE_FA3:
return _fa3.flash_attn_with_kvcache(
if USE_FA and hasattr(_fa, 'flash_attn_with_kvcache'):
return _fa.flash_attn_with_kvcache(
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
causal=causal, window_size=window_size
)
@ -178,10 +224,10 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
# =============================================================================
# Export: flash_attn module interface (drop-in replacement for FA3)
# Export: flash_attn module interface (drop-in replacement for FA3/FA2)
# =============================================================================
from types import SimpleNamespace
flash_attn = SimpleNamespace(
flash_attn_func=flash_attn_func,
flash_attn_varlen_func=flash_attn_varlen_func,
flash_attn_with_kvcache=flash_attn_with_kvcache,
)

View File

@ -79,7 +79,7 @@ class CausalSelfAttention(nn.Module):
self.ve_gate_channels = 12
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
def forward(self, x, ve, cos_sin, window_size, kv_cache):
def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@ -103,10 +103,7 @@ class CausalSelfAttention(nn.Module):
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
if kv_cache is None:
# Training: causal attention with optional sliding window
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
else:
if kv_cache is not None:
# Inference: use flash_attn_with_kvcache which handles cache management
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
y = flash_attn.flash_attn_with_kvcache(
@ -119,6 +116,15 @@ class CausalSelfAttention(nn.Module):
# Advance position after last layer processes
if self.layer_idx == kv_cache.n_layers - 1:
kv_cache.advance(T)
else:
# Varlen: packed 1D sequence with per-document attention isolation
assert cu_seqlens is not None
y = flash_attn.flash_attn_varlen_func(
q[0], k[0], v[0],
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len,
causal=True, window_size=window_size)
y = y.unsqueeze(0)
# Re-assemble the heads and project back to residual stream
y = y.contiguous().view(B, T, -1)
@ -145,8 +151,8 @@ class Block(nn.Module):
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, ve, cos_sin, window_size, kv_cache):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None):
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache, cu_seqlens, max_seq_len)
x = x + self.mlp(norm(x))
return x
@ -189,10 +195,11 @@ class GPT(nn.Module):
kv_dim = config.n_kv_head * head_dim
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
# In the future we can dynamically grow the cache, for now it's fine.
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
# Rotary embeddings are small in memory, so we over-compute generously. With varlen
# training the full micro-batch is one sequence (T = batch_size * seq_len), so we need
# enough headroom for that. 64X covers batch sizes up to 64, and the assert in forward
# will catch if we ever exceed.
self.rotary_seq_len = config.sequence_len * 64
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
@ -408,7 +415,18 @@ class GPT(nn.Module):
group["initial_lr"] = group["lr"]
return optimizer
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
def forward(self, idx, targets=None, cu_seqlens=None, kv_cache=None, loss_reduction='mean'):
if cu_seqlens is not None:
assert idx.ndim == 1
idx = idx.unsqueeze(0)
if targets is not None:
targets = targets.unsqueeze(0)
max_seq_len = self.config.sequence_len
elif kv_cache is not None:
max_seq_len = None
else:
raise ValueError("GPT.forward requires either cu_seqlens or kv_cache")
B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
@ -451,7 +469,7 @@ class GPT(nn.Module):
for i, block in enumerate(self.transformer.h):
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache, cu_seqlens, max_seq_len)
if i == backout_layer:
x_backout = x
# Subtract mid-layer residual to remove low-level features before logit projection
@ -489,10 +507,11 @@ class GPT(nn.Module):
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
ids = torch.tensor(tokens, dtype=torch.long, device=device) # 1D, no batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
cu_seqlens = torch.tensor([0, ids.size(0)], dtype=torch.int32, device=device)
logits = self.forward(ids, cu_seqlens=cu_seqlens) # (1, T, vocab_size)
logits = logits[:, -1, :] # (1, vocab_size)
if top_k is not None and top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
@ -502,6 +521,6 @@ class GPT(nn.Module):
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
ids = torch.cat((ids, next_ids.squeeze(0))) # stay 1D
token = next_ids.item()
yield token

View File

@ -29,8 +29,8 @@ def evaluate_bpb(model, batches, steps, token_bytes):
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
batch_iter = iter(batches)
for _ in range(steps):
x, y = next(batch_iter)
loss2d = model(x, y, loss_reduction='none') # (B, T)
x, y, cu_seqlens, *_ = next(batch_iter)
loss2d = model(x, y, cu_seqlens=cu_seqlens, loss_reduction='none')
loss2d = loss2d.view(-1) # flatten
y = y.view(-1) # flatten
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32

View File

@ -21,6 +21,7 @@ Examples:
"""
import os
import csv
import math
import time
import json
import yaml
@ -35,7 +36,7 @@ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir,
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
from nanochat.dataloader import tokenizing_distributed_data_loader_varlen
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
@ -48,7 +49,9 @@ class ModelWrapper:
self.model = model
self.max_seq_len = max_seq_len
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
def __call__(self, input_ids, targets=None, cu_seqlens=None, loss_reduction='mean'):
if cu_seqlens is not None:
return self._forward_varlen(input_ids, targets, cu_seqlens, loss_reduction)
logits = self.model(input_ids).logits
if targets is None:
return logits
@ -60,6 +63,48 @@ class ModelWrapper:
)
return loss
def _forward_varlen(self, input_ids, targets, cu_seqlens, loss_reduction):
"""Unpack 1D varlen to padded (B, T), forward through HF model, repack to (1, total_T, V)."""
device = input_ids.device
doc_lens = cu_seqlens[1:] - cu_seqlens[:-1]
num_docs = (doc_lens > 0).sum().item()
max_doc_len = doc_lens.max().item()
# Unpack: 1D -> padded (num_docs, max_doc_len)
batched = torch.zeros(num_docs, max_doc_len, dtype=input_ids.dtype, device=device)
doc_idx = 0
for i in range(len(doc_lens)):
length = doc_lens[i].item()
if length == 0:
continue
start = cu_seqlens[i].item()
batched[doc_idx, :length] = input_ids[start:start + length]
doc_idx += 1
logits = self.model(batched).logits # (num_docs, max_doc_len, V)
# Repack: padded (num_docs, max_doc_len, V) -> (1, total_T, V)
total_T = input_ids.size(0)
packed = torch.empty(1, total_T, logits.size(-1), dtype=logits.dtype, device=device)
doc_idx = 0
for i in range(len(doc_lens)):
length = doc_lens[i].item()
if length == 0:
continue
start = cu_seqlens[i].item()
packed[0, start:start + length] = logits[doc_idx, :length]
doc_idx += 1
if targets is None:
return packed
loss = torch.nn.functional.cross_entropy(
packed.view(-1, packed.size(-1)),
targets.view(-1),
ignore_index=-1,
reduction=loss_reduction
)
return loss
def get_device(self):
return next(self.model.parameters()).device
@ -269,8 +314,10 @@ def main():
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
steps = args.split_tokens // tokens_per_step
avg_num_docs = args.device_batch_size * sequence_len // 400
max_num_docs = math.ceil(avg_num_docs / 16) * 16
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, sequence_len, split_name, max_num_docs=max_num_docs, device=device)
bpb = evaluate_bpb(model, loader, steps, token_bytes)
bpb_results[split_name] = bpb
print0(f"{split_name} bpb: {bpb:.6f}")

View File

@ -26,13 +26,13 @@ import torch
import torch.distributed as dist
from nanochat.gpt import GPT, GPTConfig, Linear
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
from nanochat.dataloader import tokenizing_distributed_data_loader_varlen
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA, FA_VERSION
from scripts.base_eval import evaluate_core
print_banner()
@ -100,17 +100,19 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
# Flash Attention status
from nanochat.flash_attention import USE_FA3
using_fa3 = USE_FA3
if using_fa3:
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
from nanochat.flash_attention import USE_FA
if USE_FA:
if FA_VERSION == 'fa3':
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
else:
print0(f"✓ Using Flash Attention 2 (Ampere/Ada GPU detected).")
else:
print0("!" * 80)
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
if HAS_FA and COMPUTE_DTYPE != torch.bfloat16:
print0(f"WARNING: Flash Attention only supports bf16/fp16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
else:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without FA3")
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback")
print0("WARNING: Training will be less efficient without Flash Attention")
if args.window_pattern != "L":
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
@ -326,10 +328,13 @@ if scaler is not None:
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
# max_num_docs: see dataloader.py for explanation
avg_num_docs = args.device_batch_size * args.max_seq_len // 400
max_num_docs = math.ceil(avg_num_docs / 16) * 16
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
train_loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="train", max_num_docs=max_num_docs, device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="val", max_num_docs=max_num_docs, device=device)
x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Calculate the number of iterations we will train for and set up the various schedulers
@ -507,14 +512,14 @@ while True:
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
loss = model(x, y)
loss = model(x, y, cu_seqlens=cu_seqlens)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer
lrm = get_lr_multiplier(step)
muon_momentum = get_muon_momentum(step)

View File

@ -89,7 +89,6 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
device = model.get_device()
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
# We'll process batches of independent problems at a time because there is no sampling needed
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
@ -102,17 +101,18 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
for i in range(ddp_rank, num_batches, ddp_world_size):
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
# Prepare the batch of problems and pack into varlen (no padding waste)
conversations = [task_object[ii] for ii in range(i0, i1)]
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
max_length = max(len(ids) for ids in prompt_ids)
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
packed = torch.cat([torch.tensor(ids, dtype=torch.long, device=device) for ids in prompt_ids])
cu_seqlens = torch.zeros(len(prompt_ids) + 1, dtype=torch.int32, device=device)
for j, ids in enumerate(prompt_ids):
cu_seqlens[j + 1] = cu_seqlens[j] + len(ids)
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
with torch.no_grad():
logits = model(prompt_ids) # (B, T, V)
logits = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V)
# Focus on the available answer on just the letters corresponding to choices
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
@ -130,7 +130,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
letter_ids.append(letter_to_id_cache[letter])
# focus logits just down to the answer position and the available letters of the answer
answer_pos = answer_time_positions[idx]
focus_logits = logits[idx, answer_pos, letter_ids]
focus_logits = logits[cu_seqlens[idx].item() + answer_pos, letter_ids]
# get the argmax letter (the predicted answer)
argmax_letter_id = focus_logits.argmax(dim=-1).item()
predicted_letter = letters[argmax_letter_id]

View File

@ -84,7 +84,6 @@ print0(f"Calculated number of steps: {num_steps}")
@torch.no_grad()
def get_batch():
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
for example_idx in itertools.cycle(rank_indices):
@ -125,25 +124,30 @@ def get_batch():
reward = train_task.reward(conversation, generated_text)
rewards.append(reward)
# Pad the sequences so that their lengths (in time) match
max_length = max(len(seq) for seq in generated_token_sequences)
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
# Stack up the sequences and masks into PyTorch tensors
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
# Generate autoregressive inputs and targets to the Transformer
inputs = ids[:, :-1]
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
# Pack the sequences into a 1D buffer (varlen packing, no padding needed)
# Generate autoregressive inputs and targets for each sequence
input_seqs = []
target_seqs = []
for seq, mask in zip(generated_token_sequences, masks):
seq_t = torch.tensor(seq, dtype=torch.long, device=device)
mask_t = torch.tensor(mask, dtype=torch.long, device=device)
input_seqs.append(seq_t[:-1])
tgt = seq_t[1:].clone() # clone to avoid in-place modification:
tgt[mask_t[1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
target_seqs.append(tgt)
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
# Concatenate the sequences and masks into 1D PyTorch tensors
inputs = torch.cat(input_seqs)
targets = torch.cat(target_seqs)
lengths = torch.tensor([len(s) for s in input_seqs], dtype=torch.int32, device=device)
cu_seqlens = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]).to(torch.int32)
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
# Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
mu = rewards.mean()
advantages = rewards - mu
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
yield generated_token_sequences, inputs, targets, rewards, advantages
# yield packed 1D inputs/targets with cu_seqlens, and rewards/advantages as (B,) of floats
yield generated_token_sequences, inputs, targets, cu_seqlens, rewards, advantages
# -----------------------------------------------------------------------------
# Simple evaluation loop for GSM8K pass@k
@ -247,23 +251,31 @@ for step in range(num_steps):
sequence_lengths = []
for example_step in range(examples_per_rank):
# Get one batch corresponding to one example in the training dataset
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
sequences_all, inputs_all, targets_all, cu_seqlens_all, rewards_all, advantages_all = next(batch_iterator)
# Evaluate the loss and gradients
model.train() # ensure the model is in train mode
# We need one more loop because we can never exceed the device_batch_size
assert inputs_all.size(0) % args.device_batch_size == 0
num_passes = inputs_all.size(0) // args.device_batch_size
num_seqs = len(cu_seqlens_all) - 1
assert num_seqs % args.device_batch_size == 0
num_passes = num_seqs // args.device_batch_size
for pass_idx in range(num_passes):
# Pluck out the batch for this pass
# Pluck out the sub-batch for this pass from the packed 1D buffer
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
inputs = inputs_all[b0:b1]
targets = targets_all[b0:b1]
t0, t1 = cu_seqlens_all[b0].item(), cu_seqlens_all[b1].item()
inputs = inputs_all[t0:t1]
targets = targets_all[t0:t1]
cu_seqlens = cu_seqlens_all[b0:b1+1] - cu_seqlens_all[b0]
rewards = rewards_all[b0:b1]
advantages = advantages_all[b0:b1]
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
logp = -model(inputs, targets, cu_seqlens=cu_seqlens, loss_reduction='none') # (T_sub,)
# Expand per-sequence advantages to per-token positions using cu_seqlens boundaries
token_advantages = torch.zeros_like(logp)
for i in range(b1 - b0):
s, e = cu_seqlens[i].item(), cu_seqlens[i+1].item()
token_advantages[s:e] = advantages[i]
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
pg_obj = (logp * token_advantages).sum()
# normalize by the number of valid tokens, number of passes, and examples_per_rank
num_valid = (targets >= 0).sum().clamp(min=1)
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)

View File

@ -10,6 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-s
"""
import gc
import math
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@ -21,7 +22,8 @@ from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
from nanochat.flash_attention import HAS_FA
from nanochat.dataloader import sft_data_loader_varlen
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
@ -89,8 +91,8 @@ use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Flash Attention status
if not HAS_FA3:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
if not HAS_FA:
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.")
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
@ -178,147 +180,137 @@ val_dataset = TaskMixture([
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
]) # total: 24K + 14K + 1.32K ~= 39K rows
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
# A big problem is that we don't know the final num_iterations in advance. So we create
# these two global variables and update them from within the data generator.
last_step = False # we will toggle this to True when we reach the end of the training dataset
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
current_epoch = 1 # track epoch for logging
def sft_data_generator_bos_bestfit(split, buffer_size=100):
"""
BOS-aligned dataloader for SFT with bestfit-pad packing.
Each row in the batch starts with BOS (beginning of a conversation).
Conversations are packed using best-fit algorithm. When no conversation fits,
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
# Pre-tokenize and pre-pack all conversations into batch plans.
# This runs the same best-fit packing algorithm offline at startup, so we know
# num_iterations and max_num_docs exactly before training starts.
def tokenize_and_pack_sft(dataset, tokenizer, B, T, bos_token, ddp_rank, ddp_world_size, buffer_size=100):
"""
Pre-tokenize and pre-pack SFT conversations using best-fit packing.
Preserves TaskMixture's shuffled ordering (no length sorting) and uses the
same buffer-based best-fit algorithm as the original inline dataloader.
Returns (conversations, batch_plans, total_micro_batches, max_num_docs).
"""
global last_step, approx_progress, current_epoch
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
row_capacity = args.max_seq_len + 1 # +1 for target at last position
bos_token = tokenizer.get_bos_token_id()
buffer_capacity = B * T + 1
# Conversation buffer: list of (token_ids, loss_mask) tuples
conversations = []
num_convs = (dataset_size - ddp_rank + ddp_world_size - 1) // ddp_world_size
cursor = ddp_rank
while cursor < dataset_size:
ids, mask = tokenizer.render_conversation(dataset[cursor])
conversations.append((ids, mask))
cursor += ddp_world_size
if len(conversations) % 5000 == 0:
print0(f"\r\033[KTokenizing: {len(conversations):,}/{num_convs:,} ({100*len(conversations)/num_convs:.0f}%)", end='', flush=True)
print0(f"\r\033[KTokenized {len(conversations):,} conversations", flush=True)
batch_plans = []
conv_buffer = []
cursor = ddp_rank # Each rank processes different conversations (for fetching)
consumed = ddp_rank # Track actual consumption separately from buffering
epoch = 1
it = 0 # iteration counter
fetch_cursor = 0
max_doc_count = 0
def refill_buffer():
nonlocal cursor, epoch
while len(conv_buffer) < buffer_size:
conversation = dataset[cursor]
ids, mask = tokenizer.render_conversation(conversation)
conv_buffer.append((ids, mask))
cursor += ddp_world_size
if cursor >= dataset_size:
cursor = cursor % dataset_size
epoch += 1
# Note: last_step is now triggered based on consumption, not fetching
def refill():
nonlocal fetch_cursor
while len(conv_buffer) < buffer_size and fetch_cursor < len(conversations):
conv_buffer.append(fetch_cursor)
fetch_cursor += 1
while True:
rows = []
mask_rows = []
row_lengths = [] # Track actual content length (excluding padding) for each row
for _ in range(args.device_batch_size):
row = []
mask_row = []
padded = False
while len(row) < row_capacity:
# Ensure buffer has conversations
while len(conv_buffer) < buffer_size:
refill_buffer()
remaining = row_capacity - len(row)
# Find largest conversation that fits entirely
best_idx = -1
best_len = 0
for i, (conv, _) in enumerate(conv_buffer):
conv_len = len(conv)
if conv_len <= remaining and conv_len > best_len:
best_idx = i
best_len = conv_len
if best_idx >= 0:
# Found a conversation that fits - use it entirely
conv, conv_mask = conv_buffer.pop(best_idx)
row.extend(conv)
mask_row.extend(conv_mask)
consumed += ddp_world_size # Track actual consumption
else:
# No conversation fits - pad the remainder instead of cropping
# This ensures we never discard any tokens
content_len = len(row)
row.extend([bos_token] * remaining) # Pad with BOS tokens
mask_row.extend([0] * remaining)
padded = True
break # Row is now full (with padding)
# Track content length: full row if no padding, otherwise the length before padding
if padded:
row_lengths.append(content_len)
refill()
if not conv_buffer:
break
batch_indices = []
pos = 0
while pos < buffer_capacity:
refill()
if not conv_buffer:
break
remaining = buffer_capacity - pos
best_buf_idx = -1
best_len = 0
for i, conv_idx in enumerate(conv_buffer):
conv_len = len(conversations[conv_idx][0])
if conv_len <= remaining and conv_len > best_len:
best_buf_idx = i
best_len = conv_len
if best_buf_idx >= 0:
batch_indices.append(conv_buffer.pop(best_buf_idx))
pos += best_len
else:
row_lengths.append(row_capacity)
rows.append(row[:row_capacity])
mask_rows.append(mask_row[:row_capacity])
break
if batch_indices:
doc_count = len(batch_indices) + (1 if pos < buffer_capacity else 0)
max_doc_count = max(max_doc_count, doc_count)
batch_plans.append(batch_indices)
# Stopping condition to respect num_iterations, if given
it += 1
if 0 < args.num_iterations <= it and split == "train":
last_step = True
max_num_docs = math.ceil((max_doc_count + 1) / 16) * 16
return conversations, batch_plans, len(batch_plans), max(max_num_docs, 16)
# Update progress tracking (based on consumed, not cursor, to account for buffering)
if split == "train":
current_epoch = epoch
if args.num_iterations > 0:
approx_progress = it / args.num_iterations
else:
approx_progress = consumed / dataset_size
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
if consumed >= dataset_size:
last_step = True
bos_token = tokenizer.get_bos_token_id()
t_pack_start = time.time()
train_convs, train_plans, train_micro_batches, train_max_docs = tokenize_and_pack_sft(
train_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
bos_token, ddp_rank, ddp_world_size)
t_pack_train = time.time()
val_convs, val_plans, val_micro_batches, val_max_docs = tokenize_and_pack_sft(
val_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
bos_token, ddp_rank, ddp_world_size)
t_pack_val = time.time()
max_num_docs = max(train_max_docs, val_max_docs)
print0(f"Pre-tokenize & pack: train {t_pack_train - t_pack_start:.1f}s, val {t_pack_val - t_pack_train:.1f}s, total {t_pack_val - t_pack_start:.1f}s")
# Build tensors
use_cuda = device_type == "cuda"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
# Document length and packing statistics
import numpy as np
train_doc_lens = [len(ids) for ids, _ in train_convs]
train_docs_per_batch = [len(plan) for plan in train_plans]
train_tokens_per_batch = [sum(len(train_convs[i][0]) for i in plan) for plan in train_plans]
buffer_capacity = args.device_batch_size * args.max_seq_len + 1
train_packing_eff = [t / buffer_capacity for t in train_tokens_per_batch]
dl = np.array(train_doc_lens)
dpb = np.array(train_docs_per_batch)
pe = np.array(train_packing_eff)
print0(f"Train doc lengths: n={len(dl):,} | mean={dl.mean():.0f} median={np.median(dl):.0f} "
f"min={dl.min()} max={dl.max()} p5={np.percentile(dl,5):.0f} p95={np.percentile(dl,95):.0f}")
print0(f"Train docs/batch: n={len(dpb):,} | mean={dpb.mean():.1f} median={np.median(dpb):.0f} "
f"min={dpb.min()} max={dpb.max()} p5={np.percentile(dpb,5):.0f} p95={np.percentile(dpb,95):.0f}")
print0(f"Train packing eff: mean={pe.mean():.3f} median={np.median(pe):.3f} "
f"min={pe.min():.3f} max={pe.max():.3f}")
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
mask_targets = mask_tensor[:, 1:].to(device=device)
targets[mask_targets == 0] = -1
# num_iterations: exact count of optimization steps. The -1 accounts for the
# prefetch batch that the training loop requests but never trains on.
data_num_iterations = (train_micro_batches - 1) // grad_accum_steps
if args.num_iterations > 0:
num_iterations = min(args.num_iterations, data_num_iterations)
else:
num_iterations = data_num_iterations
if ddp:
num_iter_tensor = torch.tensor([num_iterations], dtype=torch.long, device=device)
dist.all_reduce(num_iter_tensor, op=dist.ReduceOp.MIN)
num_iterations = num_iter_tensor.item()
print0(f"Pre-packed {len(train_convs):,} train conversations into {train_micro_batches:,} micro-batches "
f"=> {num_iterations:,} optimization steps (max {max_num_docs} docs/batch)")
# Mask out padding positions in targets (set to -1 = ignore_index)
# For each row, positions >= (content_length - 1) in targets should be masked
for i, content_len in enumerate(row_lengths):
if content_len < row_capacity:
targets[i, content_len-1:] = -1
yield inputs, targets
train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
train_loader = sft_data_loader_varlen(
train_convs, train_plans, args.device_batch_size, args.max_seq_len,
max_num_docs, bos_token, device=device)
build_val_loader = lambda: sft_data_loader_varlen(
val_convs, val_plans, args.device_batch_size, args.max_seq_len,
max_num_docs, bos_token, device=device, cycle=True)
# Learning rate schedule (linear warmup, constant, linear warmdown)
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
def get_lr_multiplier(progress):
if progress < args.warmup_ratio:
return (progress + 1e-8) / args.warmup_ratio
elif progress <= 1.0 - args.warmdown_ratio:
def get_lr_multiplier(it):
warmup_iters = round(args.warmup_ratio * num_iterations)
warmdown_iters = round(args.warmdown_ratio * num_iterations)
if it < warmup_iters:
return (it + 1) / warmup_iters
elif it <= num_iterations - warmdown_iters:
return 1.0
else:
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
return (1 - decay) * 1.0 + decay * args.final_lr_frac
progress = (num_iterations - it) / warmdown_iters
return progress * 1.0 + (1 - progress) * args.final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -328,21 +320,16 @@ def get_muon_momentum(it):
# -----------------------------------------------------------------------------
# Training loop
x, y = next(train_loader) # prefetch the very first batch of data
x, y, cu_seqlens = next(train_loader) # prefetch the very first batch of data
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
step = 0
while True:
last_step = step == num_iterations
flops_so_far = num_flops_per_token * args.total_batch_size * step
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
if ddp:
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
last_step = bool(last_step_tensor.item())
# once in a while: evaluate the val bpb (all ranks participate)
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
model.eval()
@ -430,17 +417,16 @@ while True:
synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
loss = model(x, y)
loss = model(x, y, cu_seqlens=cu_seqlens)
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
progress = max(progress, approx_progress) # only increase progress monotonically
x, y, cu_seqlens = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# step the optimizer
lrm = get_lr_multiplier(progress)
lrm = get_lr_multiplier(step)
muon_momentum = get_muon_momentum(step)
for group in optimizer.param_groups:
group["lr"] = group["initial_lr"] * lrm
@ -467,13 +453,13 @@ while True:
# logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
pct_done = 100 * step / num_iterations
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
if step % 10 == 0:
wandb_run.log({
"step": step,
@ -484,7 +470,6 @@ while True:
"train/dt": dt,
"train/tok_per_sec": tok_per_sec,
"train/mfu": mfu,
"train/epoch": current_epoch,
})
# The garbage collector spends ~500ms scanning for cycles quite frequently.

View File

@ -1,14 +1,14 @@
"""
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
Test Flash Attention unified interface - verify FA (FA3/FA2) and SDPA produce identical results.
Run: python -m pytest tests/test_attention_fallback.py -v -s
Note on test structure:
Tests are split into two classes due to dtype/device constraints:
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
and verify they produce identical results. These require a Hopper GPU (FA3 only
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
1. TestFA3VsSDPA: Comparison tests that run both FA and SDPA on the same inputs
and verify they produce identical results. These require an Ampere+ GPU
(FA3 on Hopper, FA2 on Ampere/Ada) and use bfloat16.
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
@ -16,24 +16,29 @@ Note on test structure:
import torch
import pytest
import nanochat.flash_attention as fa_module
from nanochat.flash_attention import flash_attn, HAS_FA3
from nanochat.flash_attention import flash_attn, HAS_FA
from nanochat.engine import KVCache
def set_impl(impl):
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
"""Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA."""
fa_module._override_impl = impl
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
fa_module.USE_FA = fa_module._resolve_use_fa()
def run_both_impls(fn):
"""Run a function with both FA3 and SDPA, return both outputs."""
set_impl('fa3')
out_fa3 = fn()
"""Run a function with both FA (FA3 or FA2) and SDPA, return both outputs."""
set_impl('fa')
out_fa = fn()
set_impl('sdpa')
out_sdpa = fn()
set_impl(None) # reset
return out_fa3, out_sdpa
return out_fa, out_sdpa
def make_cu_seqlens(B, T, device):
"""Create cu_seqlens for B documents each of length T."""
return torch.arange(0, (B + 1) * T, T, dtype=torch.int32, device=device)
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
@ -48,9 +53,9 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
# =============================================================================
# FA3 vs SDPA comparison tests (require Hopper GPU)
# =============================================================================
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
@pytest.mark.skipif(not HAS_FA, reason="FA required to compare implementations")
class TestFA3VsSDPA:
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
"""Compare FA and SDPA produce identical results. Requires Ampere+ GPU."""
DEVICE = "cuda"
DTYPE = torch.bfloat16
@ -58,12 +63,15 @@ class TestFA3VsSDPA:
def test_basic_causal(self):
"""Basic causal attention."""
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
return flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
@ -72,12 +80,15 @@ class TestFA3VsSDPA:
def test_full_context(self):
"""Full context (window_size=-1)."""
B, T, H, D = 2, 128, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
return flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
@ -87,12 +98,15 @@ class TestFA3VsSDPA:
"""Sliding window attention."""
B, T, H, D = 2, 128, 4, 32
window = 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
return flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(window, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
@ -104,12 +118,15 @@ class TestFA3VsSDPA:
n_heads = 8
n_kv_heads = 2
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
return flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
@ -118,12 +135,15 @@ class TestFA3VsSDPA:
def test_larger_model(self):
"""Larger dimensions closer to real model."""
B, T, H, D = 4, 256, 12, 64
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
def run():
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
return flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1))
y_fa3, y_sdpa = run_both_impls(run)
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
@ -215,21 +235,24 @@ class TestFA3VsSDPA:
def test_backward_gradients_match(self):
"""Verify gradients are similar between FA3 and SDPA."""
B, T, H, D = 2, 32, 4, 16
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
def run():
q = q_data.clone().requires_grad_(True)
k = k_data.clone().requires_grad_(True)
v = v_data.clone().requires_grad_(True)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y = flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
set_impl('fa3')
set_impl('fa')
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
set_impl('sdpa')
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
@ -261,13 +284,16 @@ class TestSDPAOnly:
"""Test SDPA forward pass produces valid output."""
set_impl('sdpa')
B, T, H, D = 2, 64, 4, 32
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y = flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
assert y.shape == (B, T, H, D)
assert y.shape == (B * T, H, D)
assert not torch.isnan(y).any(), "Output contains NaN"
set_impl(None)
@ -275,11 +301,14 @@ class TestSDPAOnly:
"""Test gradients flow through SDPA."""
set_impl('sdpa')
B, T, H, D = 2, 32, 4, 16
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
y = flash_attn.flash_attn_varlen_func(q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
loss = y.sum()
loss.backward()
@ -340,23 +369,23 @@ class TestSDPAOnly:
class TestOverrideMechanism:
"""Test that the override mechanism works correctly."""
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
def test_override_fa3(self):
"""Test that override='fa3' uses FA3."""
set_impl('fa3')
assert fa_module.USE_FA3 == True
@pytest.mark.skipif(not HAS_FA, reason="FA required")
def test_override_fa(self):
"""Test that override='fa' uses FA."""
set_impl('fa')
assert fa_module.USE_FA == True
set_impl(None)
def test_override_sdpa(self):
"""Test that override='sdpa' uses SDPA."""
set_impl('sdpa')
assert fa_module.USE_FA3 == False
assert fa_module.USE_FA == False
set_impl(None)
def test_override_auto(self):
"""Test that override=None uses auto-detection."""
set_impl(None)
assert fa_module.USE_FA3 == HAS_FA3
assert fa_module.USE_FA == HAS_FA
if __name__ == "__main__":
@ -366,7 +395,7 @@ if __name__ == "__main__":
print(f"CUDA device: {torch.cuda.get_device_name()}")
major, minor = torch.cuda.get_device_capability()
print(f"Compute capability: {major}.{minor}")
print(f"HAS_FA3: {HAS_FA3}")
print(f"HAS_FA: {HAS_FA}")
print()
pytest.main([__file__, "-v", "-s"])