From 40586713bddf444ec597cf666957d0d4b436e6c9 Mon Sep 17 00:00:00 2001 From: Manmohan Sharma Date: Mon, 23 Mar 2026 10:04:33 -0400 Subject: [PATCH] fix KV cache dtype mismatch on CPU: use COMPUTE_DTYPE instead of hardcoded logic The KV cache was hardcoded to float32 on non-CUDA devices, but the model weights are loaded in bfloat16 via NANOCHAT_DTYPE env var. This caused a RuntimeError in scaled_dot_product_attention. Now uses COMPUTE_DTYPE from common.py which respects the env var. Also broadened CI/CD path triggers to nanochat/**. Co-Authored-By: Claude Opus 4.6 (1M context) --- .github/workflows/deploy.yml | 3 +-- nanochat/engine.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index f4a9442d..06cb35d1 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -4,8 +4,7 @@ on: push: branches: [master] paths: - - 'nanochat/ui.html' - - 'nanochat/logo.svg' + - 'nanochat/**' - 'scripts/chat_web.py' - 'scripts/chat_cli.py' diff --git a/nanochat/engine.py b/nanochat/engine.py index aa2e6a98..4bdfd654 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -17,7 +17,7 @@ import signal import warnings from contextlib import contextmanager from collections import deque -from nanochat.common import compute_init, autodetect_device_type +from nanochat.common import compute_init, autodetect_device_type, COMPUTE_DTYPE from nanochat.checkpoint_manager import load_model # ----------------------------------------------------------------------------- @@ -183,7 +183,7 @@ class Engine: # As a quick hack, we're making generate() function inherit and know about this repo-wise assumption. # I think there has to be a bigger refactor to deal with device/dtype tracking across the codebase. # In particular, the KVCache should allocate its tensors lazily - dtype = torch.bfloat16 if device.type == "cuda" else torch.float32 + dtype = COMPUTE_DTYPE rng = torch.Generator(device=device) rng.manual_seed(seed)