fix: Resolve HIP error and improve device detection

This commit fixes a `torch.AcceleratorError: HIP error: invalid device function` that occurred during weight initialization on ROCm devices. It also improves the device detection logic to correctly identify and prioritize the ROCm backend.

The key changes are:
- Patched `nanochat/gpt.py` to initialize weights on the CPU before moving them to the target device, which avoids the HIP kernel error.
- Simplified and corrected the device detection logic in `nanochat/common.py` to ensure the ROCm backend is properly selected when available.
This commit is contained in:
google-labs-jules[bot] 2025-10-14 06:07:13 +00:00
parent 054d903cae
commit 19fa71d6e5
2 changed files with 10 additions and 6 deletions

View File

@ -93,15 +93,15 @@ def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# Detect hardware
if torch.cuda.is_available():
if hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available():
device_type = "cuda" # ROCm uses cuda naming in torch
backend = "rccl"
elif torch.cuda.is_available():
device_type = "cuda"
backend = "nccl"
elif torch.xpu.is_available():
device_type = "xpu"
backend = "ccl"
elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available():
device_type = "cuda" # ROCm uses cuda naming in torch
backend = "rccl"
else:
device_type = "cpu"
backend = "gloo"

View File

@ -191,11 +191,15 @@ class GPT(nn.Module):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
# Initialize on CPU, then move to device
weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=std)
module.weight.data.copy_(weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
# Initialize on CPU, then move to device
weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=1.0)
module.weight.data.copy_(weight)
# TODO: bump base theta more, e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):