mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-13 19:30:23 +00:00
Add FA3 probe + kernels upgrade for smoke and d12 runners
This commit is contained in:
parent
6ee799ef00
commit
e26bf8719b
|
|
@ -116,8 +116,33 @@ sed -i 's/ --target-param-data-ratio=8//' runs/speedrun.sh
|
|||
echo "[runner] speedrun.sh edits applied:"
|
||||
grep -n 'depth\|target-param' runs/speedrun.sh || true
|
||||
|
||||
# Explicit venv setup BEFORE speedrun.sh so we can run diagnostic probes
|
||||
# inside the venv. speedrun.sh's uv sync is idempotent (no-op the second time).
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR
|
||||
mkdir -p "$NANOCHAT_BASE_DIR"
|
||||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
[ -d ".venv" ] || uv venv
|
||||
uv sync --extra gpu
|
||||
source .venv/bin/activate
|
||||
pip install --quiet --upgrade huggingface_hub
|
||||
|
||||
# Ensure HF token flows to the kernels lib (some libs read HF_HUB_TOKEN, not HF_TOKEN)
|
||||
export HF_HUB_TOKEN="${HF_TOKEN}"
|
||||
|
||||
# Bump kernels to latest — pyproject pins >=0.11.7 and uv often picks exactly that;
|
||||
# 0.11.x had kernel-resolution bugs that affect FA3 loading silently.
|
||||
echo "[runner] upgrading kernels lib for FA3 reliability"
|
||||
uv pip install --quiet --upgrade 'kernels>=0.13.0' 2>&1 || \
|
||||
echo "[runner] WARN: kernels upgrade failed (continuing)"
|
||||
|
||||
# FA3 diagnostic probe — surfaces real errors (nanochat silently swallows them).
|
||||
# Non-fatal: SDPA fallback is automatic. We want this output in the log
|
||||
# regardless of outcome so we can decide what to do about FA3.
|
||||
echo "[runner] === FA3 PROBE BEGIN ==="
|
||||
python "$WORKDIR/runs/runpod/probe_fa3.py" || echo "[runner] FA3 probe reported issues (non-fatal — continuing with SDPA fallback)"
|
||||
echo "[runner] === FA3 PROBE END ==="
|
||||
|
||||
(
|
||||
while true; do
|
||||
sleep "$BACKUP_INTERVAL"
|
||||
|
|
|
|||
216
runs/runpod/probe_fa3.py
Normal file
216
runs/runpod/probe_fa3.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
#!/usr/bin/env python3
|
||||
# pyright: reportMissingImports=false
|
||||
"""
|
||||
Comprehensive FA3 / kernels diagnostic probe.
|
||||
|
||||
nanochat/flash_attention.py:_load_flash_attention_3 swallows ALL exceptions silently
|
||||
and falls back to SDPA. This script runs the same code path with full tracebacks
|
||||
so we can see why FA3 isn't loading on the pod.
|
||||
|
||||
Run inside the pod (after uv sync, with venv active):
|
||||
python runs/runpod/probe_fa3.py
|
||||
|
||||
Exits 0 if FA3 is fully usable, 1 if any check fails.
|
||||
|
||||
Note: the pyright pragma above is intentional — torch/huggingface_hub/kernels
|
||||
are only present at pod runtime; the local IDE will flag them as unresolved.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
OK = "\033[32mOK\033[0m"
|
||||
FAIL = "\033[31mFAIL\033[0m"
|
||||
WARN = "\033[33mWARN\033[0m"
|
||||
|
||||
|
||||
def section(n, name):
|
||||
print()
|
||||
print("=" * 80)
|
||||
print(f"[{n}] {name}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
def fmt_token(tok):
|
||||
if not tok:
|
||||
return "NOT SET"
|
||||
return f"SET (len={len(tok)}, prefix={tok[:7]}…)"
|
||||
|
||||
|
||||
passed_all = True
|
||||
|
||||
|
||||
def fail(msg):
|
||||
global passed_all
|
||||
passed_all = False
|
||||
print(f" {FAIL} {msg}")
|
||||
|
||||
|
||||
def warn(msg):
|
||||
print(f" {WARN} {msg}")
|
||||
|
||||
|
||||
def ok(msg):
|
||||
print(f" {OK} {msg}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(1, "Environment")
|
||||
print(f" python : {sys.version.split()[0]}")
|
||||
print(f" platform : {platform.platform()}")
|
||||
print(f" cwd : {os.getcwd()}")
|
||||
hf_tok = os.environ.get("HF_TOKEN", "")
|
||||
hf_hub_tok = os.environ.get("HF_HUB_TOKEN", "")
|
||||
print(f" HF_TOKEN : {fmt_token(hf_tok)}")
|
||||
print(f" HF_HUB_TOKEN : {fmt_token(hf_hub_tok)}")
|
||||
print(f" HF_HOME : {os.environ.get('HF_HOME', '(default ~/.cache/huggingface)')}")
|
||||
print(f" HUGGINGFACE_HUB_CACHE : {os.environ.get('HUGGINGFACE_HUB_CACHE', '(unset)')}")
|
||||
print(f" WANDB_API_KEY: {fmt_token(os.environ.get('WANDB_API_KEY',''))}")
|
||||
|
||||
if not hf_tok:
|
||||
fail("HF_TOKEN env var is empty — kernels lib will fall back to anonymous and may rate-limit")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(2, "Network connectivity")
|
||||
for url, label in [
|
||||
("https://huggingface.co", "huggingface.co"),
|
||||
("https://cdn-lfs.huggingface.co", "cdn-lfs.huggingface.co"),
|
||||
("https://github.com", "github.com"),
|
||||
]:
|
||||
try:
|
||||
rc = subprocess.run(
|
||||
["curl", "-sfI", "--max-time", "10", url],
|
||||
capture_output=True, text=True, timeout=15,
|
||||
).returncode
|
||||
if rc == 0:
|
||||
ok(f"{label} reachable")
|
||||
else:
|
||||
fail(f"{label} unreachable (curl rc={rc})")
|
||||
except Exception as e:
|
||||
fail(f"{label}: {type(e).__name__}: {e}")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(3, "huggingface_hub auth (does the token actually work?)")
|
||||
try:
|
||||
from huggingface_hub import whoami
|
||||
info = whoami(token=hf_tok or None)
|
||||
ok(f"authenticated as: {info.get('name','?')} (type={info.get('type','?')})")
|
||||
print(f" orgs: {[o.get('name') for o in info.get('orgs', [])]}")
|
||||
print(f" access token role: {info.get('auth',{}).get('accessToken',{}).get('role','?')}")
|
||||
except Exception as e:
|
||||
fail(f"whoami failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(4, "PyTorch / CUDA / GPU")
|
||||
try:
|
||||
import torch
|
||||
print(f" torch : {torch.__version__}")
|
||||
print(f" cuda available : {torch.cuda.is_available()}")
|
||||
print(f" cuda version : {torch.version.cuda}")
|
||||
if torch.cuda.is_available():
|
||||
print(f" device count : {torch.cuda.device_count()}")
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
name = torch.cuda.get_device_name(i)
|
||||
mark = OK if major == 9 else WARN
|
||||
print(f" device {i} : {name} sm{major}{minor} [{mark}]")
|
||||
major, _ = torch.cuda.get_device_capability(0)
|
||||
if major != 9:
|
||||
fail(f"FA3 requires sm90 (Hopper); device 0 is sm{major}{_}")
|
||||
else:
|
||||
fail("CUDA not available")
|
||||
except Exception as e:
|
||||
fail(f"torch import failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(5, "kernels library")
|
||||
try:
|
||||
import kernels
|
||||
ver = getattr(kernels, "__version__", "?")
|
||||
print(f" kernels : {ver}")
|
||||
if ver != "?":
|
||||
major_minor = tuple(int(x) for x in ver.split(".")[:2])
|
||||
if major_minor < (0, 13):
|
||||
warn(f"kernels {ver} < 0.13 — older versions have known kernel-resolution bugs; consider 'uv pip install --upgrade kernels'")
|
||||
else:
|
||||
ok(f"kernels {ver} is recent")
|
||||
print(f" kernels path : {kernels.__file__}")
|
||||
except Exception as e:
|
||||
fail(f"kernels not importable: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(6, "Fetch varunneal/flash-attention-3 (THE actual nanochat code path)")
|
||||
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
print(" calling get_kernel('varunneal/flash-attention-3') …")
|
||||
k = get_kernel("varunneal/flash-attention-3")
|
||||
ok(f"get_kernel returned: {type(k).__name__}")
|
||||
print(f" module path: {getattr(k, '__file__', '(no __file__)')}")
|
||||
iface = k.flash_attn_interface
|
||||
ok(f"flash_attn_interface: {type(iface).__name__}")
|
||||
fn = iface.flash_attn_func
|
||||
ok(f"flash_attn_func: callable={callable(fn)}")
|
||||
print("\n >>> FA3 binary is usable on this pod. <<<")
|
||||
except Exception as e:
|
||||
fail(f"FA3 fetch failed: {type(e).__name__}: {e}")
|
||||
print()
|
||||
traceback.print_exc()
|
||||
print()
|
||||
print(" Likely causes:")
|
||||
print(" 1. Network/DNS issue (HF Hub unreachable from this DC)")
|
||||
print(" 2. Old kernels version with resolver bugs (try kernels>=0.13)")
|
||||
print(" 3. HF token not flowing — try `export HF_HUB_TOKEN=$HF_TOKEN`")
|
||||
print(" 4. No prebuilt binary for this torch/cuda combo (we have torch 2.9 + cu128 — should be supported)")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(7, "HF Hub cache state")
|
||||
import pathlib
|
||||
cache_root = pathlib.Path(os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")))
|
||||
print(f" cache root : {cache_root}")
|
||||
if cache_root.exists():
|
||||
try:
|
||||
size = sum(p.stat().st_size for p in cache_root.rglob("*") if p.is_file())
|
||||
print(f" size on disk : {size / 1024 / 1024:.1f} MB")
|
||||
except Exception as e:
|
||||
print(f" (could not size cache: {e})")
|
||||
fa3_marker = list(cache_root.rglob("*flash*attention*3*"))
|
||||
if fa3_marker:
|
||||
ok(f"found FA3-related cache entries: {len(fa3_marker)}")
|
||||
for p in fa3_marker[:5]:
|
||||
print(f" {p}")
|
||||
else:
|
||||
warn("no flash-attention-3 entries in cache yet")
|
||||
else:
|
||||
print(" (cache directory does not exist)")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(8, "Replicate nanochat.flash_attention detection")
|
||||
try:
|
||||
from nanochat.flash_attention import HAS_FA3, USE_FA3, _fa3
|
||||
if HAS_FA3:
|
||||
ok("nanochat.flash_attention.HAS_FA3 = True")
|
||||
else:
|
||||
fail("nanochat.flash_attention.HAS_FA3 = False (despite section 6 results)")
|
||||
print(f" USE_FA3 = {USE_FA3}")
|
||||
print(f" _fa3 object: {_fa3}")
|
||||
except Exception as e:
|
||||
fail(f"import nanochat.flash_attention failed: {type(e).__name__}: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
section(9, "Verdict")
|
||||
if passed_all:
|
||||
print(f" {OK} all checks passed — FA3 is wired up and base_train should use it")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f" {FAIL} one or more checks failed — see above. Run will fall back to SDPA (slower, possibly much).")
|
||||
print()
|
||||
print(" Continuing the training run anyway is safe; FA3 fallback to SDPA is automatic.")
|
||||
sys.exit(1)
|
||||
|
|
@ -91,9 +91,24 @@ uv sync --extra gpu
|
|||
source .venv/bin/activate
|
||||
pip install --quiet --upgrade huggingface_hub
|
||||
|
||||
# Ensure HF token flows to the kernels lib (some libs read HF_HUB_TOKEN, not HF_TOKEN)
|
||||
export HF_HUB_TOKEN="${HF_TOKEN}"
|
||||
|
||||
# Bump kernels to latest — pyproject pins >=0.11.7, uv often picks exactly that;
|
||||
# 0.11.x had kernel-resolution bugs that affect FA3 loading silently.
|
||||
echo "[smoke] upgrading kernels lib for FA3 reliability"
|
||||
uv pip install --quiet --upgrade 'kernels>=0.13.0' 2>&1 || \
|
||||
echo "[smoke] WARN: kernels upgrade failed (continuing with whatever uv installed)"
|
||||
|
||||
# GPU sanity
|
||||
python -c "import torch; print('[smoke] torch', torch.__version__, 'cuda', torch.cuda.is_available(), 'devices', torch.cuda.device_count())"
|
||||
|
||||
# FA3 diagnostic probe — surfaces the real error if FA3 won't load (nanochat
|
||||
# silently swallows it). Non-fatal: SDPA fallback is automatic if probe fails.
|
||||
echo "[smoke] === FA3 PROBE BEGIN ==="
|
||||
python "$WORKDIR/runs/runpod/probe_fa3.py" || echo "[smoke] FA3 probe reported issues (non-fatal — continuing with SDPA fallback)"
|
||||
echo "[smoke] === FA3 PROBE END ==="
|
||||
|
||||
# Minimum dataset + tokenizer (1 shard, 50M chars — enough for the tokenizer
|
||||
# to train on AND for base_train to consume 20 iterations of tokens)
|
||||
python -m nanochat.dataset -n 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user