restore original assert messages

This commit is contained in:
Sofie Van Landeghem 2026-02-24 16:51:34 +01:00 committed by GitHub
parent 7e42702894
commit c546a44001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -434,8 +434,8 @@ class GPT(nn.Module):
self._ensure_rope_cache(T0 + T, device=idx.device)
# Now it's safe to slice
assert idx.device == self.cos.device, f"RoPE buffers and idx device mismatch: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "RoPE buffers must be bfloat16"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]