mirror of
https://github.com/karpathy/nanochat.git
synced 2026-05-19 22:27:37 +00:00
fix KV cache dtype mismatch on CPU: use COMPUTE_DTYPE instead of hardcoded logic
The KV cache was hardcoded to float32 on non-CUDA devices, but the model weights are loaded in bfloat16 via NANOCHAT_DTYPE env var. This caused a RuntimeError in scaled_dot_product_attention. Now uses COMPUTE_DTYPE from common.py which respects the env var. Also broadened CI/CD path triggers to nanochat/**. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
c3f683f3e3
commit
40586713bd
3
.github/workflows/deploy.yml
vendored
3
.github/workflows/deploy.yml
vendored
|
|
@ -4,8 +4,7 @@ on:
|
|||
push:
|
||||
branches: [master]
|
||||
paths:
|
||||
- 'nanochat/ui.html'
|
||||
- 'nanochat/logo.svg'
|
||||
- 'nanochat/**'
|
||||
- 'scripts/chat_web.py'
|
||||
- 'scripts/chat_cli.py'
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import signal
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from collections import deque
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
|
@ -183,7 +183,7 @@ class Engine:
|
|||
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
|
||||
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
|
||||
# In particular, the KVCache should allocate its tensors lazily
|
||||
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
|
||||
dtype = COMPUTE_DTYPE
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user