mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 13:15:21 +00:00
Merge 6ee8fd6908 into c0dbf1f3ff
This commit is contained in:
commit
29ad903517
|
|
@ -200,3 +200,36 @@ This commit is special because all of the improvements that went into [this comm
|
|||
## Run 6
|
||||
|
||||
Achieved Mar 14, 2026 on commit `a825e63`. Exactly the same launch command as Run 4 except `--target-param-data-ratio=8`. Improvements in the architecture are allowing us to train shorter and shorter time. Instead of an undertrained d24 I attempted to train an overtrained d22 but it was worse. This set of changes came from autoresearch round 2, where I asked it to reference the modded-nanogpt repo for inspiration. So the exploration tried out a number of ideas and in particular found a way to incorporate the backout and smear in such a way that they are helpful (I had previously tried them manually a long time ago and they caused regressions). The smear idea in particular is a little bit heavier and bloaty because it is essentially an "early fusion" of context across tokens, producing a kind of a bigram input into the network and allowing it to focus on higher ngrams earlier. But for this reason the code gets a bit more complex and required some changes to inference. I verified with a unit test that the Engine inference is correct compared to the naive inference of `GPT.generate()`. The average of 5 runs was CORE 0.262634 and each of them lasted 1.65 hours (99 minutes).
|
||||
|
||||
## Run 7
|
||||
|
||||
Achieved Mar 22, 2026 on commit `a075166`. Launch command (same as `runs/speedrun.sh`):
|
||||
|
||||
```
|
||||
OMP_NUM_THREADS=1 torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- \
|
||||
--depth=24 \
|
||||
--run="d24-varlen" \
|
||||
--model-tag="d24-varlen" \
|
||||
--device-batch-size=16 \
|
||||
--sample-every=-1 \
|
||||
--save-every=-1 \
|
||||
--core-metric-max-per-task=-1 \
|
||||
--core-metric-every=999999 \
|
||||
--target-param-data-ratio=8 \
|
||||
--fp8
|
||||
```
|
||||
|
||||
Result:
|
||||
|
||||
```
|
||||
step 05567/05568 (99.98%) | loss: 2.388741 | lrm: 0.05 | dt: 1048.30ms | tok/sec: 1,000,264 | bf16_mfu: 60.37 | epoch: 1 pq: 117 rg: 64 | total time: 97.38m | eta: 0.0m
|
||||
Step 05568 | Validation bpb: 0.724772
|
||||
Step 05568 | CORE metric: 0.2614
|
||||
Peak memory usage: 52865.94MiB
|
||||
Total training time: 97.38m
|
||||
Minimum validation bpb: 0.724772
|
||||
```
|
||||
|
||||
The big change in this run is switching nanochat's entire pre-training attention path from standard batched attention to FlashAttention's variable-length packed attention (`flash_attn_varlen_func` with `cu_seqlens`). Each document now attends only to itself, which saves compute by not calculating attention scores across document boundaries. The dataloader is also simplified: the old bestfit bin-packing is replaced by greedy 1D packing where only the final document in each micro-batch is truncated. See the PR description in [dev/PULL_REQUEST.md](PULL_REQUEST.md) for full details.
|
||||
|
||||
Previous record was 1.65 hours / 99 minutes (Run 6), so 1.62 hours / 97.38 minutes is ~1.8% speed improvement. Model checkpoint: [ChrisMcCormick/nanochat-varlen-d24-2026-03-22](https://huggingface.co/ChrisMcCormick/nanochat-varlen-d24-2026-03-22).
|
||||
|
|
|
|||
|
|
@ -101,15 +101,6 @@ def find_common_length(token_sequences, direction='left'):
|
|||
return min_len
|
||||
|
||||
|
||||
def stack_sequences(tokens, pad_token_id):
|
||||
"""Stack up a list of token sequences, pad to longest on the right"""
|
||||
bsz, seq_len = len(tokens), max(len(x) for x in tokens)
|
||||
input_ids = torch.full((bsz, seq_len), pad_token_id, dtype=torch.long)
|
||||
for i, x in enumerate(tokens):
|
||||
input_ids[i, :len(x)] = torch.tensor(x, dtype=torch.long)
|
||||
return input_ids
|
||||
|
||||
|
||||
def batch_sequences_mc(tokenizer, prompts):
|
||||
# In multiple choice, contexts are the same but the continuation is different (common prefix)
|
||||
tokens = tokenizer(prompts, prepend=tokenizer.get_bos_token_id())
|
||||
|
|
@ -142,26 +133,31 @@ def batch_sequences_lm(tokenizer, prompts):
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def forward_model(model, input_ids):
|
||||
def forward_model(model, tokens_list, device):
|
||||
"""
|
||||
Take BxT tensor of token ids, return BxT tensor of losses and argmax predictions.
|
||||
The last column of losses is set to nan because we don't have autoregressive targets there.
|
||||
Pack token sequences into varlen, forward, extract per-sequence losses and predictions.
|
||||
Returns (losses_list, preds_list) where each element corresponds to one input sequence.
|
||||
The last element of each loss vector is nan (no autoregressive target there).
|
||||
"""
|
||||
batch_size, seq_len = input_ids.size()
|
||||
outputs = model(input_ids)
|
||||
# Roll the tensor to the left by one position to get the (autoregressive) target ids
|
||||
target_ids = torch.roll(input_ids, shifts=-1, dims=1)
|
||||
# Calculate cross entropy at all positions
|
||||
losses = torch.nn.functional.cross_entropy(
|
||||
outputs.view(batch_size * seq_len, -1),
|
||||
target_ids.view(batch_size * seq_len),
|
||||
reduction='none'
|
||||
).view(batch_size, seq_len)
|
||||
# Set the last column to be nan because there is no autoregressive loss there
|
||||
losses[:, -1] = float('nan')
|
||||
# Get the argmax predictions at each position
|
||||
predictions = outputs.argmax(dim=-1)
|
||||
return losses, predictions
|
||||
packed = torch.cat([torch.tensor(t, dtype=torch.long, device=device) for t in tokens_list])
|
||||
cu_seqlens = torch.zeros(len(tokens_list) + 1, dtype=torch.int32, device=device)
|
||||
for i, t in enumerate(tokens_list):
|
||||
cu_seqlens[i + 1] = cu_seqlens[i] + len(t)
|
||||
|
||||
outputs = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V)
|
||||
|
||||
losses_list = []
|
||||
preds_list = []
|
||||
for i in range(len(tokens_list)):
|
||||
s, e = cu_seqlens[i].item(), cu_seqlens[i + 1].item()
|
||||
doc_logits = outputs[s:e]
|
||||
doc_tokens = packed[s:e]
|
||||
doc_targets = torch.roll(doc_tokens, -1)
|
||||
doc_losses = torch.nn.functional.cross_entropy(doc_logits[:-1], doc_targets[:-1], reduction='none')
|
||||
doc_losses = torch.cat([doc_losses, torch.tensor([float('nan')], device=device)])
|
||||
losses_list.append(doc_losses)
|
||||
preds_list.append(doc_logits.argmax(dim=-1))
|
||||
return losses_list, preds_list
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -212,26 +208,21 @@ def evaluate_example(idx, model, tokenizer, data, device, task_meta):
|
|||
new_end_idxs.append(e)
|
||||
tokens, start_idxs, end_idxs = new_tokens, new_start_idxs, new_end_idxs
|
||||
|
||||
# Stack up all the sequences into a batch
|
||||
pad_token_id = tokenizer.get_bos_token_id() # use BOS as pad token is ok
|
||||
input_ids = stack_sequences(tokens, pad_token_id)
|
||||
input_ids = input_ids.to(device)
|
||||
|
||||
# Forward the model, get the autoregressive loss and argmax prediction at each token
|
||||
losses, predictions = forward_model(model, input_ids)
|
||||
# Forward the model with varlen packing (no padding waste)
|
||||
losses_list, preds_list = forward_model(model, tokens, device)
|
||||
|
||||
# See if the losses/predictions come out correctly
|
||||
if task_type == 'language_modeling':
|
||||
# language modeling task is currently always batch size 1
|
||||
si = start_idxs[0]
|
||||
ei = end_idxs[0]
|
||||
# predictions[i] predict input_ids[i+1] autoregressively
|
||||
predicted_tokens = predictions[0, si-1:ei-1]
|
||||
actual_tokens = input_ids[0, si:ei]
|
||||
# preds_list[i][j] predicts tokens[i][j+1] autoregressively
|
||||
predicted_tokens = preds_list[0][si-1:ei-1]
|
||||
actual_tokens = torch.tensor(tokens[0][si:ei], device=device)
|
||||
is_correct = torch.all(predicted_tokens == actual_tokens).item()
|
||||
elif task_type in ['multiple_choice', 'schema']:
|
||||
# For MC/schema: find the option with lowest average loss
|
||||
mean_losses = [losses[i, si-1:ei-1].mean().item()
|
||||
mean_losses = [losses_list[i][si-1:ei-1].mean().item()
|
||||
for i, (si, ei) in enumerate(zip(start_idxs, end_idxs))]
|
||||
pred_idx = mean_losses.index(min(mean_losses))
|
||||
is_correct = pred_idx == item['gold']
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""
|
||||
Distributed dataloaders for pretraining.
|
||||
Distributed dataloader for pretraining.
|
||||
|
||||
BOS-aligned bestfit:
|
||||
- Every row starts with BOS token
|
||||
- Documents packed using best-fit algorithm to minimize cropping
|
||||
- When no document fits remaining space, crops a document to fill exactly
|
||||
- 100% utilization (no padding), ~35% tokens cropped at T=2048
|
||||
Varlen 1D packing:
|
||||
- Packs documents into 1D buffer with cu_seqlens for per-document attention isolation
|
||||
- No cropping, no padding: every token is used exactly once
|
||||
- Yields (inputs_1d, targets_1d, cu_seqlens) for flash_attn_varlen_func
|
||||
|
||||
Compared to the original tokenizing_distributed_data_loader:
|
||||
BOS-aligned loses ~35% of tokens to cropping, but ensures that
|
||||
|
|
@ -71,31 +70,43 @@ def _document_batches(split, resume_state_dict, tokenizer_batch_size):
|
|||
epoch += 1
|
||||
|
||||
|
||||
def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
||||
tokenizer, B, T, split,
|
||||
|
||||
# =============================================================================
|
||||
# 1D packed varlen dataloader
|
||||
# =============================================================================
|
||||
# Packs documents into a single flat buffer of B*T tokens with cu_seqlens marking
|
||||
# document boundaries for flash_attn_varlen_func. Each document gets its own
|
||||
# attention context. Greedy packing: documents are added sequentially until the
|
||||
# buffer is full. Only the last document in each micro-batch gets cropped.
|
||||
#
|
||||
# Requires specificying a fixed maximum number of docs supported per batch.
|
||||
# The dataloader will append additional documents to the final segment if needed,
|
||||
# resulting in cross-document attention bleeding, but that hasn't been a problem
|
||||
# in practice.
|
||||
# It's recommended to keep max_num_docs tight rather than padding it conservatively
|
||||
# because an oversized `cu_seqlens` tensor will hurt FlashAttention performance
|
||||
# somewhat.
|
||||
|
||||
def tokenizing_distributed_data_loader_varlen(
|
||||
tokenizer, B, T, split, max_num_docs,
|
||||
tokenizer_threads=4, tokenizer_batch_size=128,
|
||||
device="cuda", resume_state_dict=None,
|
||||
buffer_size=1000
|
||||
):
|
||||
"""
|
||||
BOS-aligned dataloader with Best-Fit Cropping.
|
||||
1D packed varlen dataloader for use with flash_attn_varlen_func.
|
||||
|
||||
Reduces token waste compared to simple greedy cropping by searching a buffer
|
||||
for documents that fit well, while maintaining 100% utilization (no padding).
|
||||
|
||||
Algorithm for each row:
|
||||
1. From buffered docs, pick the LARGEST doc that fits entirely
|
||||
2. Repeat until no doc fits
|
||||
3. When nothing fits, crop a doc to fill remaining space exactly
|
||||
|
||||
Key properties:
|
||||
- Every row starts with BOS
|
||||
- 100% utilization (no padding, every token is trained on)
|
||||
- Approximately 35% of all tokens are discarded due to cropping
|
||||
Yields (inputs, targets, cu_seqlens, state_dict) where:
|
||||
- inputs: 1D long tensor of shape (B*T,)
|
||||
- targets: 1D long tensor of shape (B*T,), shifted by 1
|
||||
- cu_seqlens: int32 tensor of shape (max_num_docs,), cumulative doc lengths
|
||||
padded with total_tokens for unused slots (ghost segments of length 0)
|
||||
- state_dict: {"pq_idx", "rg_idx", "epoch"} for checkpoint resume
|
||||
"""
|
||||
assert split in ["train", "val"], "split must be 'train' or 'val'"
|
||||
|
||||
row_capacity = T + 1
|
||||
total_tokens = B * T
|
||||
buffer_capacity = total_tokens + 1 # +1 so the last input position has a target
|
||||
|
||||
batches = _document_batches(split, resume_state_dict, tokenizer_batch_size)
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
doc_buffer = []
|
||||
|
|
@ -105,62 +116,139 @@ def tokenizing_distributed_data_loader_with_state_bos_bestfit(
|
|||
nonlocal pq_idx, rg_idx, epoch
|
||||
doc_batch, (pq_idx, rg_idx, epoch) = next(batches)
|
||||
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
|
||||
for tokens in token_lists:
|
||||
doc_buffer.append(tokens)
|
||||
doc_buffer.extend(token_lists)
|
||||
|
||||
# Pre-allocate buffers once: layout is [inputs (B*T) | targets (B*T)]
|
||||
# This gives us contiguous views and a single HtoD transfer
|
||||
# Pre-allocate all buffers once
|
||||
use_cuda = device == "cuda"
|
||||
row_buffer = torch.empty((B, row_capacity), dtype=torch.long) # for building rows without creating Python lists
|
||||
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=use_cuda) # staging area (CPU)
|
||||
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device=device) # on-device buffer
|
||||
cpu_inputs = cpu_buffer[:B * T].view(B, T) # a few views into these buffers just for convenience
|
||||
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
||||
inputs = gpu_buffer[:B * T].view(B, T)
|
||||
targets = gpu_buffer[B * T:].view(B, T)
|
||||
pack_buffer = torch.empty(buffer_capacity, dtype=torch.long) # 1D packing workspace
|
||||
cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device)
|
||||
cpu_inputs = cpu_buffer[:total_tokens]
|
||||
cpu_targets = cpu_buffer[total_tokens:]
|
||||
inputs = gpu_buffer[:total_tokens]
|
||||
targets = gpu_buffer[total_tokens:]
|
||||
cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32)
|
||||
cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device)
|
||||
|
||||
warned = False
|
||||
warned_seqlen = False
|
||||
while True:
|
||||
for row_idx in range(B):
|
||||
pos = 0
|
||||
while pos < row_capacity:
|
||||
# Ensure buffer has documents
|
||||
while len(doc_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
# Greedily pack documents into a single 1D buffer
|
||||
pos = 0
|
||||
doc_count = 0
|
||||
cu_seqlens_cpu[0] = 0
|
||||
|
||||
remaining = row_capacity - pos
|
||||
while pos < buffer_capacity:
|
||||
while len(doc_buffer) == 0:
|
||||
refill_buffer()
|
||||
|
||||
# Find largest doc that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, doc in enumerate(doc_buffer):
|
||||
doc_len = len(doc)
|
||||
if doc_len <= remaining and doc_len > best_len:
|
||||
best_idx = i
|
||||
best_len = doc_len
|
||||
doc = doc_buffer.pop(0)
|
||||
doc_len = min(len(doc), T) # truncate to max_seq_len
|
||||
remaining = buffer_capacity - pos
|
||||
use_len = min(doc_len, remaining) # crop last doc to fill exactly
|
||||
|
||||
if best_idx >= 0:
|
||||
doc = doc_buffer.pop(best_idx)
|
||||
doc_len = len(doc)
|
||||
row_buffer[row_idx, pos:pos + doc_len] = torch.tensor(doc, dtype=torch.long)
|
||||
pos += doc_len
|
||||
else:
|
||||
# No doc fits - crop shortest in buffer to fill remaining and minimize waste
|
||||
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
||||
doc = doc_buffer.pop(shortest_idx)
|
||||
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
||||
pos += remaining
|
||||
pack_buffer[pos:pos + use_len] = torch.tensor(doc[:use_len], dtype=torch.long)
|
||||
pos += use_len
|
||||
if doc_count < max_num_docs - 1:
|
||||
doc_count += 1
|
||||
cu_seqlens_cpu[doc_count] = min(pos, total_tokens)
|
||||
else:
|
||||
if not warned:
|
||||
print(f"Warning: too many documents for cu_seqlens size ({max_num_docs}), "
|
||||
f"merging remaining docs (cross-document attention bleeding)")
|
||||
warned = True
|
||||
merged_len = min(pos, total_tokens) - cu_seqlens_cpu[doc_count].item()
|
||||
if merged_len > T and not warned_seqlen:
|
||||
print(f"Warning: merged segment length ({merged_len}) exceeds max_seq_len ({T}). "
|
||||
f"Increase max_num_docs to avoid silent attention truncation.")
|
||||
warned_seqlen = True
|
||||
|
||||
# Copy to pinned CPU buffer, then single HtoD transfer
|
||||
cpu_inputs.copy_(row_buffer[:, :-1])
|
||||
cpu_targets.copy_(row_buffer[:, 1:])
|
||||
# Ensure the final document boundary always points to the end of the batch
|
||||
cu_seqlens_cpu[doc_count] = total_tokens
|
||||
|
||||
# Pad remaining cu_seqlens slots (ghost segments of length 0)
|
||||
cu_seqlens_cpu[doc_count + 1:] = total_tokens
|
||||
|
||||
# Split into inputs/targets (standard next-token prediction shift)
|
||||
cpu_inputs.copy_(pack_buffer[:total_tokens])
|
||||
cpu_targets.copy_(pack_buffer[1:total_tokens + 1])
|
||||
|
||||
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx, "epoch": epoch}
|
||||
|
||||
# Single HtoD copy into persistent GPU buffer and yield
|
||||
# H2D transfer: single copy for tokens, small copy for cu_seqlens
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
yield inputs, targets, state_dict
|
||||
cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda)
|
||||
yield inputs, targets, cu_seqlens_gpu, state_dict
|
||||
|
||||
def tokenizing_distributed_data_loader_bos_bestfit(*args, **kwargs):
|
||||
"""Helper that omits state_dict from yields."""
|
||||
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state_bos_bestfit(*args, **kwargs):
|
||||
yield inputs, targets
|
||||
|
||||
# =============================================================================
|
||||
# SFT varlen dataloader (replay from pre-packed batch plans)
|
||||
# =============================================================================
|
||||
|
||||
def sft_data_loader_varlen(
|
||||
conversations, batch_plan, B, T, max_num_docs, bos_token,
|
||||
device="cuda", cycle=False,
|
||||
):
|
||||
"""
|
||||
Replay dataloader for SFT: constructs 1D-packed varlen batches from
|
||||
pre-computed batch plans (see tokenize_and_pack_sft in chat_sft.py).
|
||||
|
||||
Args:
|
||||
conversations: list of (ids, mask) tuples (pre-tokenized)
|
||||
batch_plan: list of lists of conversation indices
|
||||
B, T: batch dimensions (total_tokens = B * T)
|
||||
max_num_docs: cu_seqlens tensor size (exact max from pre-packing)
|
||||
bos_token: BOS token id for padding
|
||||
device: target device
|
||||
cycle: if True, repeat the batch plan indefinitely (for val eval)
|
||||
"""
|
||||
total_tokens = B * T
|
||||
buffer_capacity = total_tokens + 1
|
||||
use_cuda = torch.device(device).type == "cuda"
|
||||
|
||||
pack_buffer = torch.empty(buffer_capacity, dtype=torch.long)
|
||||
mask_buffer = torch.empty(buffer_capacity, dtype=torch.int8)
|
||||
cpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, pin_memory=use_cuda)
|
||||
gpu_buffer = torch.empty(2 * total_tokens, dtype=torch.long, device=device)
|
||||
cpu_inputs = cpu_buffer[:total_tokens]
|
||||
cpu_targets = cpu_buffer[total_tokens:]
|
||||
inputs = gpu_buffer[:total_tokens]
|
||||
targets = gpu_buffer[total_tokens:]
|
||||
cu_seqlens_cpu = torch.empty(max_num_docs, dtype=torch.int32)
|
||||
cu_seqlens_gpu = torch.empty(max_num_docs, dtype=torch.int32, device=device)
|
||||
|
||||
while True:
|
||||
for conv_indices in batch_plan:
|
||||
pos = 0
|
||||
doc_count = 0
|
||||
cu_seqlens_cpu[0] = 0
|
||||
|
||||
for conv_idx in conv_indices:
|
||||
ids, mask = conversations[conv_idx]
|
||||
conv_len = len(ids)
|
||||
pack_buffer[pos:pos + conv_len] = torch.tensor(ids, dtype=torch.long)
|
||||
mask_buffer[pos:pos + conv_len] = torch.tensor(mask, dtype=torch.int8)
|
||||
pos += conv_len
|
||||
doc_count += 1
|
||||
cu_seqlens_cpu[doc_count] = min(pos, total_tokens)
|
||||
|
||||
if pos < buffer_capacity:
|
||||
remaining = buffer_capacity - pos
|
||||
pack_buffer[pos:pos + remaining] = bos_token
|
||||
mask_buffer[pos:pos + remaining] = 0
|
||||
doc_count += 1
|
||||
cu_seqlens_cpu[doc_count] = total_tokens
|
||||
|
||||
cu_seqlens_cpu[doc_count + 1:] = total_tokens
|
||||
|
||||
cpu_inputs.copy_(pack_buffer[:total_tokens])
|
||||
cpu_targets.copy_(pack_buffer[1:total_tokens + 1])
|
||||
target_mask = mask_buffer[1:total_tokens + 1]
|
||||
cpu_targets[target_mask == 0] = -1
|
||||
|
||||
gpu_buffer.copy_(cpu_buffer, non_blocking=use_cuda)
|
||||
cu_seqlens_gpu.copy_(cu_seqlens_cpu, non_blocking=use_cuda)
|
||||
yield inputs, targets, cu_seqlens_gpu
|
||||
|
||||
if not cycle:
|
||||
break
|
||||
|
|
|
|||
|
|
@ -1,66 +1,86 @@
|
|||
"""
|
||||
Unified Flash Attention interface with automatic FA3/SDPA switching.
|
||||
Unified Flash Attention interface with three-tier automatic backend selection:
|
||||
|
||||
Exports `flash_attn` module that matches the FA3 API exactly, but falls back
|
||||
to PyTorch SDPA on non-Hopper GPUs (including Blackwell), MPS, and CPU.
|
||||
FA3 (Hopper sm90) -> FA2 (Ampere sm80 / Ada sm89) -> PyTorch SDPA fallback
|
||||
|
||||
Usage (drop-in replacement for FA3):
|
||||
Exports `flash_attn` module with two functions:
|
||||
|
||||
Usage:
|
||||
from nanochat.flash_attention import flash_attn
|
||||
|
||||
# Training (no KV cache)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
# Training with packed variable-length sequences: q, k, v are (total_tokens, H, D)
|
||||
y = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, ...)
|
||||
|
||||
# Inference (with KV cache)
|
||||
y = flash_attn.flash_attn_with_kvcache(q, k_cache, v_cache, k=k, v=v, ...)
|
||||
|
||||
All non-cached forward passes go through varlen. (B, T) callers that don't provide
|
||||
cu_seqlens get it auto-constructed in GPT.forward.
|
||||
|
||||
FA3 and FA2 both support flash_attn_varlen_func with per-document attention isolation.
|
||||
The SDPA fallback reshapes to (B, T_seq) and uses is_causal=True -- no doc isolation,
|
||||
but efficient kernels on all hardware (Mac, CPU, Blackwell, older GPUs).
|
||||
"""
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Detection: Try to load FA3 on Hopper+ GPUs
|
||||
# Detection: Try FA3 (Hopper), then FA2 (Ampere/Ada), then SDPA fallback
|
||||
# =============================================================================
|
||||
def _load_flash_attention_3():
|
||||
"""Try to load Flash Attention 3 (requires Hopper GPU, sm90)."""
|
||||
def _load_flash_attention():
|
||||
"""Try to load Flash Attention kernels. Returns (module, version_string) or (None, None)."""
|
||||
if not torch.cuda.is_available():
|
||||
return None
|
||||
return None, None
|
||||
try:
|
||||
major, _ = torch.cuda.get_device_capability()
|
||||
# FA3 kernels are compiled for Hopper (sm90) only
|
||||
# Ada (sm89), Blackwell (sm100) need SDPA fallback until FA3 is recompiled
|
||||
if major != 9:
|
||||
return None
|
||||
import os
|
||||
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
||||
from kernels import get_kernel
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
|
||||
# FA3: Hopper (sm90) only
|
||||
if major == 9:
|
||||
try:
|
||||
return get_kernel('varunneal/flash-attention-3').flash_attn_interface, 'fa3'
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# FA2: Ampere (sm80), Ada (sm89), and Hopper fallback
|
||||
if major >= 8:
|
||||
try:
|
||||
return get_kernel('kernels-community/flash-attn2').flash_attn_interface, 'fa2'
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
return None
|
||||
pass
|
||||
return None, None
|
||||
|
||||
|
||||
_fa3 = _load_flash_attention_3()
|
||||
HAS_FA3 = _fa3 is not None
|
||||
_fa, FA_VERSION = _load_flash_attention()
|
||||
HAS_FA = _fa is not None
|
||||
|
||||
# Override for testing: set to 'fa3', 'sdpa', or None (auto)
|
||||
# Override for testing: set to 'fa3', 'fa2', 'sdpa', or None (auto)
|
||||
_override_impl = None
|
||||
|
||||
|
||||
def _resolve_use_fa3():
|
||||
"""Decide once whether to use FA3, based on availability, override, and dtype."""
|
||||
if _override_impl == 'fa3':
|
||||
assert HAS_FA3, "Cannot override to FA3: not available on this hardware"
|
||||
def _resolve_use_fa():
|
||||
"""Decide once whether to use FA, based on availability, override, and dtype."""
|
||||
if _override_impl in ('fa3', 'fa2', 'fa'):
|
||||
assert HAS_FA, "Cannot override to FA: not available on this hardware"
|
||||
return True
|
||||
if _override_impl == 'sdpa':
|
||||
return False
|
||||
if HAS_FA3:
|
||||
# FA3 Hopper kernels only support bf16 and fp8; fp16/fp32 must use SDPA fallback
|
||||
if HAS_FA:
|
||||
from nanochat.common import COMPUTE_DTYPE
|
||||
if COMPUTE_DTYPE == torch.bfloat16:
|
||||
return True
|
||||
return False
|
||||
if FA_VERSION == 'fa3':
|
||||
# FA3 Hopper kernels only support bf16 and fp8
|
||||
return COMPUTE_DTYPE == torch.bfloat16
|
||||
else:
|
||||
# FA2 supports bf16 and fp16
|
||||
return COMPUTE_DTYPE in (torch.bfloat16, torch.float16)
|
||||
return False
|
||||
|
||||
USE_FA3 = _resolve_use_fa3()
|
||||
USE_FA = _resolve_use_fa()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
|
@ -101,31 +121,57 @@ def _sdpa_attention(q, k, v, window_size, enable_gqa):
|
|||
|
||||
return F.scaled_dot_product_attention(q, k, v, attn_mask=mask, enable_gqa=enable_gqa)
|
||||
|
||||
|
||||
def _sdpa_varlen_attention(q, k, v, max_seqlen, window_size, enable_gqa):
|
||||
"""
|
||||
SDPA fallback for varlen: reshapes packed (T, H, D) to (B, T_seq, H, D)
|
||||
and uses standard causal SDPA. No document isolation (cross-doc bleeding
|
||||
within each T_seq chunk), but uses efficient is_causal=True kernels.
|
||||
"""
|
||||
T, H, D = q.shape
|
||||
H_kv = k.shape[1]
|
||||
B = T // max_seqlen
|
||||
q = q.view(B, max_seqlen, H, D).transpose(1, 2)
|
||||
k = k.view(B, max_seqlen, H_kv, D).transpose(1, 2)
|
||||
v = v.view(B, max_seqlen, H_kv, D).transpose(1, 2)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
return y.transpose(1, 2).reshape(T, H, D)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Public API: Same interface as FA3
|
||||
# =============================================================================
|
||||
def flash_attn_func(q, k, v, causal=False, window_size=(-1, -1)):
|
||||
def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
causal=False, window_size=(-1, -1)):
|
||||
"""
|
||||
Flash Attention for training (no KV cache).
|
||||
Flash Attention for packed variable-length sequences (training, no KV cache).
|
||||
|
||||
1D packed inputs where multiple documents are concatenated into one buffer.
|
||||
Each document attends only to itself, with boundaries defined by cu_seqlens.
|
||||
|
||||
Args:
|
||||
q, k, v: Tensors of shape (B, T, H, D)
|
||||
causal: Whether to use causal masking
|
||||
q, k, v: Tensors of shape (total_tokens, H, D)
|
||||
cu_seqlens_q, cu_seqlens_k: Cumulative sequence lengths, shape (max_num_seqs,).
|
||||
Format: [0, end_doc1, end_doc2, ..., total, total, ...]
|
||||
max_seqlen_q, max_seqlen_k: Max individual sequence length (FA3 tiling hint).
|
||||
causal: Whether to use causal masking.
|
||||
window_size: (left, right) sliding window. -1 means unlimited.
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (B, T, H, D)
|
||||
Output tensor of shape (total_tokens, H, D)
|
||||
"""
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_func(q, k, v, causal=causal, window_size=window_size)
|
||||
if USE_FA:
|
||||
return _fa.flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k,
|
||||
causal=causal, window_size=window_size,
|
||||
)
|
||||
|
||||
# SDPA fallback: transpose (B, T, H, D) -> (B, H, T, D)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
# SDPA fallback: reshape to (B, T_seq) and use standard causal SDPA (no doc isolation)
|
||||
enable_gqa = q.size(1) != k.size(1)
|
||||
y = _sdpa_attention(q, k, v, window_size, enable_gqa)
|
||||
return y.transpose(1, 2) # back to (B, T, H, D)
|
||||
return _sdpa_varlen_attention(q, k, v, max_seqlen_q, window_size, enable_gqa)
|
||||
|
||||
|
||||
def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=None,
|
||||
|
|
@ -146,8 +192,8 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
Returns:
|
||||
Output tensor of shape (B, T_new, H, D)
|
||||
"""
|
||||
if USE_FA3:
|
||||
return _fa3.flash_attn_with_kvcache(
|
||||
if USE_FA and hasattr(_fa, 'flash_attn_with_kvcache'):
|
||||
return _fa.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache, k=k, v=v, cache_seqlens=cache_seqlens,
|
||||
causal=causal, window_size=window_size
|
||||
)
|
||||
|
|
@ -178,10 +224,10 @@ def flash_attn_with_kvcache(q, k_cache, v_cache, k=None, v=None, cache_seqlens=N
|
|||
|
||||
|
||||
# =============================================================================
|
||||
# Export: flash_attn module interface (drop-in replacement for FA3)
|
||||
# Export: flash_attn module interface (drop-in replacement for FA3/FA2)
|
||||
# =============================================================================
|
||||
from types import SimpleNamespace
|
||||
flash_attn = SimpleNamespace(
|
||||
flash_attn_func=flash_attn_func,
|
||||
flash_attn_varlen_func=flash_attn_varlen_func,
|
||||
flash_attn_with_kvcache=flash_attn_with_kvcache,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.ve_gate_channels = 12
|
||||
self.ve_gate = Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
|
|
@ -103,10 +103,7 @@ class CausalSelfAttention(nn.Module):
|
|||
|
||||
# Flash Attention (FA3 on Hopper+, PyTorch SDPA fallback elsewhere)
|
||||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
if kv_cache is not None:
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
|
|
@ -119,6 +116,15 @@ class CausalSelfAttention(nn.Module):
|
|||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
else:
|
||||
# Varlen: packed 1D sequence with per-document attention isolation
|
||||
assert cu_seqlens is not None
|
||||
y = flash_attn.flash_attn_varlen_func(
|
||||
q[0], k[0], v[0],
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=max_seq_len, max_seqlen_k=max_seq_len,
|
||||
causal=True, window_size=window_size)
|
||||
y = y.unsqueeze(0)
|
||||
|
||||
# Re-assemble the heads and project back to residual stream
|
||||
y = y.contiguous().view(B, T, -1)
|
||||
|
|
@ -145,8 +151,8 @@ class Block(nn.Module):
|
|||
self.attn = CausalSelfAttention(config, layer_idx)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache)
|
||||
def forward(self, x, ve, cos_sin, window_size, kv_cache, cu_seqlens=None, max_seq_len=None):
|
||||
x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache, cu_seqlens, max_seq_len)
|
||||
x = x + self.mlp(norm(x))
|
||||
return x
|
||||
|
||||
|
|
@ -189,10 +195,11 @@ class GPT(nn.Module):
|
|||
kv_dim = config.n_kv_head * head_dim
|
||||
self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)})
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only.
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them by 10X, but assert fail if we ever reach that amount.
|
||||
# In the future we can dynamically grow the cache, for now it's fine.
|
||||
self.rotary_seq_len = config.sequence_len * 10 # 10X over-compute should be enough, TODO make nicer?
|
||||
# Rotary embeddings are small in memory, so we over-compute generously. With varlen
|
||||
# training the full micro-batch is one sequence (T = batch_size * seq_len), so we need
|
||||
# enough headroom for that. 64X covers batch sizes up to 64, and the assert in forward
|
||||
# will catch if we ever exceed.
|
||||
self.rotary_seq_len = config.sequence_len * 64
|
||||
head_dim = config.n_embd // config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
|
|
@ -408,7 +415,18 @@ class GPT(nn.Module):
|
|||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
def forward(self, idx, targets=None, cu_seqlens=None, kv_cache=None, loss_reduction='mean'):
|
||||
if cu_seqlens is not None:
|
||||
assert idx.ndim == 1
|
||||
idx = idx.unsqueeze(0)
|
||||
if targets is not None:
|
||||
targets = targets.unsqueeze(0)
|
||||
max_seq_len = self.config.sequence_len
|
||||
elif kv_cache is not None:
|
||||
max_seq_len = None
|
||||
else:
|
||||
raise ValueError("GPT.forward requires either cu_seqlens or kv_cache")
|
||||
|
||||
B, T = idx.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||
|
|
@ -451,7 +469,7 @@ class GPT(nn.Module):
|
|||
for i, block in enumerate(self.transformer.h):
|
||||
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
||||
ve = self.value_embeds[str(i)](idx).to(x.dtype) if str(i) in self.value_embeds else None
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache)
|
||||
x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache, cu_seqlens, max_seq_len)
|
||||
if i == backout_layer:
|
||||
x_backout = x
|
||||
# Subtract mid-layer residual to remove low-level features before logit projection
|
||||
|
|
@ -489,10 +507,11 @@ class GPT(nn.Module):
|
|||
if temperature > 0:
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
ids = torch.tensor(tokens, dtype=torch.long, device=device) # 1D, no batch dim
|
||||
for _ in range(max_tokens):
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
cu_seqlens = torch.tensor([0, ids.size(0)], dtype=torch.int32, device=device)
|
||||
logits = self.forward(ids, cu_seqlens=cu_seqlens) # (1, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (1, vocab_size)
|
||||
if top_k is not None and top_k > 0:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
|
@ -502,6 +521,6 @@ class GPT(nn.Module):
|
|||
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
else:
|
||||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
ids = torch.cat((ids, next_ids), dim=1)
|
||||
ids = torch.cat((ids, next_ids.squeeze(0))) # stay 1D
|
||||
token = next_ids.item()
|
||||
yield token
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ def evaluate_bpb(model, batches, steps, token_bytes):
|
|||
total_bytes = torch.tensor(0, dtype=torch.int64, device=model.get_device())
|
||||
batch_iter = iter(batches)
|
||||
for _ in range(steps):
|
||||
x, y = next(batch_iter)
|
||||
loss2d = model(x, y, loss_reduction='none') # (B, T)
|
||||
x, y, cu_seqlens, *_ = next(batch_iter)
|
||||
loss2d = model(x, y, cu_seqlens=cu_seqlens, loss_reduction='none')
|
||||
loss2d = loss2d.view(-1) # flatten
|
||||
y = y.view(-1) # flatten
|
||||
if (y.int() < 0).any(): # mps does not currently have kernel for < 0 for int64, only int32
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ Examples:
|
|||
"""
|
||||
import os
|
||||
import csv
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import yaml
|
||||
|
|
@ -35,7 +36,7 @@ from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir,
|
|||
from nanochat.tokenizer import HuggingFaceTokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_varlen
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -48,7 +49,9 @@ class ModelWrapper:
|
|||
self.model = model
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
def __call__(self, input_ids, targets=None, loss_reduction='mean'):
|
||||
def __call__(self, input_ids, targets=None, cu_seqlens=None, loss_reduction='mean'):
|
||||
if cu_seqlens is not None:
|
||||
return self._forward_varlen(input_ids, targets, cu_seqlens, loss_reduction)
|
||||
logits = self.model(input_ids).logits
|
||||
if targets is None:
|
||||
return logits
|
||||
|
|
@ -60,6 +63,48 @@ class ModelWrapper:
|
|||
)
|
||||
return loss
|
||||
|
||||
def _forward_varlen(self, input_ids, targets, cu_seqlens, loss_reduction):
|
||||
"""Unpack 1D varlen to padded (B, T), forward through HF model, repack to (1, total_T, V)."""
|
||||
device = input_ids.device
|
||||
doc_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
num_docs = (doc_lens > 0).sum().item()
|
||||
max_doc_len = doc_lens.max().item()
|
||||
|
||||
# Unpack: 1D -> padded (num_docs, max_doc_len)
|
||||
batched = torch.zeros(num_docs, max_doc_len, dtype=input_ids.dtype, device=device)
|
||||
doc_idx = 0
|
||||
for i in range(len(doc_lens)):
|
||||
length = doc_lens[i].item()
|
||||
if length == 0:
|
||||
continue
|
||||
start = cu_seqlens[i].item()
|
||||
batched[doc_idx, :length] = input_ids[start:start + length]
|
||||
doc_idx += 1
|
||||
|
||||
logits = self.model(batched).logits # (num_docs, max_doc_len, V)
|
||||
|
||||
# Repack: padded (num_docs, max_doc_len, V) -> (1, total_T, V)
|
||||
total_T = input_ids.size(0)
|
||||
packed = torch.empty(1, total_T, logits.size(-1), dtype=logits.dtype, device=device)
|
||||
doc_idx = 0
|
||||
for i in range(len(doc_lens)):
|
||||
length = doc_lens[i].item()
|
||||
if length == 0:
|
||||
continue
|
||||
start = cu_seqlens[i].item()
|
||||
packed[0, start:start + length] = logits[doc_idx, :length]
|
||||
doc_idx += 1
|
||||
|
||||
if targets is None:
|
||||
return packed
|
||||
loss = torch.nn.functional.cross_entropy(
|
||||
packed.view(-1, packed.size(-1)),
|
||||
targets.view(-1),
|
||||
ignore_index=-1,
|
||||
reduction=loss_reduction
|
||||
)
|
||||
return loss
|
||||
|
||||
def get_device(self):
|
||||
return next(self.model.parameters()).device
|
||||
|
||||
|
|
@ -269,8 +314,10 @@ def main():
|
|||
print0(f"Adjusted split_tokens to {args.split_tokens} (must be divisible by {tokens_per_step})")
|
||||
steps = args.split_tokens // tokens_per_step
|
||||
|
||||
avg_num_docs = args.device_batch_size * sequence_len // 400
|
||||
max_num_docs = math.ceil(avg_num_docs / 16) * 16
|
||||
for split_name in ["train", "val"]:
|
||||
loader = tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, sequence_len, split_name, device=device)
|
||||
loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, sequence_len, split_name, max_num_docs=max_num_docs, device=device)
|
||||
bpb = evaluate_bpb(model, loader, steps, token_bytes)
|
||||
bpb_results[split_name] = bpb
|
||||
print0(f"{split_name} bpb: {bpb:.6f}")
|
||||
|
|
|
|||
|
|
@ -26,13 +26,13 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
from nanochat.gpt import GPT, GPTConfig, Linear
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_bos_bestfit, tokenizing_distributed_data_loader_with_state_bos_bestfit
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader_varlen
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type, get_peak_flops, COMPUTE_DTYPE, COMPUTE_DTYPE_REASON, is_ddp_initialized
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA, FA_VERSION
|
||||
from scripts.base_eval import evaluate_core
|
||||
print_banner()
|
||||
|
||||
|
|
@ -100,17 +100,19 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
from nanochat.flash_attention import USE_FA3
|
||||
using_fa3 = USE_FA3
|
||||
if using_fa3:
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
from nanochat.flash_attention import USE_FA
|
||||
if USE_FA:
|
||||
if FA_VERSION == 'fa3':
|
||||
print0("✓ Using Flash Attention 3 (Hopper GPU detected), efficient, new and awesome.")
|
||||
else:
|
||||
print0(f"✓ Using Flash Attention 2 (Ampere/Ada GPU detected).")
|
||||
else:
|
||||
print0("!" * 80)
|
||||
if HAS_FA3 and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention 3 only supports bf16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
if HAS_FA and COMPUTE_DTYPE != torch.bfloat16:
|
||||
print0(f"WARNING: Flash Attention only supports bf16/fp16, but COMPUTE_DTYPE={COMPUTE_DTYPE}. Using PyTorch SDPA fallback")
|
||||
else:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without FA3")
|
||||
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback")
|
||||
print0("WARNING: Training will be less efficient without Flash Attention")
|
||||
if args.window_pattern != "L":
|
||||
print0(f"WARNING: SDPA has no support for sliding window attention (window_pattern='{args.window_pattern}'). Your GPU utilization will be terrible.")
|
||||
print0("WARNING: Recommend using --window-pattern L for full context attention without alternating sliding window patterns.")
|
||||
|
|
@ -326,10 +328,13 @@ if scaler is not None:
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the DataLoaders for train/val
|
||||
# max_num_docs: see dataloader.py for explanation
|
||||
avg_num_docs = args.device_batch_size * args.max_seq_len // 400
|
||||
max_num_docs = math.ceil(avg_num_docs / 16) * 16
|
||||
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
|
||||
train_loader = tokenizing_distributed_data_loader_with_state_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_bos_bestfit(tokenizer, args.device_batch_size, args.max_seq_len, split="val", device=device)
|
||||
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
train_loader = tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="train", max_num_docs=max_num_docs, device=device, resume_state_dict=dataloader_resume_state_dict)
|
||||
build_val_loader = lambda: tokenizing_distributed_data_loader_varlen(tokenizer, args.device_batch_size, args.max_seq_len, split="val", max_num_docs=max_num_docs, device=device)
|
||||
x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculate the number of iterations we will train for and set up the various schedulers
|
||||
|
|
@ -507,14 +512,14 @@ while True:
|
|||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
loss = model(x, y)
|
||||
loss = model(x, y, cu_seqlens=cu_seqlens)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
x, y, cu_seqlens, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(step)
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
|
|
|
|||
|
|
@ -89,7 +89,6 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
device = model.get_device()
|
||||
bos = tokenizer.get_bos_token_id() # use BOS as pad token is ok, these positions are ignored
|
||||
|
||||
# We'll process batches of independent problems at a time because there is no sampling needed
|
||||
num_problems = len(task_object) if max_problems is None else min(len(task_object), max_problems)
|
||||
|
|
@ -102,17 +101,18 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
for i in range(ddp_rank, num_batches, ddp_world_size):
|
||||
i0, i1 = i * batch_size, min((i + 1) * batch_size, num_problems)
|
||||
|
||||
# Prepare the batch of problems. They might all be of different length, so we pad/collate them.
|
||||
# Prepare the batch of problems and pack into varlen (no padding waste)
|
||||
conversations = [task_object[ii] for ii in range(i0, i1)]
|
||||
prompt_ids = [tokenizer.render_for_completion(conversation) for conversation in conversations] # TODO: remake the way this works
|
||||
max_length = max(len(ids) for ids in prompt_ids)
|
||||
answer_time_positions = [len(ids) - 1 for ids in prompt_ids] # where the last token is (and the predicted answer)
|
||||
padded_prompt_ids = [ids + [bos] * (max_length - len(ids)) for ids in prompt_ids]
|
||||
prompt_ids = torch.tensor(padded_prompt_ids, dtype=torch.long, device=device)
|
||||
packed = torch.cat([torch.tensor(ids, dtype=torch.long, device=device) for ids in prompt_ids])
|
||||
cu_seqlens = torch.zeros(len(prompt_ids) + 1, dtype=torch.int32, device=device)
|
||||
for j, ids in enumerate(prompt_ids):
|
||||
cu_seqlens[j + 1] = cu_seqlens[j] + len(ids)
|
||||
|
||||
# Get the logits for the whole batch of conversations in parallel (efficiency win here)
|
||||
with torch.no_grad():
|
||||
logits = model(prompt_ids) # (B, T, V)
|
||||
logits = model(packed, cu_seqlens=cu_seqlens).squeeze(0) # (total_T, V)
|
||||
|
||||
# Focus on the available answer on just the letters corresponding to choices
|
||||
# Note that this helps the evaluation a lot because it specifically narrows the focus to only the available letters
|
||||
|
|
@ -130,7 +130,7 @@ def run_categorical_eval(task_object, tokenizer, model, batch_size, max_problems
|
|||
letter_ids.append(letter_to_id_cache[letter])
|
||||
# focus logits just down to the answer position and the available letters of the answer
|
||||
answer_pos = answer_time_positions[idx]
|
||||
focus_logits = logits[idx, answer_pos, letter_ids]
|
||||
focus_logits = logits[cu_seqlens[idx].item() + answer_pos, letter_ids]
|
||||
# get the argmax letter (the predicted answer)
|
||||
argmax_letter_id = focus_logits.argmax(dim=-1).item()
|
||||
predicted_letter = letters[argmax_letter_id]
|
||||
|
|
|
|||
|
|
@ -84,7 +84,6 @@ print0(f"Calculated number of steps: {num_steps}")
|
|||
|
||||
@torch.no_grad()
|
||||
def get_batch():
|
||||
assistant_end = tokenizer.encode_special("<|assistant_end|>") # ok to use this token, it's only for padding and isn't used in the loss.
|
||||
rank_indices = range(ddp_rank, len(train_task), ddp_world_size) # each rank is responsible for different examples in the training data
|
||||
for example_idx in itertools.cycle(rank_indices):
|
||||
|
||||
|
|
@ -125,25 +124,30 @@ def get_batch():
|
|||
reward = train_task.reward(conversation, generated_text)
|
||||
rewards.append(reward)
|
||||
|
||||
# Pad the sequences so that their lengths (in time) match
|
||||
max_length = max(len(seq) for seq in generated_token_sequences)
|
||||
padded_generated_token_sequences = [seq + [assistant_end] * (max_length - len(seq)) for seq in generated_token_sequences]
|
||||
padded_masks = [mask + [0] * (max_length - len(mask)) for mask in masks]
|
||||
# Stack up the sequences and masks into PyTorch tensors
|
||||
ids = torch.tensor(padded_generated_token_sequences, dtype=torch.long, device=device)
|
||||
mask_ids = torch.tensor(padded_masks, dtype=torch.long, device=device)
|
||||
# Generate autoregressive inputs and targets to the Transformer
|
||||
inputs = ids[:, :-1]
|
||||
targets = ids[:, 1:].clone() # clone to avoid in-place modification:
|
||||
targets[mask_ids[:, 1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
# Pack the sequences into a 1D buffer (varlen packing, no padding needed)
|
||||
# Generate autoregressive inputs and targets for each sequence
|
||||
input_seqs = []
|
||||
target_seqs = []
|
||||
for seq, mask in zip(generated_token_sequences, masks):
|
||||
seq_t = torch.tensor(seq, dtype=torch.long, device=device)
|
||||
mask_t = torch.tensor(mask, dtype=torch.long, device=device)
|
||||
input_seqs.append(seq_t[:-1])
|
||||
tgt = seq_t[1:].clone() # clone to avoid in-place modification:
|
||||
tgt[mask_t[1:] == 0] = -1 # <-- inplace modification right here. -1 is the ignore index
|
||||
target_seqs.append(tgt)
|
||||
# NOTE also that the Engine returns mask=0 for BOTH the prompt tokens AND the tool use tokens.
|
||||
# So we will (correctly) end up not training on the prompt tokens, or the tool use forced tokens.
|
||||
# Concatenate the sequences and masks into 1D PyTorch tensors
|
||||
inputs = torch.cat(input_seqs)
|
||||
targets = torch.cat(target_seqs)
|
||||
lengths = torch.tensor([len(s) for s in input_seqs], dtype=torch.int32, device=device)
|
||||
cu_seqlens = torch.cat([lengths.new_zeros(1), lengths.cumsum(0)]).to(torch.int32)
|
||||
rewards = torch.tensor(rewards, dtype=torch.float, device=device)
|
||||
# Calculate the advantages by simply subtracting the mean (instead of z-score (x-mu)/sigma)
|
||||
mu = rewards.mean()
|
||||
advantages = rewards - mu
|
||||
# yield inputs/targets as (B, T) of ids and rewards as (B,) of floats
|
||||
yield generated_token_sequences, inputs, targets, rewards, advantages
|
||||
# yield packed 1D inputs/targets with cu_seqlens, and rewards/advantages as (B,) of floats
|
||||
yield generated_token_sequences, inputs, targets, cu_seqlens, rewards, advantages
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Simple evaluation loop for GSM8K pass@k
|
||||
|
|
@ -247,23 +251,31 @@ for step in range(num_steps):
|
|||
sequence_lengths = []
|
||||
for example_step in range(examples_per_rank):
|
||||
# Get one batch corresponding to one example in the training dataset
|
||||
sequences_all, inputs_all, targets_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
sequences_all, inputs_all, targets_all, cu_seqlens_all, rewards_all, advantages_all = next(batch_iterator)
|
||||
# Evaluate the loss and gradients
|
||||
model.train() # ensure the model is in train mode
|
||||
# We need one more loop because we can never exceed the device_batch_size
|
||||
assert inputs_all.size(0) % args.device_batch_size == 0
|
||||
num_passes = inputs_all.size(0) // args.device_batch_size
|
||||
num_seqs = len(cu_seqlens_all) - 1
|
||||
assert num_seqs % args.device_batch_size == 0
|
||||
num_passes = num_seqs // args.device_batch_size
|
||||
for pass_idx in range(num_passes):
|
||||
# Pluck out the batch for this pass
|
||||
# Pluck out the sub-batch for this pass from the packed 1D buffer
|
||||
b0, b1 = pass_idx * args.device_batch_size, (pass_idx + 1) * args.device_batch_size
|
||||
inputs = inputs_all[b0:b1]
|
||||
targets = targets_all[b0:b1]
|
||||
t0, t1 = cu_seqlens_all[b0].item(), cu_seqlens_all[b1].item()
|
||||
inputs = inputs_all[t0:t1]
|
||||
targets = targets_all[t0:t1]
|
||||
cu_seqlens = cu_seqlens_all[b0:b1+1] - cu_seqlens_all[b0]
|
||||
rewards = rewards_all[b0:b1]
|
||||
advantages = advantages_all[b0:b1]
|
||||
# Calculate log probabilities. Note that the loss calculates NLL = -logp, so we negate
|
||||
logp = -model(inputs, targets, loss_reduction='none').view_as(inputs) # (B, T)
|
||||
logp = -model(inputs, targets, cu_seqlens=cu_seqlens, loss_reduction='none') # (T_sub,)
|
||||
# Expand per-sequence advantages to per-token positions using cu_seqlens boundaries
|
||||
token_advantages = torch.zeros_like(logp)
|
||||
for i in range(b1 - b0):
|
||||
s, e = cu_seqlens[i].item(), cu_seqlens[i+1].item()
|
||||
token_advantages[s:e] = advantages[i]
|
||||
# Calculate the PG objective. Note that ignore_index=-1 ensures that invalid tokens have loss 0.
|
||||
pg_obj = (logp * advantages.unsqueeze(-1)).sum()
|
||||
pg_obj = (logp * token_advantages).sum()
|
||||
# normalize by the number of valid tokens, number of passes, and examples_per_rank
|
||||
num_valid = (targets >= 0).sum().clamp(min=1)
|
||||
pg_obj = pg_obj / (num_valid * num_passes * examples_per_rank)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-s
|
|||
"""
|
||||
|
||||
import gc
|
||||
import math
|
||||
import argparse
|
||||
import os
|
||||
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
||||
|
|
@ -21,7 +22,8 @@ from nanochat.tokenizer import get_token_bytes
|
|||
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
import torch.distributed as dist
|
||||
from nanochat.flash_attention import HAS_FA3
|
||||
from nanochat.flash_attention import HAS_FA
|
||||
from nanochat.dataloader import sft_data_loader_varlen
|
||||
from nanochat.engine import Engine
|
||||
from scripts.chat_eval import run_chat_eval
|
||||
|
||||
|
|
@ -89,8 +91,8 @@ use_dummy_wandb = args.run == "dummy" or not master_process
|
|||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
|
||||
|
||||
# Flash Attention status
|
||||
if not HAS_FA3:
|
||||
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
if not HAS_FA:
|
||||
print0("WARNING: Flash Attention not available, using PyTorch SDPA fallback. Training will be less efficient.")
|
||||
|
||||
# Load the model and tokenizer
|
||||
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
|
||||
|
|
@ -178,147 +180,137 @@ val_dataset = TaskMixture([
|
|||
MMLU(subset="all", split="test", stop=5200), # 14K rows in test set, use only 5.2K to match the train ratios
|
||||
GSM8K(subset="main", split="test", stop=420), # 1.32K rows in test set, use only 420 to match the train ratios
|
||||
]) # total: 24K + 14K + 1.32K ~= 39K rows
|
||||
# DataLoader is defined here, it emits inputs, targets : 2D tensors of shape (device_batch_size, max_seq_len)
|
||||
# A big problem is that we don't know the final num_iterations in advance. So we create
|
||||
# these two global variables and update them from within the data generator.
|
||||
last_step = False # we will toggle this to True when we reach the end of the training dataset
|
||||
approx_progress = 0.0 # will go from 0 to 1 over the course of the epoch
|
||||
current_epoch = 1 # track epoch for logging
|
||||
def sft_data_generator_bos_bestfit(split, buffer_size=100):
|
||||
"""
|
||||
BOS-aligned dataloader for SFT with bestfit-pad packing.
|
||||
|
||||
Each row in the batch starts with BOS (beginning of a conversation).
|
||||
Conversations are packed using best-fit algorithm. When no conversation fits,
|
||||
the row is padded (instead of cropping) to ensure no tokens are ever discarded.
|
||||
Padding positions have targets masked with -1 (ignore_index for cross-entropy).
|
||||
# Pre-tokenize and pre-pack all conversations into batch plans.
|
||||
# This runs the same best-fit packing algorithm offline at startup, so we know
|
||||
# num_iterations and max_num_docs exactly before training starts.
|
||||
def tokenize_and_pack_sft(dataset, tokenizer, B, T, bos_token, ddp_rank, ddp_world_size, buffer_size=100):
|
||||
"""
|
||||
Pre-tokenize and pre-pack SFT conversations using best-fit packing.
|
||||
|
||||
Preserves TaskMixture's shuffled ordering (no length sorting) and uses the
|
||||
same buffer-based best-fit algorithm as the original inline dataloader.
|
||||
|
||||
Returns (conversations, batch_plans, total_micro_batches, max_num_docs).
|
||||
"""
|
||||
global last_step, approx_progress, current_epoch
|
||||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
row_capacity = args.max_seq_len + 1 # +1 for target at last position
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
buffer_capacity = B * T + 1
|
||||
|
||||
# Conversation buffer: list of (token_ids, loss_mask) tuples
|
||||
conversations = []
|
||||
num_convs = (dataset_size - ddp_rank + ddp_world_size - 1) // ddp_world_size
|
||||
cursor = ddp_rank
|
||||
while cursor < dataset_size:
|
||||
ids, mask = tokenizer.render_conversation(dataset[cursor])
|
||||
conversations.append((ids, mask))
|
||||
cursor += ddp_world_size
|
||||
if len(conversations) % 5000 == 0:
|
||||
print0(f"\r\033[KTokenizing: {len(conversations):,}/{num_convs:,} ({100*len(conversations)/num_convs:.0f}%)", end='', flush=True)
|
||||
print0(f"\r\033[KTokenized {len(conversations):,} conversations", flush=True)
|
||||
|
||||
batch_plans = []
|
||||
conv_buffer = []
|
||||
cursor = ddp_rank # Each rank processes different conversations (for fetching)
|
||||
consumed = ddp_rank # Track actual consumption separately from buffering
|
||||
epoch = 1
|
||||
it = 0 # iteration counter
|
||||
fetch_cursor = 0
|
||||
max_doc_count = 0
|
||||
|
||||
def refill_buffer():
|
||||
nonlocal cursor, epoch
|
||||
while len(conv_buffer) < buffer_size:
|
||||
conversation = dataset[cursor]
|
||||
ids, mask = tokenizer.render_conversation(conversation)
|
||||
conv_buffer.append((ids, mask))
|
||||
cursor += ddp_world_size
|
||||
if cursor >= dataset_size:
|
||||
cursor = cursor % dataset_size
|
||||
epoch += 1
|
||||
# Note: last_step is now triggered based on consumption, not fetching
|
||||
def refill():
|
||||
nonlocal fetch_cursor
|
||||
while len(conv_buffer) < buffer_size and fetch_cursor < len(conversations):
|
||||
conv_buffer.append(fetch_cursor)
|
||||
fetch_cursor += 1
|
||||
|
||||
while True:
|
||||
rows = []
|
||||
mask_rows = []
|
||||
row_lengths = [] # Track actual content length (excluding padding) for each row
|
||||
for _ in range(args.device_batch_size):
|
||||
row = []
|
||||
mask_row = []
|
||||
padded = False
|
||||
while len(row) < row_capacity:
|
||||
# Ensure buffer has conversations
|
||||
while len(conv_buffer) < buffer_size:
|
||||
refill_buffer()
|
||||
|
||||
remaining = row_capacity - len(row)
|
||||
|
||||
# Find largest conversation that fits entirely
|
||||
best_idx = -1
|
||||
best_len = 0
|
||||
for i, (conv, _) in enumerate(conv_buffer):
|
||||
conv_len = len(conv)
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_idx = i
|
||||
best_len = conv_len
|
||||
|
||||
if best_idx >= 0:
|
||||
# Found a conversation that fits - use it entirely
|
||||
conv, conv_mask = conv_buffer.pop(best_idx)
|
||||
row.extend(conv)
|
||||
mask_row.extend(conv_mask)
|
||||
consumed += ddp_world_size # Track actual consumption
|
||||
else:
|
||||
# No conversation fits - pad the remainder instead of cropping
|
||||
# This ensures we never discard any tokens
|
||||
content_len = len(row)
|
||||
row.extend([bos_token] * remaining) # Pad with BOS tokens
|
||||
mask_row.extend([0] * remaining)
|
||||
padded = True
|
||||
break # Row is now full (with padding)
|
||||
|
||||
# Track content length: full row if no padding, otherwise the length before padding
|
||||
if padded:
|
||||
row_lengths.append(content_len)
|
||||
refill()
|
||||
if not conv_buffer:
|
||||
break
|
||||
batch_indices = []
|
||||
pos = 0
|
||||
while pos < buffer_capacity:
|
||||
refill()
|
||||
if not conv_buffer:
|
||||
break
|
||||
remaining = buffer_capacity - pos
|
||||
best_buf_idx = -1
|
||||
best_len = 0
|
||||
for i, conv_idx in enumerate(conv_buffer):
|
||||
conv_len = len(conversations[conv_idx][0])
|
||||
if conv_len <= remaining and conv_len > best_len:
|
||||
best_buf_idx = i
|
||||
best_len = conv_len
|
||||
if best_buf_idx >= 0:
|
||||
batch_indices.append(conv_buffer.pop(best_buf_idx))
|
||||
pos += best_len
|
||||
else:
|
||||
row_lengths.append(row_capacity)
|
||||
rows.append(row[:row_capacity])
|
||||
mask_rows.append(mask_row[:row_capacity])
|
||||
break
|
||||
if batch_indices:
|
||||
doc_count = len(batch_indices) + (1 if pos < buffer_capacity else 0)
|
||||
max_doc_count = max(max_doc_count, doc_count)
|
||||
batch_plans.append(batch_indices)
|
||||
|
||||
# Stopping condition to respect num_iterations, if given
|
||||
it += 1
|
||||
if 0 < args.num_iterations <= it and split == "train":
|
||||
last_step = True
|
||||
max_num_docs = math.ceil((max_doc_count + 1) / 16) * 16
|
||||
return conversations, batch_plans, len(batch_plans), max(max_num_docs, 16)
|
||||
|
||||
# Update progress tracking (based on consumed, not cursor, to account for buffering)
|
||||
if split == "train":
|
||||
current_epoch = epoch
|
||||
if args.num_iterations > 0:
|
||||
approx_progress = it / args.num_iterations
|
||||
else:
|
||||
approx_progress = consumed / dataset_size
|
||||
# Trigger last_step when we've consumed enough (instead of when cursor wraps)
|
||||
if consumed >= dataset_size:
|
||||
last_step = True
|
||||
bos_token = tokenizer.get_bos_token_id()
|
||||
t_pack_start = time.time()
|
||||
train_convs, train_plans, train_micro_batches, train_max_docs = tokenize_and_pack_sft(
|
||||
train_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
|
||||
bos_token, ddp_rank, ddp_world_size)
|
||||
t_pack_train = time.time()
|
||||
val_convs, val_plans, val_micro_batches, val_max_docs = tokenize_and_pack_sft(
|
||||
val_dataset, tokenizer, args.device_batch_size, args.max_seq_len,
|
||||
bos_token, ddp_rank, ddp_world_size)
|
||||
t_pack_val = time.time()
|
||||
max_num_docs = max(train_max_docs, val_max_docs)
|
||||
print0(f"Pre-tokenize & pack: train {t_pack_train - t_pack_start:.1f}s, val {t_pack_val - t_pack_train:.1f}s, total {t_pack_val - t_pack_start:.1f}s")
|
||||
|
||||
# Build tensors
|
||||
use_cuda = device_type == "cuda"
|
||||
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
|
||||
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda).contiguous()
|
||||
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda).contiguous()
|
||||
# Document length and packing statistics
|
||||
import numpy as np
|
||||
train_doc_lens = [len(ids) for ids, _ in train_convs]
|
||||
train_docs_per_batch = [len(plan) for plan in train_plans]
|
||||
train_tokens_per_batch = [sum(len(train_convs[i][0]) for i in plan) for plan in train_plans]
|
||||
buffer_capacity = args.device_batch_size * args.max_seq_len + 1
|
||||
train_packing_eff = [t / buffer_capacity for t in train_tokens_per_batch]
|
||||
dl = np.array(train_doc_lens)
|
||||
dpb = np.array(train_docs_per_batch)
|
||||
pe = np.array(train_packing_eff)
|
||||
print0(f"Train doc lengths: n={len(dl):,} | mean={dl.mean():.0f} median={np.median(dl):.0f} "
|
||||
f"min={dl.min()} max={dl.max()} p5={np.percentile(dl,5):.0f} p95={np.percentile(dl,95):.0f}")
|
||||
print0(f"Train docs/batch: n={len(dpb):,} | mean={dpb.mean():.1f} median={np.median(dpb):.0f} "
|
||||
f"min={dpb.min()} max={dpb.max()} p5={np.percentile(dpb,5):.0f} p95={np.percentile(dpb,95):.0f}")
|
||||
print0(f"Train packing eff: mean={pe.mean():.3f} median={np.median(pe):.3f} "
|
||||
f"min={pe.min():.3f} max={pe.max():.3f}")
|
||||
|
||||
# Apply the loss mask from render_conversation (mask=1 for assistant completions,
|
||||
# mask=0 for user prompts, BOS, special tokens, tool outputs). mask[1:] aligns
|
||||
# with targets (shifted by 1). Unmasked positions get -1 (ignore_index).
|
||||
mask_tensor = torch.tensor(mask_rows, dtype=torch.int8)
|
||||
mask_targets = mask_tensor[:, 1:].to(device=device)
|
||||
targets[mask_targets == 0] = -1
|
||||
# num_iterations: exact count of optimization steps. The -1 accounts for the
|
||||
# prefetch batch that the training loop requests but never trains on.
|
||||
data_num_iterations = (train_micro_batches - 1) // grad_accum_steps
|
||||
if args.num_iterations > 0:
|
||||
num_iterations = min(args.num_iterations, data_num_iterations)
|
||||
else:
|
||||
num_iterations = data_num_iterations
|
||||
if ddp:
|
||||
num_iter_tensor = torch.tensor([num_iterations], dtype=torch.long, device=device)
|
||||
dist.all_reduce(num_iter_tensor, op=dist.ReduceOp.MIN)
|
||||
num_iterations = num_iter_tensor.item()
|
||||
print0(f"Pre-packed {len(train_convs):,} train conversations into {train_micro_batches:,} micro-batches "
|
||||
f"=> {num_iterations:,} optimization steps (max {max_num_docs} docs/batch)")
|
||||
|
||||
# Mask out padding positions in targets (set to -1 = ignore_index)
|
||||
# For each row, positions >= (content_length - 1) in targets should be masked
|
||||
for i, content_len in enumerate(row_lengths):
|
||||
if content_len < row_capacity:
|
||||
targets[i, content_len-1:] = -1
|
||||
|
||||
yield inputs, targets
|
||||
|
||||
train_loader = sft_data_generator_bos_bestfit("train")
|
||||
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
|
||||
progress = 0 # will go from 0 to 1 over the course of the epoch
|
||||
train_loader = sft_data_loader_varlen(
|
||||
train_convs, train_plans, args.device_batch_size, args.max_seq_len,
|
||||
max_num_docs, bos_token, device=device)
|
||||
build_val_loader = lambda: sft_data_loader_varlen(
|
||||
val_convs, val_plans, args.device_batch_size, args.max_seq_len,
|
||||
max_num_docs, bos_token, device=device, cycle=True)
|
||||
|
||||
# Learning rate schedule (linear warmup, constant, linear warmdown)
|
||||
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
|
||||
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
|
||||
def get_lr_multiplier(progress):
|
||||
if progress < args.warmup_ratio:
|
||||
return (progress + 1e-8) / args.warmup_ratio
|
||||
elif progress <= 1.0 - args.warmdown_ratio:
|
||||
def get_lr_multiplier(it):
|
||||
warmup_iters = round(args.warmup_ratio * num_iterations)
|
||||
warmdown_iters = round(args.warmdown_ratio * num_iterations)
|
||||
if it < warmup_iters:
|
||||
return (it + 1) / warmup_iters
|
||||
elif it <= num_iterations - warmdown_iters:
|
||||
return 1.0
|
||||
else:
|
||||
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
|
||||
return (1 - decay) * 1.0 + decay * args.final_lr_frac
|
||||
progress = (num_iterations - it) / warmdown_iters
|
||||
return progress * 1.0 + (1 - progress) * args.final_lr_frac
|
||||
|
||||
# Momentum scheduler for Muon optimizer
|
||||
def get_muon_momentum(it):
|
||||
|
|
@ -328,21 +320,16 @@ def get_muon_momentum(it):
|
|||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Training loop
|
||||
x, y = next(train_loader) # prefetch the very first batch of data
|
||||
x, y, cu_seqlens = next(train_loader) # prefetch the very first batch of data
|
||||
min_val_bpb = float("inf")
|
||||
smooth_train_loss = 0 # EMA of training loss
|
||||
ema_beta = 0.9 # EMA decay factor
|
||||
total_training_time = 0 # total wall-clock time of training
|
||||
step = 0
|
||||
while True:
|
||||
last_step = step == num_iterations
|
||||
flops_so_far = num_flops_per_token * args.total_batch_size * step
|
||||
|
||||
# Synchronize last_step across all ranks to avoid hangs in the distributed setting
|
||||
if ddp:
|
||||
last_step_tensor = torch.tensor(last_step, dtype=torch.int32, device=device)
|
||||
dist.all_reduce(last_step_tensor, op=dist.ReduceOp.MAX)
|
||||
last_step = bool(last_step_tensor.item())
|
||||
|
||||
# once in a while: evaluate the val bpb (all ranks participate)
|
||||
if last_step or (args.eval_every > 0 and step % args.eval_every == 0):
|
||||
model.eval()
|
||||
|
|
@ -430,17 +417,16 @@ while True:
|
|||
synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
loss = model(x, y)
|
||||
loss = model(x, y, cu_seqlens=cu_seqlens)
|
||||
train_loss = loss.detach() # for logging
|
||||
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
|
||||
if scaler is not None:
|
||||
scaler.scale(loss).backward()
|
||||
else:
|
||||
loss.backward()
|
||||
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
progress = max(progress, approx_progress) # only increase progress monotonically
|
||||
x, y, cu_seqlens = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
|
||||
# step the optimizer
|
||||
lrm = get_lr_multiplier(progress)
|
||||
lrm = get_lr_multiplier(step)
|
||||
muon_momentum = get_muon_momentum(step)
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = group["initial_lr"] * lrm
|
||||
|
|
@ -467,13 +453,13 @@ while True:
|
|||
# logging
|
||||
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
|
||||
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
|
||||
pct_done = 100 * progress
|
||||
pct_done = 100 * step / num_iterations
|
||||
tok_per_sec = int(args.total_batch_size / dt)
|
||||
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
|
||||
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
|
||||
if step > 10:
|
||||
total_training_time += dt # only count the time after the first 10 steps
|
||||
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
|
||||
print0(f"step {step:05d}/{num_iterations:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | total time: {total_training_time/60:.2f}m")
|
||||
if step % 10 == 0:
|
||||
wandb_run.log({
|
||||
"step": step,
|
||||
|
|
@ -484,7 +470,6 @@ while True:
|
|||
"train/dt": dt,
|
||||
"train/tok_per_sec": tok_per_sec,
|
||||
"train/mfu": mfu,
|
||||
"train/epoch": current_epoch,
|
||||
})
|
||||
|
||||
# The garbage collector spends ~500ms scanning for cycles quite frequently.
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
"""
|
||||
Test Flash Attention unified interface - verify FA3 and SDPA produce identical results.
|
||||
Test Flash Attention unified interface - verify FA (FA3/FA2) and SDPA produce identical results.
|
||||
|
||||
Run: python -m pytest tests/test_attention_fallback.py -v -s
|
||||
|
||||
Note on test structure:
|
||||
Tests are split into two classes due to dtype/device constraints:
|
||||
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA3 and SDPA on the same inputs
|
||||
and verify they produce identical results. These require a Hopper GPU (FA3 only
|
||||
works on sm90+) and use bfloat16 (FA3 doesn't support float32).
|
||||
1. TestFA3VsSDPA: Comparison tests that run both FA and SDPA on the same inputs
|
||||
and verify they produce identical results. These require an Ampere+ GPU
|
||||
(FA3 on Hopper, FA2 on Ampere/Ada) and use bfloat16.
|
||||
|
||||
2. TestSDPAOnly: Tests that only exercise the SDPA fallback path. These can run
|
||||
on any device (CUDA, CPU, MPS) with the appropriate dtype for that device.
|
||||
|
|
@ -16,24 +16,29 @@ Note on test structure:
|
|||
import torch
|
||||
import pytest
|
||||
import nanochat.flash_attention as fa_module
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA3
|
||||
from nanochat.flash_attention import flash_attn, HAS_FA
|
||||
from nanochat.engine import KVCache
|
||||
|
||||
|
||||
def set_impl(impl):
|
||||
"""Set the implementation override ('fa3', 'sdpa', or None for auto) and re-resolve USE_FA3."""
|
||||
"""Set the implementation override ('fa3', 'fa2', 'sdpa', or None for auto) and re-resolve USE_FA."""
|
||||
fa_module._override_impl = impl
|
||||
fa_module.USE_FA3 = fa_module._resolve_use_fa3()
|
||||
fa_module.USE_FA = fa_module._resolve_use_fa()
|
||||
|
||||
|
||||
def run_both_impls(fn):
|
||||
"""Run a function with both FA3 and SDPA, return both outputs."""
|
||||
set_impl('fa3')
|
||||
out_fa3 = fn()
|
||||
"""Run a function with both FA (FA3 or FA2) and SDPA, return both outputs."""
|
||||
set_impl('fa')
|
||||
out_fa = fn()
|
||||
set_impl('sdpa')
|
||||
out_sdpa = fn()
|
||||
set_impl(None) # reset
|
||||
return out_fa3, out_sdpa
|
||||
return out_fa, out_sdpa
|
||||
|
||||
|
||||
def make_cu_seqlens(B, T, device):
|
||||
"""Create cu_seqlens for B documents each of length T."""
|
||||
return torch.arange(0, (B + 1) * T, T, dtype=torch.int32, device=device)
|
||||
|
||||
|
||||
def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
||||
|
|
@ -48,9 +53,9 @@ def assert_close(t1, t2, name, atol=1e-2, rtol=1e-2):
|
|||
# =============================================================================
|
||||
# FA3 vs SDPA comparison tests (require Hopper GPU)
|
||||
# =============================================================================
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required to compare implementations")
|
||||
@pytest.mark.skipif(not HAS_FA, reason="FA required to compare implementations")
|
||||
class TestFA3VsSDPA:
|
||||
"""Compare FA3 and SDPA produce identical results. Requires Hopper GPU."""
|
||||
"""Compare FA and SDPA produce identical results. Requires Ampere+ GPU."""
|
||||
|
||||
DEVICE = "cuda"
|
||||
DTYPE = torch.bfloat16
|
||||
|
|
@ -58,12 +63,15 @@ class TestFA3VsSDPA:
|
|||
def test_basic_causal(self):
|
||||
"""Basic causal attention."""
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
return flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "basic_causal")
|
||||
|
|
@ -72,12 +80,15 @@ class TestFA3VsSDPA:
|
|||
def test_full_context(self):
|
||||
"""Full context (window_size=-1)."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
return flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "full_context")
|
||||
|
|
@ -87,12 +98,15 @@ class TestFA3VsSDPA:
|
|||
"""Sliding window attention."""
|
||||
B, T, H, D = 2, 128, 4, 32
|
||||
window = 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(window, 0))
|
||||
return flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(window, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "sliding_window")
|
||||
|
|
@ -104,12 +118,15 @@ class TestFA3VsSDPA:
|
|||
n_heads = 8
|
||||
n_kv_heads = 2
|
||||
|
||||
q = torch.randn(B, T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, n_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, n_kv_heads, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
return flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "gqa")
|
||||
|
|
@ -118,12 +135,15 @@ class TestFA3VsSDPA:
|
|||
def test_larger_model(self):
|
||||
"""Larger dimensions closer to real model."""
|
||||
B, T, H, D = 4, 256, 12, 64
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
def run():
|
||||
return flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(-1, -1))
|
||||
return flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(-1, -1))
|
||||
|
||||
y_fa3, y_sdpa = run_both_impls(run)
|
||||
max_diff, mean_diff = assert_close(y_fa3, y_sdpa, "larger_model")
|
||||
|
|
@ -215,21 +235,24 @@ class TestFA3VsSDPA:
|
|||
def test_backward_gradients_match(self):
|
||||
"""Verify gradients are similar between FA3 and SDPA."""
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
q_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v_data = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
|
||||
def run():
|
||||
q = q_data.clone().requires_grad_(True)
|
||||
k = k_data.clone().requires_grad_(True)
|
||||
v = v_data.clone().requires_grad_(True)
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
y = flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
return y.detach(), q.grad.detach(), k.grad.detach(), v.grad.detach()
|
||||
|
||||
set_impl('fa3')
|
||||
set_impl('fa')
|
||||
y_fa3, q_grad_fa3, k_grad_fa3, v_grad_fa3 = run()
|
||||
set_impl('sdpa')
|
||||
y_sdpa, q_grad_sdpa, k_grad_sdpa, v_grad_sdpa = run()
|
||||
|
|
@ -261,13 +284,16 @@ class TestSDPAOnly:
|
|||
"""Test SDPA forward pass produces valid output."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 64, 4, 32
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
y = flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
|
||||
|
||||
assert y.shape == (B, T, H, D)
|
||||
assert y.shape == (B * T, H, D)
|
||||
assert not torch.isnan(y).any(), "Output contains NaN"
|
||||
set_impl(None)
|
||||
|
||||
|
|
@ -275,11 +301,14 @@ class TestSDPAOnly:
|
|||
"""Test gradients flow through SDPA."""
|
||||
set_impl('sdpa')
|
||||
B, T, H, D = 2, 32, 4, 16
|
||||
q = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
k = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
v = torch.randn(B, T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
q = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
k = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
v = torch.randn(B * T, H, D, device=self.DEVICE, dtype=self.DTYPE, requires_grad=True)
|
||||
cu_seqlens = make_cu_seqlens(B, T, self.DEVICE)
|
||||
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=(T, 0))
|
||||
y = flash_attn.flash_attn_varlen_func(q, k, v,
|
||||
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=T, max_seqlen_k=T, causal=True, window_size=(T, 0))
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
|
||||
|
|
@ -340,23 +369,23 @@ class TestSDPAOnly:
|
|||
class TestOverrideMechanism:
|
||||
"""Test that the override mechanism works correctly."""
|
||||
|
||||
@pytest.mark.skipif(not HAS_FA3, reason="FA3 required")
|
||||
def test_override_fa3(self):
|
||||
"""Test that override='fa3' uses FA3."""
|
||||
set_impl('fa3')
|
||||
assert fa_module.USE_FA3 == True
|
||||
@pytest.mark.skipif(not HAS_FA, reason="FA required")
|
||||
def test_override_fa(self):
|
||||
"""Test that override='fa' uses FA."""
|
||||
set_impl('fa')
|
||||
assert fa_module.USE_FA == True
|
||||
set_impl(None)
|
||||
|
||||
def test_override_sdpa(self):
|
||||
"""Test that override='sdpa' uses SDPA."""
|
||||
set_impl('sdpa')
|
||||
assert fa_module.USE_FA3 == False
|
||||
assert fa_module.USE_FA == False
|
||||
set_impl(None)
|
||||
|
||||
def test_override_auto(self):
|
||||
"""Test that override=None uses auto-detection."""
|
||||
set_impl(None)
|
||||
assert fa_module.USE_FA3 == HAS_FA3
|
||||
assert fa_module.USE_FA == HAS_FA
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -366,7 +395,7 @@ if __name__ == "__main__":
|
|||
print(f"CUDA device: {torch.cuda.get_device_name()}")
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
print(f"Compute capability: {major}.{minor}")
|
||||
print(f"HAS_FA3: {HAS_FA3}")
|
||||
print(f"HAS_FA: {HAS_FA}")
|
||||
print()
|
||||
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user