fix KV cache dtype mismatch on CPU: use COMPUTE_DTYPE instead of hardcoded logic

The KV cache was hardcoded to float32 on non-CUDA devices, but the model
weights are loaded in bfloat16 via NANOCHAT_DTYPE env var. This caused a
RuntimeError in scaled_dot_product_attention. Now uses COMPUTE_DTYPE from
common.py which respects the env var.

Also broadened CI/CD path triggers to nanochat/**.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Manmohan Sharma 2026-03-23 10:04:33 -04:00
parent c3f683f3e3
commit 40586713bd
No known key found for this signature in database
2 changed files with 3 additions and 4 deletions

View File

@ -4,8 +4,7 @@ on:
push:
branches: [master]
paths:
- 'nanochat/ui.html'
- 'nanochat/logo.svg'
- 'nanochat/**'
- 'scripts/chat_web.py'
- 'scripts/chat_cli.py'

View File

@ -17,7 +17,7 @@ import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init, autodetect_device_type
from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE
from nanochat.checkpoint_manager import load_model
# -----------------------------------------------------------------------------
@ -183,7 +183,7 @@ class Engine:
# As a quick hack, we're making generate() function inherit and know about this repo-wise assumption.
# I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase.
# In particular, the KVCache should allocate its tensors lazily
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
dtype = COMPUTE_DTYPE
rng = torch.Generator(device=device)
rng.manual_seed(seed)