From 19fa71d6e51639b6a154b5f21614df89708cae93 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 14 Oct 2025 06:07:13 +0000 Subject: [PATCH] fix: Resolve HIP error and improve device detection This commit fixes a `torch.AcceleratorError: HIP error: invalid device function` that occurred during weight initialization on ROCm devices. It also improves the device detection logic to correctly identify and prioritize the ROCm backend. The key changes are: - Patched `nanochat/gpt.py` to initialize weights on the CPU before moving them to the target device, which avoids the HIP kernel error. - Simplified and corrected the device detection logic in `nanochat/common.py` to ensure the ROCm backend is properly selected when available. --- nanochat/common.py | 8 ++++---- nanochat/gpt.py | 8 ++++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 0f40b47..e92380d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -93,15 +93,15 @@ def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" # Detect hardware - if torch.cuda.is_available(): + if hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): + device_type = "cuda" # ROCm uses cuda naming in torch + backend = "rccl" + elif torch.cuda.is_available(): device_type = "cuda" backend = "nccl" elif torch.xpu.is_available(): device_type = "xpu" backend = "ccl" - elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): - device_type = "cuda" # ROCm uses cuda naming in torch - backend = "rccl" else: device_type = "cpu" backend = "gloo" diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..732f331 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -191,11 +191,15 @@ class GPT(nn.Module): fan_out = module.weight.size(0) fan_in = module.weight.size(1) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) - torch.nn.init.normal_(module.weight, mean=0.0, std=std) + # Initialize on CPU, then move to device + weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=std) + module.weight.data.copy_(weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) + # Initialize on CPU, then move to device + weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=1.0) + module.weight.data.copy_(weight) # TODO: bump base theta more, e.g. 100K is more common more recently def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):