diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 0b717bb..3e12424 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -434,8 +434,8 @@ class GPT(nn.Module): self._ensure_rope_cache(T0 + T, device=idx.device) # Now it's safe to slice - assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16" + assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" + assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]