From 6ee8fd6908d4da5cff9220a849be34ce110249c4 Mon Sep 17 00:00:00 2001 From: Chris McCormick Date: Sun, 22 Mar 2026 13:16:08 -0700 Subject: [PATCH] Refactor for FA varlen Made-with: Cursor --- dev/LEADERBOARD.md | 33 ++++ nanochat/core_eval.py | 67 ++++---- nanochat/dataloader.py | 226 ++++++++++++++++++-------- nanochat/flash_attention.py | 138 ++++++++++------ nanochat/gpt.py | 53 +++++-- nanochat/loss_eval.py | 4 +- scripts/base_eval.py | 53 ++++++- scripts/base_train.py | 35 ++-- scripts/chat_eval.py | 14 +- scripts/chat_rl.py | 56 ++++--- scripts/chat_sft.py | 265 +++++++++++++++---------------- tests/test_attention_fallback.py | 139 +++++++++------- 12 files changed, 669 insertions(+), 414 deletions(-) diff --git a/dev/LEADERBOARD.md b/dev/LEADERBOARD.md index 65c0809..adce8bc 100644 --- a/dev/LEADERBOARD.md +++ b/dev/LEADERBOARD.md @@ -200,3 +200,36 @@ This commit is special because all of the improvements that went into [this comm ## Run 6 Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes). + +## Run 7 + +Achieved Mar 22, 2026 on commit `a075166`. Launch command (same as `runs/speedrun.sh`): + +``` +OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \ + --depth=24 \ + --run="d24-varlen" \ + --model-tag="d24-varlen" \ + --device-batch-size=16 \ + --sample-every=-1 \ + --save-every=-1 \ + --core-metric-max-per-task=-1 \ + --core-metric-every=999999 \ + --target-param-data-ratio=8 \ + --fp8 +``` + +Result: + +``` +step 05567/05568 (99.98%) | loss: 2.388741 | lrm: 0.05 | dt: 1048.30ms | tok/sec: 1,000,264 | bf16_mfu: 60.37 | epoch: 1 pq: 117 rg: 64 | total time: 97.38m | eta: 0.0m +Step 05568 | Validation bpb: 0.724772 +Step 05568 | CORE metric: 0.2614 +Peak memory usage: 52865.94MiB +Total training time: 97.38m +Minimum validation bpb: 0.724772 +``` + +The big change in this run is switching nanochat's entire pre-training attention path from standard batched attention to FlashAttention's variable-length packed attention (`flash_attn_varlen_func` with `cu_seqlens`). Each document now attends only to itself, which saves compute by not calculating attention scores across document boundaries. The dataloader is also simplified: the old bestfit bin-packing is replaced by greedy 1D packing where only the final document in each micro-batch is truncated. See the PR description in [dev/PULL_REQUEST.md](PULL_REQUEST.md) for full details. + +Previous record was 1.65 hours / 99 minutes (Run 6), so 1.62 hours / 97.38 minutes is ~1.8% speed improvement. Model checkpoint: [ChrisMcCormick/nanochat-varlen-d24-2026-03-22](https://huggingface.co/ChrisMcCormick/nanochat-varlen-d24-2026-03-22). diff --git a/nanochat/core_eval.py b/nanochat/core_eval.py index f3c9a9f..7276cf2 100644 --- a/nanochat/core_eval.py +++ b/nanochat/core_eval.py @@ -101,15 +101,6 @@ def find_common_length(token_sequences, direction='left'): return min_len -def stack_sequences(tokens, pad_token_id): - """Stack up a list of token sequences, pad to longest on the right""" - bsz, seq_len = len(tokens), max(len(x) for x in tokens) - input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long) - for i, x in enumerate(tokens): - input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long) - return input_ids - - def batch_sequences_mc(tokenizer, prompts): # In multiple choice, contexts are the same but the continuation is different (common prefix) tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id()) @@ -142,26 +133,31 @@ def batch_sequences_lm(tokenizer, prompts): @torch.no_grad() -def forward_model(model, input_ids): +def forward_model(model, tokens_list, device): """ - Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions. - The last column of losses is set to nan because we don't have autoregressive targets there. + Pack token sequences into varlen, forward, extract per-sequence losses and predictions. + Returns (losses_list, preds_list) where each element corresponds to one input sequence. + The last element of each loss vector is nan (no autoregressive target there). """ - batch_size, seq_len = input_ids.size() - outputs = model(input_ids) - # Roll the tensor to the left by one position to get the (autoregressive) target ids - target_ids = torch.roll(input_ids, shifts=-1, dims=1) - # Calculate cross entropy at all positions - losses = torch.nn.functional.cross_entropy( - outputs.view(batch_size * seq_len, -1), - target_ids.view(batch_size * seq_len), - reduction='none' - ).view(batch_size, seq_len) - # Set the last column to be nan because there is no autoregressive loss there - losses[:, -1] = float('nan') - # Get the argmax predictions at each position - predictions = outputs.argmax(dim=-1) - return losses, predictions + packed = torch.cat([torch.tensor(t, dtype=torch.long, device=device) for t in tokens_list]) + cu_seqlens = torch.zeros(len(tokens_list) + 1, dtype=torch.int32, device=device) + for i, t in enumerate(tokens_list): + cu_seqlens[i + 1] = cu_seqlens[i] + len(t) + + outputs = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V) + + losses_list = [] + preds_list = [] + for i in range(len(tokens_list)): + s, e = cu_seqlens[i].item(), cu_seqlens[i + 1].item() + doc_logits = outputs[s:e] + doc_tokens = packed[s:e] + doc_targets = torch.roll(doc_tokens, -1) + doc_losses = torch.nn.functional.cross_entropy(doc_logits[:-1], doc_targets[:-1], reduction='none') + doc_losses = torch.cat([doc_losses, torch.tensor([float('nan')], device=device)]) + losses_list.append(doc_losses) + preds_list.append(doc_logits.argmax(dim=-1)) + return losses_list, preds_list @torch.no_grad() @@ -212,26 +208,21 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta): new_end_idxs.append(e) tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs - # Stack up all the sequences into a batch - pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok - input_ids = stack_sequences(tokens, pad_token_id) - input_ids = input_ids.to(device) - - # Forward the model, get the autoregressive loss and argmax prediction at each token - losses, predictions = forward_model(model, input_ids) + # Forward the model with varlen packing (no padding waste) + losses_list, preds_list = forward_model(model, tokens, device) # See if the losses/predictions come out correctly if task_type == 'language_modeling': # language modeling task is currently always batch size 1 si = start_idxs[0] ei = end_idxs[0] - # predictions[i] predict input_ids[i+1] autoregressively - predicted_tokens = predictions[0, si-1:ei-1] - actual_tokens = input_ids[0, si:ei] + # preds_list[i][j] predicts tokens[i][j+1] autoregressively + predicted_tokens = preds_list[0][si-1:ei-1] + actual_tokens = torch.tensor(tokens[0][si:ei], device=device) is_correct = torch.all(predicted_tokens == actual_tokens).item() elif task_type in ['multiple_choice', 'schema']: # For MC/schema: find the option with lowest average loss - mean_losses = [losses[i, si-1:ei-1].mean().item() + mean_losses = [losses_list[i][si-1:ei-1].mean().item() for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))] pred_idx = mean_losses.index(min(mean_losses)) is_correct = pred_idx == item['gold'] diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index 4cb2279..d937683 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -1,11 +1,10 @@ """ -Distributed dataloaders for pretraining. +Distributed dataloader for pretraining. -BOS-aligned bestfit: - - Every row starts with BOS token - - Documents packed using best-fit algorithm to minimize cropping - - When no document fits remaining space, crops a document to fill exactly - - 100% utilization (no padding), ~35% tokens cropped at T=2048 +Varlen 1D packing: + - Packs documents into 1D buffer with cu_seqlens for per-document attention isolation + - No cropping, no padding: every token is used exactly once + - Yields (inputs_1d, targets_1d, cu_seqlens) for flash_attn_varlen_func Compared to the original tokenizing_distributed_data_loader: BOS-aligned loses ~35% of tokens to cropping, but ensures that @@ -71,31 +70,43 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size): epoch += 1 -def tokenizing_distributed_data_loader_with_state_bos_bestfit( - tokenizer, B, T, split, + +# ============================================================================= +# 1D packed varlen dataloader +# ============================================================================= +# Packs documents into a single flat buffer of B*T tokens with cu_seqlens marking +# document boundaries for flash_attn_varlen_func. Each document gets its own +# attention context. Greedy packing: documents are added sequentially until the +# buffer is full. Only the last document in each micro-batch gets cropped. +# +# Requires specificying a fixed maximum number of docs supported per batch. +# The dataloader will append additional documents to the final segment if needed, +# resulting in cross-document attention bleeding, but that hasn't been a problem +# in practice. +# It's recommended to keep max_num_docs tight rather than padding it conservatively +# because an oversized `cu_seqlens` tensor will hurt FlashAttention performance +# somewhat. + +def tokenizing_distributed_data_loader_varlen( + tokenizer, B, T, split, max_num_docs, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None, - buffer_size=1000 ): """ - BOS-aligned dataloader with Best-Fit Cropping. + 1D packed varlen dataloader for use with flash_attn_varlen_func. - Reduces token waste compared to simple greedy cropping by searching a buffer - for documents that fit well, while maintaining 100% utilization (no padding). - - Algorithm for each row: - 1. From buffered docs, pick the LARGEST doc that fits entirely - 2. Repeat until no doc fits - 3. When nothing fits, crop a doc to fill remaining space exactly - - Key properties: - - Every row starts with BOS - - 100% utilization (no padding, every token is trained on) - - Approximately 35% of all tokens are discarded due to cropping + Yields (inputs, targets, cu_seqlens, state_dict) where: + - inputs: 1D long tensor of shape (B*T,) + - targets: 1D long tensor of shape (B*T,), shifted by 1 + - cu_seqlens: int32 tensor of shape (max_num_docs,), cumulative doc lengths + padded with total_tokens for unused slots (ghost segments of length 0) + - state_dict: {"pq_idx", "rg_idx", "epoch"} for checkpoint resume """ assert split in ["train", "val"], "split must be 'train' or 'val'" - row_capacity = T + 1 + total_tokens = B * T + buffer_capacity = total_tokens + 1 # +1 so the last input position has a target + batches = _document_batches(split, resume_state_dict, tokenizer_batch_size) bos_token = tokenizer.get_bos_token_id() doc_buffer = [] @@ -105,62 +116,139 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit( nonlocal pq_idx, rg_idx, epoch doc_batch, (pq_idx, rg_idx, epoch) = next(batches) token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads) - for tokens in token_lists: - doc_buffer.append(tokens) + doc_buffer.extend(token_lists) - # Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)] - # This gives us contiguous views and a single HtoD transfer + # Pre-allocate all buffers once use_cuda = device == "cuda" - row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists - cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU) - gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer - cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience - cpu_targets = cpu_buffer[B * T:].view(B, T) - inputs = gpu_buffer[:B * T].view(B, T) - targets = gpu_buffer[B * T:].view(B, T) + pack_buffer = torch.empty(buffer_capacity, dtype=torch.long) # 1D packing workspace + cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda) + gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device) + cpu_inputs = cpu_buffer[:total_tokens] + cpu_targets = cpu_buffer[total_tokens:] + inputs = gpu_buffer[:total_tokens] + targets = gpu_buffer[total_tokens:] + cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32) + cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device) + warned = False + warned_seqlen = False while True: - for row_idx in range(B): - pos = 0 - while pos < row_capacity: - # Ensure buffer has documents - while len(doc_buffer) < buffer_size: - refill_buffer() + # Greedily pack documents into a single 1D buffer + pos = 0 + doc_count = 0 + cu_seqlens_cpu[0] = 0 - remaining = row_capacity - pos + while pos < buffer_capacity: + while len(doc_buffer) == 0: + refill_buffer() - # Find largest doc that fits entirely - best_idx = -1 - best_len = 0 - for i, doc in enumerate(doc_buffer): - doc_len = len(doc) - if doc_len <= remaining and doc_len > best_len: - best_idx = i - best_len = doc_len + doc = doc_buffer.pop(0) + doc_len = min(len(doc), T) # truncate to max_seq_len + remaining = buffer_capacity - pos + use_len = min(doc_len, remaining) # crop last doc to fill exactly - if best_idx >= 0: - doc = doc_buffer.pop(best_idx) - doc_len = len(doc) - row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long) - pos += doc_len - else: - # No doc fits - crop shortest in buffer to fill remaining and minimize waste - shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i])) - doc = doc_buffer.pop(shortest_idx) - row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long) - pos += remaining + pack_buffer[pos:pos + use_len] = torch.tensor(doc[:use_len], dtype=torch.long) + pos += use_len + if doc_count < max_num_docs - 1: + doc_count += 1 + cu_seqlens_cpu[doc_count] = min(pos, total_tokens) + else: + if not warned: + print(f"Warning: too many documents for cu_seqlens size ({max_num_docs}), " + f"merging remaining docs (cross-document attention bleeding)") + warned = True + merged_len = min(pos, total_tokens) - cu_seqlens_cpu[doc_count].item() + if merged_len > T and not warned_seqlen: + print(f"Warning: merged segment length ({merged_len}) exceeds max_seq_len ({T}). " + f"Increase max_num_docs to avoid silent attention truncation.") + warned_seqlen = True - # Copy to pinned CPU buffer, then single HtoD transfer - cpu_inputs.copy_(row_buffer[:, :-1]) - cpu_targets.copy_(row_buffer[:, 1:]) + # Ensure the final document boundary always points to the end of the batch + cu_seqlens_cpu[doc_count] = total_tokens + + # Pad remaining cu_seqlens slots (ghost segments of length 0) + cu_seqlens_cpu[doc_count + 1:] = total_tokens + + # Split into inputs/targets (standard next-token prediction shift) + cpu_inputs.copy_(pack_buffer[:total_tokens]) + cpu_targets.copy_(pack_buffer[1:total_tokens + 1]) state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch} - # Single HtoD copy into persistent GPU buffer and yield + # H2D transfer: single copy for tokens, small copy for cu_seqlens gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda) - yield inputs, targets, state_dict + cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda) + yield inputs, targets, cu_seqlens_gpu, state_dict -def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs): - """Helper that omits state_dict from yields.""" - for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs): - yield inputs, targets + +# ============================================================================= +# SFT varlen dataloader (replay from pre-packed batch plans) +# ============================================================================= + +def sft_data_loader_varlen( + conversations, batch_plan, B, T, max_num_docs, bos_token, + device="cuda", cycle=False, +): + """ + Replay dataloader for SFT: constructs 1D-packed varlen batches from + pre-computed batch plans (see tokenize_and_pack_sft in chat_sft.py). + + Args: + conversations: list of (ids, mask) tuples (pre-tokenized) + batch_plan: list of lists of conversation indices + B, T: batch dimensions (total_tokens = B * T) + max_num_docs: cu_seqlens tensor size (exact max from pre-packing) + bos_token: BOS token id for padding + device: target device + cycle: if True, repeat the batch plan indefinitely (for val eval) + """ + total_tokens = B * T + buffer_capacity = total_tokens + 1 + use_cuda = torch.device(device).type == "cuda" + + pack_buffer = torch.empty(buffer_capacity, dtype=torch.long) + mask_buffer = torch.empty(buffer_capacity, dtype=torch.int8) + cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda) + gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device) + cpu_inputs = cpu_buffer[:total_tokens] + cpu_targets = cpu_buffer[total_tokens:] + inputs = gpu_buffer[:total_tokens] + targets = gpu_buffer[total_tokens:] + cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32) + cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device) + + while True: + for conv_indices in batch_plan: + pos = 0 + doc_count = 0 + cu_seqlens_cpu[0] = 0 + + for conv_idx in conv_indices: + ids, mask = conversations[conv_idx] + conv_len = len(ids) + pack_buffer[pos:pos + conv_len] = torch.tensor(ids, dtype=torch.long) + mask_buffer[pos:pos + conv_len] = torch.tensor(mask, dtype=torch.int8) + pos += conv_len + doc_count += 1 + cu_seqlens_cpu[doc_count] = min(pos, total_tokens) + + if pos < buffer_capacity: + remaining = buffer_capacity - pos + pack_buffer[pos:pos + remaining] = bos_token + mask_buffer[pos:pos + remaining] = 0 + doc_count += 1 + cu_seqlens_cpu[doc_count] = total_tokens + + cu_seqlens_cpu[doc_count + 1:] = total_tokens + + cpu_inputs.copy_(pack_buffer[:total_tokens]) + cpu_targets.copy_(pack_buffer[1:total_tokens + 1]) + target_mask = mask_buffer[1:total_tokens + 1] + cpu_targets[target_mask == 0] = -1 + + gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda) + cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda) + yield inputs, targets, cu_seqlens_gpu + + if not cycle: + break diff --git a/nanochat/flash_attention.py b/nanochat/flash_attention.py index af2aee3..70ed00b 100644 --- a/nanochat/flash_attention.py +++ b/nanochat/flash_attention.py @@ -1,66 +1,86 @@ """ -Unified Flash Attention interface with automatic FA3/SDPA switching. +Unified Flash Attention interface with three-tier automatic backend selection: -Exports `flash_attn` module that matches the FA3 API exactly, but falls back -to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU. + FA3 (Hopper sm90) -> FA2 (Ampere sm80 / Ada sm89) -> PyTorch SDPA fallback -Usage (drop-in replacement for FA3): +Exports `flash_attn` module with two functions: + +Usage: from nanochat.flash_attention import flash_attn - # Training (no KV cache) - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) + # Training with packed variable-length sequences: q, k, v are (total_tokens, H, D) + y = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...) # Inference (with KV cache) y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...) + +All non-cached forward passes go through varlen. (B, T) callers that don't provide +cu_seqlens get it auto-constructed in GPT.forward. + +FA3 and FA2 both support flash_attn_varlen_func with per-document attention isolation. +The SDPA fallback reshapes to (B, T_seq) and uses is_causal=True -- no doc isolation, +but efficient kernels on all hardware (Mac, CPU, Blackwell, older GPUs). """ import torch import torch.nn.functional as F # ============================================================================= -# Detection: Try to load FA3 on Hopper+ GPUs +# Detection: Try FA3 (Hopper), then FA2 (Ampere/Ada), then SDPA fallback # ============================================================================= -def _load_flash_attention_3(): - """Try to load Flash Attention 3 (requires Hopper GPU, sm90).""" +def _load_flash_attention(): + """Try to load Flash Attention kernels. Returns (module, version_string) or (None, None).""" if not torch.cuda.is_available(): - return None + return None, None try: major, _ = torch.cuda.get_device_capability() - # FA3 kernels are compiled for Hopper (sm90) only - # Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled - if major != 9: - return None import os os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1" from kernels import get_kernel - return get_kernel('varunneal/flash-attention-3').flash_attn_interface + + # FA3: Hopper (sm90) only + if major == 9: + try: + return get_kernel('varunneal/flash-attention-3').flash_attn_interface, 'fa3' + except Exception: + pass + + # FA2: Ampere (sm80), Ada (sm89), and Hopper fallback + if major >= 8: + try: + return get_kernel('kernels-community/flash-attn2').flash_attn_interface, 'fa2' + except Exception: + pass except Exception: - return None + pass + return None, None -_fa3 = _load_flash_attention_3() -HAS_FA3 = _fa3 is not None +_fa, FA_VERSION = _load_flash_attention() +HAS_FA = _fa is not None -# Override for testing: set to 'fa3', 'sdpa', or None (auto) +# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto) _override_impl = None -def _resolve_use_fa3(): - """Decide once whether to use FA3, based on availability, override, and dtype.""" - if _override_impl == 'fa3': - assert HAS_FA3, "Cannot override to FA3: not available on this hardware" +def _resolve_use_fa(): + """Decide once whether to use FA, based on availability, override, and dtype.""" + if _override_impl in ('fa3', 'fa2', 'fa'): + assert HAS_FA, "Cannot override to FA: not available on this hardware" return True if _override_impl == 'sdpa': return False - if HAS_FA3: - # FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback + if HAS_FA: from nanochat.common import COMPUTE_DTYPE - if COMPUTE_DTYPE == torch.bfloat16: - return True - return False + if FA_VERSION == 'fa3': + # FA3 Hopper kernels only support bf16 and fp8 + return COMPUTE_DTYPE == torch.bfloat16 + else: + # FA2 supports bf16 and fp16 + return COMPUTE_DTYPE in (torch.bfloat16, torch.float16) return False -USE_FA3 = _resolve_use_fa3() +USE_FA = _resolve_use_fa() # ============================================================================= @@ -101,31 +121,57 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa): return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa) + +def _sdpa_varlen_attention(q, k, v, max_seqlen, window_size, enable_gqa): + """ + SDPA fallback for varlen: reshapes packed (T, H, D) to (B, T_seq, H, D) + and uses standard causal SDPA. No document isolation (cross-doc bleeding + within each T_seq chunk), but uses efficient is_causal=True kernels. + """ + T, H, D = q.shape + H_kv = k.shape[1] + B = T // max_seqlen + q = q.view(B, max_seqlen, H, D).transpose(1, 2) + k = k.view(B, max_seqlen, H_kv, D).transpose(1, 2) + v = v.view(B, max_seqlen, H_kv, D).transpose(1, 2) + y = _sdpa_attention(q, k, v, window_size, enable_gqa) + return y.transpose(1, 2).reshape(T, H, D) + + # ============================================================================= # Public API: Same interface as FA3 # ============================================================================= -def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)): +def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + causal=False, window_size=(-1, -1)): """ - Flash Attention for training (no KV cache). + Flash Attention for packed variable-length sequences (training, no KV cache). + + 1D packed inputs where multiple documents are concatenated into one buffer. + Each document attends only to itself, with boundaries defined by cu_seqlens. Args: - q, k, v: Tensors of shape (B, T, H, D) - causal: Whether to use causal masking + q, k, v: Tensors of shape (total_tokens, H, D) + cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths, shape (max_num_seqs,). + Format: [0, end_doc1, end_doc2, ..., total, total, ...] + max_seqlen_q, max_seqlen_k: Max individual sequence length (FA3 tiling hint). + causal: Whether to use causal masking. window_size: (left, right) sliding window. -1 means unlimited. Returns: - Output tensor of shape (B, T, H, D) + Output tensor of shape (total_tokens, H, D) """ - if USE_FA3: - return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size) + if USE_FA: + return _fa.flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, + causal=causal, window_size=window_size, + ) - # SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D) - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + # SDPA fallback: reshape to (B, T_seq) and use standard causal SDPA (no doc isolation) enable_gqa = q.size(1) != k.size(1) - y = _sdpa_attention(q, k, v, window_size, enable_gqa) - return y.transpose(1, 2) # back to (B, T, H, D) + return _sdpa_varlen_attention(q, k, v, max_seqlen_q, window_size, enable_gqa) def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None, @@ -146,8 +192,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N Returns: Output tensor of shape (B, T_new, H, D) """ - if USE_FA3: - return _fa3.flash_attn_with_kvcache( + if USE_FA and hasattr(_fa, 'flash_attn_with_kvcache'): + return _fa.flash_attn_with_kvcache( q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size ) @@ -178,10 +224,10 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N # ============================================================================= -# Export: flash_attn module interface (drop-in replacement for FA3) +# Export: flash_attn module interface (drop-in replacement for FA3/FA2) # ============================================================================= from types import SimpleNamespace flash_attn = SimpleNamespace( - flash_attn_func=flash_attn_func, + flash_attn_varlen_func=flash_attn_varlen_func, flash_attn_with_kvcache=flash_attn_with_kvcache, ) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b822e4..d83feb4 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -79,7 +79,7 @@ class CausalSelfAttention(nn.Module): self.ve_gate_channels = 12 self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None - def forward(self, x, ve, cos_sin, window_size, kv_cache): + def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -103,10 +103,7 @@ class CausalSelfAttention(nn.Module): # Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere) # window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context - if kv_cache is None: - # Training: causal attention with optional sliding window - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size) - else: + if kv_cache is not None: # Inference: use flash_attn_with_kvcache which handles cache management k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx) y = flash_attn.flash_attn_with_kvcache( @@ -119,6 +116,15 @@ class CausalSelfAttention(nn.Module): # Advance position after last layer processes if self.layer_idx == kv_cache.n_layers - 1: kv_cache.advance(T) + else: + # Varlen: packed 1D sequence with per-document attention isolation + assert cu_seqlens is not None + y = flash_attn.flash_attn_varlen_func( + q[0], k[0], v[0], + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len, + causal=True, window_size=window_size) + y = y.unsqueeze(0) # Re-assemble the heads and project back to residual stream y = y.contiguous().view(B, T, -1) @@ -145,8 +151,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, ve, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) + def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None): + x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache, cu_seqlens, max_seq_len) x = x + self.mlp(norm(x)) return x @@ -189,10 +195,11 @@ class GPT(nn.Module): kv_dim = config.n_kv_head * head_dim self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. - # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, - # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. - # In the future we can dynamically grow the cache, for now it's fine. - self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer? + # Rotary embeddings are small in memory, so we over-compute generously. With varlen + # training the full micro-batch is one sequence (T = batch_size * seq_len), so we need + # enough headroom for that. 64X covers batch sizes up to 64, and the assert in forward + # will catch if we ever exceed. + self.rotary_seq_len = config.sequence_len * 64 head_dim = config.n_embd // config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint @@ -408,7 +415,18 @@ class GPT(nn.Module): group["initial_lr"] = group["lr"] return optimizer - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): + def forward(self, idx, targets=None, cu_seqlens=None, kv_cache=None, loss_reduction='mean'): + if cu_seqlens is not None: + assert idx.ndim == 1 + idx = idx.unsqueeze(0) + if targets is not None: + targets = targets.unsqueeze(0) + max_seq_len = self.config.sequence_len + elif kv_cache is not None: + max_seq_len = None + else: + raise ValueError("GPT.forward requires either cu_seqlens or kv_cache") + B, T = idx.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2)) @@ -451,7 +469,7 @@ class GPT(nn.Module): for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None - x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) + x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache, cu_seqlens, max_seq_len) if i == backout_layer: x_backout = x # Subtract mid-layer residual to remove low-level features before logit projection @@ -489,10 +507,11 @@ class GPT(nn.Module): if temperature > 0: rng = torch.Generator(device=device) rng.manual_seed(seed) - ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim + ids = torch.tensor(tokens, dtype=torch.long, device=device) # 1D, no batch dim for _ in range(max_tokens): - logits = self.forward(ids) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) + cu_seqlens = torch.tensor([0, ids.size(0)], dtype=torch.int32, device=device) + logits = self.forward(ids, cu_seqlens=cu_seqlens) # (1, T, vocab_size) + logits = logits[:, -1, :] # (1, vocab_size) if top_k is not None and top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') @@ -502,6 +521,6 @@ class GPT(nn.Module): next_ids = torch.multinomial(probs, num_samples=1, generator=rng) else: next_ids = torch.argmax(logits, dim=-1, keepdim=True) - ids = torch.cat((ids, next_ids), dim=1) + ids = torch.cat((ids, next_ids.squeeze(0))) # stay 1D token = next_ids.item() yield token diff --git a/nanochat/loss_eval.py b/nanochat/loss_eval.py index 5a556e6..c47855a 100644 --- a/nanochat/loss_eval.py +++ b/nanochat/loss_eval.py @@ -29,8 +29,8 @@ def evaluate_bpb(model, batches, steps, token_bytes): total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device()) batch_iter = iter(batches) for _ in range(steps): - x, y = next(batch_iter) - loss2d = model(x, y, loss_reduction='none') # (B, T) + x, y, cu_seqlens, *_ = next(batch_iter) + loss2d = model(x, y, cu_seqlens=cu_seqlens, loss_reduction='none') loss2d = loss2d.view(-1) # flatten y = y.view(-1) # flatten if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32 diff --git a/scripts/base_eval.py b/scripts/base_eval.py index a57bbaf..bcf827f 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -21,6 +21,7 @@ Examples: """ import os import csv +import math import time import json import yaml @@ -35,7 +36,7 @@ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task -from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit +from nanochat.dataloader import tokenizing_distributed_data_loader_varlen from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine @@ -48,7 +49,9 @@ class ModelWrapper: self.model = model self.max_seq_len = max_seq_len - def __call__(self, input_ids, targets=None, loss_reduction='mean'): + def __call__(self, input_ids, targets=None, cu_seqlens=None, loss_reduction='mean'): + if cu_seqlens is not None: + return self._forward_varlen(input_ids, targets, cu_seqlens, loss_reduction) logits = self.model(input_ids).logits if targets is None: return logits @@ -60,6 +63,48 @@ class ModelWrapper: ) return loss + def _forward_varlen(self, input_ids, targets, cu_seqlens, loss_reduction): + """Unpack 1D varlen to padded (B, T), forward through HF model, repack to (1, total_T, V).""" + device = input_ids.device + doc_lens = cu_seqlens[1:] - cu_seqlens[:-1] + num_docs = (doc_lens > 0).sum().item() + max_doc_len = doc_lens.max().item() + + # Unpack: 1D -> padded (num_docs, max_doc_len) + batched = torch.zeros(num_docs, max_doc_len, dtype=input_ids.dtype, device=device) + doc_idx = 0 + for i in range(len(doc_lens)): + length = doc_lens[i].item() + if length == 0: + continue + start = cu_seqlens[i].item() + batched[doc_idx, :length] = input_ids[start:start + length] + doc_idx += 1 + + logits = self.model(batched).logits # (num_docs, max_doc_len, V) + + # Repack: padded (num_docs, max_doc_len, V) -> (1, total_T, V) + total_T = input_ids.size(0) + packed = torch.empty(1, total_T, logits.size(-1), dtype=logits.dtype, device=device) + doc_idx = 0 + for i in range(len(doc_lens)): + length = doc_lens[i].item() + if length == 0: + continue + start = cu_seqlens[i].item() + packed[0, start:start + length] = logits[doc_idx, :length] + doc_idx += 1 + + if targets is None: + return packed + loss = torch.nn.functional.cross_entropy( + packed.view(-1, packed.size(-1)), + targets.view(-1), + ignore_index=-1, + reduction=loss_reduction + ) + return loss + def get_device(self): return next(self.model.parameters()).device @@ -269,8 +314,10 @@ def main(): print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})") steps = args.split_tokens // tokens_per_step + avg_num_docs = args.device_batch_size * sequence_len // 400 + max_num_docs = math.ceil(avg_num_docs / 16) * 16 for split_name in ["train", "val"]: - loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device) + loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, sequence_len, split_name, max_num_docs=max_num_docs, device=device) bpb = evaluate_bpb(model, loader, steps, token_bytes) bpb_results[split_name] = bpb print0(f"{split_name} bpb: {bpb:.6f}") diff --git a/scripts/base_train.py b/scripts/base_train.py index 86aa770..9b270ab 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -26,13 +26,13 @@ import torch import torch.distributed as dist from nanochat.gpt import GPT, GPTConfig, Linear -from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit +from nanochat.dataloader import tokenizing_distributed_data_loader_varlen from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint from nanochat.loss_eval import evaluate_bpb from nanochat.engine import Engine -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA, FA_VERSION from scripts.base_eval import evaluate_core print_banner() @@ -100,17 +100,19 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config) # Flash Attention status -from nanochat.flash_attention import USE_FA3 -using_fa3 = USE_FA3 -if using_fa3: - print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") +from nanochat.flash_attention import USE_FA +if USE_FA: + if FA_VERSION == 'fa3': + print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.") + else: + print0(f"✓ Using Flash Attention 2 (Ampere/Ada GPU detected).") else: print0("!" * 80) - if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16: - print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") + if HAS_FA and COMPUTE_DTYPE != torch.bfloat16: + print0(f"WARNING: Flash Attention only supports bf16/fp16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback") else: - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback") - print0("WARNING: Training will be less efficient without FA3") + print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback") + print0("WARNING: Training will be less efficient without Flash Attention") if args.window_pattern != "L": print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.") print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.") @@ -326,10 +328,13 @@ if scaler is not None: # ----------------------------------------------------------------------------- # Initialize the DataLoaders for train/val +# max_num_docs: see dataloader.py for explanation +avg_num_docs = args.device_batch_size * args.max_seq_len // 400 +max_num_docs = math.ceil(avg_num_docs / 16) * 16 dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"] -train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict) -build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device) -x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data +train_loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="train", max_num_docs=max_num_docs, device=device, resume_state_dict=dataloader_resume_state_dict) +build_val_loader = lambda: tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="val", max_num_docs=max_num_docs, device=device) +x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- # Calculate the number of iterations we will train for and set up the various schedulers @@ -507,14 +512,14 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - loss = model(x, y) + loss = model(x, y, cu_seqlens=cu_seqlens) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here if scaler is not None: scaler.scale(loss).backward() else: loss.backward() - x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward + x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer lrm = get_lr_multiplier(step) muon_momentum = get_muon_momentum(step) diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index 858d4c2..39e2f6c 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -89,7 +89,6 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() device = model.get_device() - bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored # We'll process batches of independent problems at a time because there is no sampling needed num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems) @@ -102,17 +101,18 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems for i in range(ddp_rank, num_batches, ddp_world_size): i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems) - # Prepare the batch of problems. They might all be of different length, so we pad/collate them. + # Prepare the batch of problems and pack into varlen (no padding waste) conversations = [task_object[ii] for ii in range(i0, i1)] prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works - max_length = max(len(ids) for ids in prompt_ids) answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer) - padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids] - prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device) + packed = torch.cat([torch.tensor(ids, dtype=torch.long, device=device) for ids in prompt_ids]) + cu_seqlens = torch.zeros(len(prompt_ids) + 1, dtype=torch.int32, device=device) + for j, ids in enumerate(prompt_ids): + cu_seqlens[j + 1] = cu_seqlens[j] + len(ids) # Get the logits for the whole batch of conversations in parallel (efficiency win here) with torch.no_grad(): - logits = model(prompt_ids) # (B, T, V) + logits = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V) # Focus on the available answer on just the letters corresponding to choices # Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters @@ -130,7 +130,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems letter_ids.append(letter_to_id_cache[letter]) # focus logits just down to the answer position and the available letters of the answer answer_pos = answer_time_positions[idx] - focus_logits = logits[idx, answer_pos, letter_ids] + focus_logits = logits[cu_seqlens[idx].item() + answer_pos, letter_ids] # get the argmax letter (the predicted answer) argmax_letter_id = focus_logits.argmax(dim=-1).item() predicted_letter = letters[argmax_letter_id] diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index cb2cb0e..e9a3887 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -84,7 +84,6 @@ print0(f"Calculated number of steps: {num_steps}") @torch.no_grad() def get_batch(): - assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss. rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data for example_idx in itertools.cycle(rank_indices): @@ -125,25 +124,30 @@ def get_batch(): reward = train_task.reward(conversation, generated_text) rewards.append(reward) - # Pad the sequences so that their lengths (in time) match - max_length = max(len(seq) for seq in generated_token_sequences) - padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences] - padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks] - # Stack up the sequences and masks into PyTorch tensors - ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device) - mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device) - # Generate autoregressive inputs and targets to the Transformer - inputs = ids[:, :-1] - targets = ids[:, 1:].clone() # clone to avoid in-place modification: - targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index + # Pack the sequences into a 1D buffer (varlen packing, no padding needed) + # Generate autoregressive inputs and targets for each sequence + input_seqs = [] + target_seqs = [] + for seq, mask in zip(generated_token_sequences, masks): + seq_t = torch.tensor(seq, dtype=torch.long, device=device) + mask_t = torch.tensor(mask, dtype=torch.long, device=device) + input_seqs.append(seq_t[:-1]) + tgt = seq_t[1:].clone() # clone to avoid in-place modification: + tgt[mask_t[1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index + target_seqs.append(tgt) # NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens. # So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens. + # Concatenate the sequences and masks into 1D PyTorch tensors + inputs = torch.cat(input_seqs) + targets = torch.cat(target_seqs) + lengths = torch.tensor([len(s) for s in input_seqs], dtype=torch.int32, device=device) + cu_seqlens = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]).to(torch.int32) rewards = torch.tensor(rewards, dtype=torch.float, device=device) # Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma) mu = rewards.mean() advantages = rewards - mu - # yield inputs/targets as (B, T) of ids and rewards as (B,) of floats - yield generated_token_sequences, inputs, targets, rewards, advantages + # yield packed 1D inputs/targets with cu_seqlens, and rewards/advantages as (B,) of floats + yield generated_token_sequences, inputs, targets, cu_seqlens, rewards, advantages # ----------------------------------------------------------------------------- # Simple evaluation loop for GSM8K pass@k @@ -247,23 +251,31 @@ for step in range(num_steps): sequence_lengths = [] for example_step in range(examples_per_rank): # Get one batch corresponding to one example in the training dataset - sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator) + sequences_all, inputs_all, targets_all, cu_seqlens_all, rewards_all, advantages_all = next(batch_iterator) # Evaluate the loss and gradients model.train() # ensure the model is in train mode # We need one more loop because we can never exceed the device_batch_size - assert inputs_all.size(0) % args.device_batch_size == 0 - num_passes = inputs_all.size(0) // args.device_batch_size + num_seqs = len(cu_seqlens_all) - 1 + assert num_seqs % args.device_batch_size == 0 + num_passes = num_seqs // args.device_batch_size for pass_idx in range(num_passes): - # Pluck out the batch for this pass + # Pluck out the sub-batch for this pass from the packed 1D buffer b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size - inputs = inputs_all[b0:b1] - targets = targets_all[b0:b1] + t0, t1 = cu_seqlens_all[b0].item(), cu_seqlens_all[b1].item() + inputs = inputs_all[t0:t1] + targets = targets_all[t0:t1] + cu_seqlens = cu_seqlens_all[b0:b1+1] - cu_seqlens_all[b0] rewards = rewards_all[b0:b1] advantages = advantages_all[b0:b1] # Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate - logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T) + logp = -model(inputs, targets, cu_seqlens=cu_seqlens, loss_reduction='none') # (T_sub,) + # Expand per-sequence advantages to per-token positions using cu_seqlens boundaries + token_advantages = torch.zeros_like(logp) + for i in range(b1 - b0): + s, e = cu_seqlens[i].item(), cu_seqlens[i+1].item() + token_advantages[s:e] = advantages[i] # Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0. - pg_obj = (logp * advantages.unsqueeze(-1)).sum() + pg_obj = (logp * token_advantages).sum() # normalize by the number of valid tokens, number of passes, and examples_per_rank num_valid = (targets >= 0).sum().clamp(min=1) pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..8a47e50 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -10,6 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-s """ import gc +import math import argparse import os os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" @@ -21,7 +22,8 @@ from nanochat.tokenizer import get_token_bytes from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state from nanochat.loss_eval import evaluate_bpb import torch.distributed as dist -from nanochat.flash_attention import HAS_FA3 +from nanochat.flash_attention import HAS_FA +from nanochat.dataloader import sft_data_loader_varlen from nanochat.engine import Engine from scripts.chat_eval import run_chat_eval @@ -89,8 +91,8 @@ use_dummy_wandb = args.run == "dummy" or not master_process wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config) # Flash Attention status -if not HAS_FA3: - print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.") +if not HAS_FA: + print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.") # Load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step) @@ -178,147 +180,137 @@ val_dataset = TaskMixture([ MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios ]) # total: 24K + 14K + 1.32K ~= 39K rows -# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len) -# A big problem is that we don't know the final num_iterations in advance. So we create -# these two global variables and update them from within the data generator. -last_step = False # we will toggle this to True when we reach the end of the training dataset -approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch -current_epoch = 1 # track epoch for logging -def sft_data_generator_bos_bestfit(split, buffer_size=100): - """ - BOS-aligned dataloader for SFT with bestfit-pad packing. - Each row in the batch starts with BOS (beginning of a conversation). - Conversations are packed using best-fit algorithm. When no conversation fits, - the row is padded (instead of cropping) to ensure no tokens are ever discarded. - Padding positions have targets masked with -1 (ignore_index for cross-entropy). +# Pre-tokenize and pre-pack all conversations into batch plans. +# This runs the same best-fit packing algorithm offline at startup, so we know +# num_iterations and max_num_docs exactly before training starts. +def tokenize_and_pack_sft(dataset, tokenizer, B, T, bos_token, ddp_rank, ddp_world_size, buffer_size=100): + """ + Pre-tokenize and pre-pack SFT conversations using best-fit packing. + + Preserves TaskMixture's shuffled ordering (no length sorting) and uses the + same buffer-based best-fit algorithm as the original inline dataloader. + + Returns (conversations, batch_plans, total_micro_batches, max_num_docs). """ - global last_step, approx_progress, current_epoch - assert split in {"train", "val"}, "split must be 'train' or 'val'" - dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) - assert dataset_size > 0 - row_capacity = args.max_seq_len + 1 # +1 for target at last position - bos_token = tokenizer.get_bos_token_id() + buffer_capacity = B * T + 1 - # Conversation buffer: list of (token_ids, loss_mask) tuples + conversations = [] + num_convs = (dataset_size - ddp_rank + ddp_world_size - 1) // ddp_world_size + cursor = ddp_rank + while cursor < dataset_size: + ids, mask = tokenizer.render_conversation(dataset[cursor]) + conversations.append((ids, mask)) + cursor += ddp_world_size + if len(conversations) % 5000 == 0: + print0(f"\r\033[KTokenizing: {len(conversations):,}/{num_convs:,} ({100*len(conversations)/num_convs:.0f}%)", end='', flush=True) + print0(f"\r\033[KTokenized {len(conversations):,} conversations", flush=True) + + batch_plans = [] conv_buffer = [] - cursor = ddp_rank # Each rank processes different conversations (for fetching) - consumed = ddp_rank # Track actual consumption separately from buffering - epoch = 1 - it = 0 # iteration counter + fetch_cursor = 0 + max_doc_count = 0 - def refill_buffer(): - nonlocal cursor, epoch - while len(conv_buffer) < buffer_size: - conversation = dataset[cursor] - ids, mask = tokenizer.render_conversation(conversation) - conv_buffer.append((ids, mask)) - cursor += ddp_world_size - if cursor >= dataset_size: - cursor = cursor % dataset_size - epoch += 1 - # Note: last_step is now triggered based on consumption, not fetching + def refill(): + nonlocal fetch_cursor + while len(conv_buffer) < buffer_size and fetch_cursor < len(conversations): + conv_buffer.append(fetch_cursor) + fetch_cursor += 1 while True: - rows = [] - mask_rows = [] - row_lengths = [] # Track actual content length (excluding padding) for each row - for _ in range(args.device_batch_size): - row = [] - mask_row = [] - padded = False - while len(row) < row_capacity: - # Ensure buffer has conversations - while len(conv_buffer) < buffer_size: - refill_buffer() - - remaining = row_capacity - len(row) - - # Find largest conversation that fits entirely - best_idx = -1 - best_len = 0 - for i, (conv, _) in enumerate(conv_buffer): - conv_len = len(conv) - if conv_len <= remaining and conv_len > best_len: - best_idx = i - best_len = conv_len - - if best_idx >= 0: - # Found a conversation that fits - use it entirely - conv, conv_mask = conv_buffer.pop(best_idx) - row.extend(conv) - mask_row.extend(conv_mask) - consumed += ddp_world_size # Track actual consumption - else: - # No conversation fits - pad the remainder instead of cropping - # This ensures we never discard any tokens - content_len = len(row) - row.extend([bos_token] * remaining) # Pad with BOS tokens - mask_row.extend([0] * remaining) - padded = True - break # Row is now full (with padding) - - # Track content length: full row if no padding, otherwise the length before padding - if padded: - row_lengths.append(content_len) + refill() + if not conv_buffer: + break + batch_indices = [] + pos = 0 + while pos < buffer_capacity: + refill() + if not conv_buffer: + break + remaining = buffer_capacity - pos + best_buf_idx = -1 + best_len = 0 + for i, conv_idx in enumerate(conv_buffer): + conv_len = len(conversations[conv_idx][0]) + if conv_len <= remaining and conv_len > best_len: + best_buf_idx = i + best_len = conv_len + if best_buf_idx >= 0: + batch_indices.append(conv_buffer.pop(best_buf_idx)) + pos += best_len else: - row_lengths.append(row_capacity) - rows.append(row[:row_capacity]) - mask_rows.append(mask_row[:row_capacity]) + break + if batch_indices: + doc_count = len(batch_indices) + (1 if pos < buffer_capacity else 0) + max_doc_count = max(max_doc_count, doc_count) + batch_plans.append(batch_indices) - # Stopping condition to respect num_iterations, if given - it += 1 - if 0 < args.num_iterations <= it and split == "train": - last_step = True + max_num_docs = math.ceil((max_doc_count + 1) / 16) * 16 + return conversations, batch_plans, len(batch_plans), max(max_num_docs, 16) - # Update progress tracking (based on consumed, not cursor, to account for buffering) - if split == "train": - current_epoch = epoch - if args.num_iterations > 0: - approx_progress = it / args.num_iterations - else: - approx_progress = consumed / dataset_size - # Trigger last_step when we've consumed enough (instead of when cursor wraps) - if consumed >= dataset_size: - last_step = True +bos_token = tokenizer.get_bos_token_id() +t_pack_start = time.time() +train_convs, train_plans, train_micro_batches, train_max_docs = tokenize_and_pack_sft( + train_dataset, tokenizer, args.device_batch_size, args.max_seq_len, + bos_token, ddp_rank, ddp_world_size) +t_pack_train = time.time() +val_convs, val_plans, val_micro_batches, val_max_docs = tokenize_and_pack_sft( + val_dataset, tokenizer, args.device_batch_size, args.max_seq_len, + bos_token, ddp_rank, ddp_world_size) +t_pack_val = time.time() +max_num_docs = max(train_max_docs, val_max_docs) +print0(f"Pre-tokenize & pack: train {t_pack_train - t_pack_start:.1f}s, val {t_pack_val - t_pack_train:.1f}s, total {t_pack_val - t_pack_start:.1f}s") - # Build tensors - use_cuda = device_type == "cuda" - batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda) - inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous() - targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous() +# Document length and packing statistics +import numpy as np +train_doc_lens = [len(ids) for ids, _ in train_convs] +train_docs_per_batch = [len(plan) for plan in train_plans] +train_tokens_per_batch = [sum(len(train_convs[i][0]) for i in plan) for plan in train_plans] +buffer_capacity = args.device_batch_size * args.max_seq_len + 1 +train_packing_eff = [t / buffer_capacity for t in train_tokens_per_batch] +dl = np.array(train_doc_lens) +dpb = np.array(train_docs_per_batch) +pe = np.array(train_packing_eff) +print0(f"Train doc lengths: n={len(dl):,} | mean={dl.mean():.0f} median={np.median(dl):.0f} " + f"min={dl.min()} max={dl.max()} p5={np.percentile(dl,5):.0f} p95={np.percentile(dl,95):.0f}") +print0(f"Train docs/batch: n={len(dpb):,} | mean={dpb.mean():.1f} median={np.median(dpb):.0f} " + f"min={dpb.min()} max={dpb.max()} p5={np.percentile(dpb,5):.0f} p95={np.percentile(dpb,95):.0f}") +print0(f"Train packing eff: mean={pe.mean():.3f} median={np.median(pe):.3f} " + f"min={pe.min():.3f} max={pe.max():.3f}") - # Apply the loss mask from render_conversation (mask=1 for assistant completions, - # mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns - # with targets (shifted by 1). Unmasked positions get -1 (ignore_index). - mask_tensor = torch.tensor(mask_rows, dtype=torch.int8) - mask_targets = mask_tensor[:, 1:].to(device=device) - targets[mask_targets == 0] = -1 +# num_iterations: exact count of optimization steps. The -1 accounts for the +# prefetch batch that the training loop requests but never trains on. +data_num_iterations = (train_micro_batches - 1) // grad_accum_steps +if args.num_iterations > 0: + num_iterations = min(args.num_iterations, data_num_iterations) +else: + num_iterations = data_num_iterations +if ddp: + num_iter_tensor = torch.tensor([num_iterations], dtype=torch.long, device=device) + dist.all_reduce(num_iter_tensor, op=dist.ReduceOp.MIN) + num_iterations = num_iter_tensor.item() +print0(f"Pre-packed {len(train_convs):,} train conversations into {train_micro_batches:,} micro-batches " + f"=> {num_iterations:,} optimization steps (max {max_num_docs} docs/batch)") - # Mask out padding positions in targets (set to -1 = ignore_index) - # For each row, positions >= (content_length - 1) in targets should be masked - for i, content_len in enumerate(row_lengths): - if content_len < row_capacity: - targets[i, content_len-1:] = -1 - - yield inputs, targets - -train_loader = sft_data_generator_bos_bestfit("train") -build_val_loader = lambda: sft_data_generator_bos_bestfit("val") -progress = 0 # will go from 0 to 1 over the course of the epoch +train_loader = sft_data_loader_varlen( + train_convs, train_plans, args.device_batch_size, args.max_seq_len, + max_num_docs, bos_token, device=device) +build_val_loader = lambda: sft_data_loader_varlen( + val_convs, val_plans, args.device_batch_size, args.max_seq_len, + max_num_docs, bos_token, device=device, cycle=True) # Learning rate schedule (linear warmup, constant, linear warmdown) -# Same shape as base_train but uses progress (0→1) instead of absolute step counts, -# because SFT doesn't always know num_iterations in advance (dataset-driven stopping). -def get_lr_multiplier(progress): - if progress < args.warmup_ratio: - return (progress + 1e-8) / args.warmup_ratio - elif progress <= 1.0 - args.warmdown_ratio: +def get_lr_multiplier(it): + warmup_iters = round(args.warmup_ratio * num_iterations) + warmdown_iters = round(args.warmdown_ratio * num_iterations) + if it < warmup_iters: + return (it + 1) / warmup_iters + elif it <= num_iterations - warmdown_iters: return 1.0 else: - decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio - return (1 - decay) * 1.0 + decay * args.final_lr_frac + progress = (num_iterations - it) / warmdown_iters + return progress * 1.0 + (1 - progress) * args.final_lr_frac # Momentum scheduler for Muon optimizer def get_muon_momentum(it): @@ -328,21 +320,16 @@ def get_muon_momentum(it): # ----------------------------------------------------------------------------- # Training loop -x, y = next(train_loader) # prefetch the very first batch of data +x, y, cu_seqlens = next(train_loader) # prefetch the very first batch of data min_val_bpb = float("inf") smooth_train_loss = 0 # EMA of training loss ema_beta = 0.9 # EMA decay factor total_training_time = 0 # total wall-clock time of training step = 0 while True: + last_step = step == num_iterations flops_so_far = num_flops_per_token * args.total_batch_size * step - # Synchronize last_step across all ranks to avoid hangs in the distributed setting - if ddp: - last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device) - dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX) - last_step = bool(last_step_tensor.item()) - # once in a while: evaluate the val bpb (all ranks participate) if last_step or (args.eval_every > 0 and step % args.eval_every == 0): model.eval() @@ -430,17 +417,16 @@ while True: synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): - loss = model(x, y) + loss = model(x, y, cu_seqlens=cu_seqlens) train_loss = loss.detach() # for logging loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here if scaler is not None: scaler.scale(loss).backward() else: loss.backward() - x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward - progress = max(progress, approx_progress) # only increase progress monotonically + x, y, cu_seqlens = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward # step the optimizer - lrm = get_lr_multiplier(progress) + lrm = get_lr_multiplier(step) muon_momentum = get_muon_momentum(step) for group in optimizer.param_groups: group["lr"] = group["initial_lr"] * lrm @@ -467,13 +453,13 @@ while True: # logging smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA - pct_done = 100 * progress + pct_done = 100 * step / num_iterations tok_per_sec = int(args.total_batch_size / dt) flops_per_sec = num_flops_per_token * args.total_batch_size / dt mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size) if step > 10: total_training_time += dt # only count the time after the first 10 steps - print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m") + print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m") if step % 10 == 0: wandb_run.log({ "step": step, @@ -484,7 +470,6 @@ while True: "train/dt": dt, "train/tok_per_sec": tok_per_sec, "train/mfu": mfu, - "train/epoch": current_epoch, }) # The garbage collector spends ~500ms scanning for cycles quite frequently. diff --git a/tests/test_attention_fallback.py b/tests/test_attention_fallback.py index 3eddc72..875279d 100644 --- a/tests/test_attention_fallback.py +++ b/tests/test_attention_fallback.py @@ -1,14 +1,14 @@ """ -Test Flash Attention unified interface - verify FA3 and SDPA produce identical results. +Test Flash Attention unified interface - verify FA (FA3/FA2) and SDPA produce identical results. Run: python -m pytest tests/test_attention_fallback.py -v -s Note on test structure: Tests are split into two classes due to dtype/device constraints: - 1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs - and verify they produce identical results. These require a Hopper GPU (FA3 only - works on sm90+) and use bfloat16 (FA3 doesn't support float32). + 1. TestFA3VsSDPA: Comparison tests that run both FA and SDPA on the same inputs + and verify they produce identical results. These require an Ampere+ GPU + (FA3 on Hopper, FA2 on Ampere/Ada) and use bfloat16. 2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run on any device (CUDA, CPU, MPS) with the appropriate dtype for that device. @@ -16,24 +16,29 @@ Note on test structure: import torch import pytest import nanochat.flash_attention as fa_module -from nanochat.flash_attention import flash_attn, HAS_FA3 +from nanochat.flash_attention import flash_attn, HAS_FA from nanochat.engine import KVCache def set_impl(impl): - """Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3.""" + """Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA.""" fa_module._override_impl = impl - fa_module.USE_FA3 = fa_module._resolve_use_fa3() + fa_module.USE_FA = fa_module._resolve_use_fa() def run_both_impls(fn): - """Run a function with both FA3 and SDPA, return both outputs.""" - set_impl('fa3') - out_fa3 = fn() + """Run a function with both FA (FA3 or FA2) and SDPA, return both outputs.""" + set_impl('fa') + out_fa = fn() set_impl('sdpa') out_sdpa = fn() set_impl(None) # reset - return out_fa3, out_sdpa + return out_fa, out_sdpa + + +def make_cu_seqlens(B, T, device): + """Create cu_seqlens for B documents each of length T.""" + return torch.arange(0, (B + 1) * T, T, dtype=torch.int32, device=device) def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): @@ -48,9 +53,9 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2): # ============================================================================= # FA3 vs SDPA comparison tests (require Hopper GPU) # ============================================================================= -@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations") +@pytest.mark.skipif(not HAS_FA, reason="FA required to compare implementations") class TestFA3VsSDPA: - """Compare FA3 and SDPA produce identical results. Requires Hopper GPU.""" + """Compare FA and SDPA produce identical results. Requires Ampere+ GPU.""" DEVICE = "cuda" DTYPE = torch.bfloat16 @@ -58,12 +63,15 @@ class TestFA3VsSDPA: def test_basic_causal(self): """Basic causal attention.""" B, T, H, D = 2, 64, 4, 32 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) def run(): - return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + return flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal") @@ -72,12 +80,15 @@ class TestFA3VsSDPA: def test_full_context(self): """Full context (window_size=-1).""" B, T, H, D = 2, 128, 4, 32 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) def run(): - return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) + return flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context") @@ -87,12 +98,15 @@ class TestFA3VsSDPA: """Sliding window attention.""" B, T, H, D = 2, 128, 4, 32 window = 32 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) def run(): - return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0)) + return flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(window, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window") @@ -104,12 +118,15 @@ class TestFA3VsSDPA: n_heads = 8 n_kv_heads = 2 - q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) def run(): - return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + return flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa") @@ -118,12 +135,15 @@ class TestFA3VsSDPA: def test_larger_model(self): """Larger dimensions closer to real model.""" B, T, H, D = 4, 256, 12, 64 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) def run(): - return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1)) + return flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1)) y_fa3, y_sdpa = run_both_impls(run) max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model") @@ -215,21 +235,24 @@ class TestFA3VsSDPA: def test_backward_gradients_match(self): """Verify gradients are similar between FA3 and SDPA.""" B, T, H, D = 2, 32, 4, 16 + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) - q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) def run(): q = q_data.clone().requires_grad_(True) k = k_data.clone().requires_grad_(True) v = v_data.clone().requires_grad_(True) - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + y = flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0)) loss = y.sum() loss.backward() return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach() - set_impl('fa3') + set_impl('fa') y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run() set_impl('sdpa') y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run() @@ -261,13 +284,16 @@ class TestSDPAOnly: """Test SDPA forward pass produces valid output.""" set_impl('sdpa') B, T, H, D = 2, 64, 4, 32 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + y = flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0)) - assert y.shape == (B, T, H, D) + assert y.shape == (B * T, H, D) assert not torch.isnan(y).any(), "Output contains NaN" set_impl(None) @@ -275,11 +301,14 @@ class TestSDPAOnly: """Test gradients flow through SDPA.""" set_impl('sdpa') B, T, H, D = 2, 32, 4, 16 - q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) - k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) - v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True) + cu_seqlens = make_cu_seqlens(B, T, self.DEVICE) - y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0)) + y = flash_attn.flash_attn_varlen_func(q, k, v, + cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, + max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0)) loss = y.sum() loss.backward() @@ -340,23 +369,23 @@ class TestSDPAOnly: class TestOverrideMechanism: """Test that the override mechanism works correctly.""" - @pytest.mark.skipif(not HAS_FA3, reason="FA3 required") - def test_override_fa3(self): - """Test that override='fa3' uses FA3.""" - set_impl('fa3') - assert fa_module.USE_FA3 == True + @pytest.mark.skipif(not HAS_FA, reason="FA required") + def test_override_fa(self): + """Test that override='fa' uses FA.""" + set_impl('fa') + assert fa_module.USE_FA == True set_impl(None) def test_override_sdpa(self): """Test that override='sdpa' uses SDPA.""" set_impl('sdpa') - assert fa_module.USE_FA3 == False + assert fa_module.USE_FA == False set_impl(None) def test_override_auto(self): """Test that override=None uses auto-detection.""" set_impl(None) - assert fa_module.USE_FA3 == HAS_FA3 + assert fa_module.USE_FA == HAS_FA if __name__ == "__main__": @@ -366,7 +395,7 @@ if __name__ == "__main__": print(f"CUDA device: {torch.cuda.get_device_name()}") major, minor = torch.cuda.get_device_capability() print(f"Compute capability: {major}.{minor}") - print(f"HAS_FA3: {HAS_FA3}") + print(f"HAS_FA: {HAS_FA}") print() pytest.main([__file__, "-v", "-s"])