mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-16 05:48:37 +00:00
Merge 3e5fccdfa4 into 50413d2d67
This commit is contained in:
commit
fb47904d51
|
|
@ -29,7 +29,21 @@ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
|||
# Official docs of FA3 label it as "beta" and want you to install FA3 from source, which is a pain.
|
||||
# Wishing for official FA3 wheels soon, for now this seems to be a fast way to get them (ty varunneal)
|
||||
from kernels import get_kernel
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
|
||||
flash_attn = None
|
||||
try:
|
||||
# Flash Attention 3 uses NVIDIA Hopper-specific features like TMA (Tensor Memory Accelerator).
|
||||
# These are only physically available on GPUs with Compute Capability >= 9.0 (e.g. H100).
|
||||
# We explicitly check for this to prevent "No kernel image available" crashes on Ampere/Ada GPUs (RTX 30xx/40xx) etc.
|
||||
if torch.cuda.is_available():
|
||||
if torch.cuda.get_device_capability()[0] >= 9:
|
||||
flash_attn = get_kernel('varunneal/flash-attention-3').flash_attn_interface
|
||||
else:
|
||||
# If the kernel image is not available, try installing the wheel manually from https://windreamer.github.io/flash-attention3-wheels/
|
||||
import flash_attn_interface as flash_attn
|
||||
except Exception:
|
||||
# Fallback to PyTorch SDPA on non-Hopper NVIDIA GPUs, Mac (MPS), or CPU.
|
||||
pass
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
|
|
@ -92,17 +106,54 @@ class CausalSelfAttention(nn.Module):
|
|||
# window_size is (left, right) tuple: (N, 0) for causal, (-1, 0) for full context
|
||||
if kv_cache is None:
|
||||
# Training: causal attention with optional sliding window
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
if flash_attn is not None:
|
||||
y = flash_attn.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
||||
else:
|
||||
# Fallback Path (CPU/MPS): Needs Transpose to (B, H, T, D) for SDPA
|
||||
y = F.scaled_dot_product_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
is_causal=True
|
||||
)
|
||||
y = y.transpose(1, 2) # Restore layout to (B, T, H, D)
|
||||
else:
|
||||
# Inference: use flash_attn_with_kvcache which handles cache management
|
||||
k_cache, v_cache = kv_cache.get_layer_cache(self.layer_idx)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
|
||||
if flash_attn is not None:
|
||||
# Optimized Path (Linux/CUDA with FA3)
|
||||
y = flash_attn.flash_attn_with_kvcache(
|
||||
q, k_cache, v_cache,
|
||||
k=k, v=v,
|
||||
cache_seqlens=kv_cache.cache_seqlens,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
# Fallback Path (macOS/MPS or CPU) - Manual Cache Update + SDPA
|
||||
positions = kv_cache.cache_seqlens
|
||||
# Update cache manually for the batch
|
||||
for b in range(B):
|
||||
pos = positions[b].item()
|
||||
k_cache[b, pos:pos+T] = k[b]
|
||||
v_cache[b, pos:pos+T] = v[b]
|
||||
|
||||
# Compute attention manually
|
||||
y = torch.empty_like(q)
|
||||
for b in range(B):
|
||||
pos = positions[b].item()
|
||||
|
||||
# Fetch history + cast to q.dtype (MPS fix) + Transpose (CPU fix)
|
||||
k_curr = k_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2)
|
||||
v_curr = v_cache[b:b+1, :pos+T].to(dtype=q.dtype).transpose(1, 2)
|
||||
q_curr = q[b:b+1].transpose(1, 2)
|
||||
|
||||
# Standard Attention (is_causal=False because context is explicit)
|
||||
att_out = F.scaled_dot_product_attention(q_curr, k_curr, v_curr, is_causal=False)
|
||||
# Transpose back to (B, T, H, D) and store
|
||||
y[b] = att_out.transpose(1, 2)[0]
|
||||
|
||||
# Advance position after last layer processes
|
||||
if self.layer_idx == kv_cache.n_layers - 1:
|
||||
kv_cache.advance(T)
|
||||
|
|
@ -405,4 +456,4 @@ class GPT(nn.Module):
|
|||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
ids = torch.cat((ids, next_ids), dim=1)
|
||||
token = next_ids.item()
|
||||
yield token
|
||||
yield token
|
||||
Loading…
Reference in New Issue
Block a user