mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 21:55:14 +00:00
fix: Resolve HIP error and improve device detection
This commit fixes a `torch.AcceleratorError: HIP error: invalid device function` that occurred during weight initialization on ROCm devices. It also improves the device detection logic to correctly identify and prioritize the ROCm backend. The key changes are: - Patched `nanochat/gpt.py` to initialize weights on the CPU before moving them to the target device, which avoids the HIP kernel error. - Simplified and corrected the device detection logic in `nanochat/common.py` to ensure the ROCm backend is properly selected when available.
This commit is contained in:
parent
054d903cae
commit
19fa71d6e5
|
|
@ -93,15 +93,15 @@ def compute_init():
|
|||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# Detect hardware
|
||||
if torch.cuda.is_available():
|
||||
if hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available():
|
||||
device_type = "cuda" # ROCm uses cuda naming in torch
|
||||
backend = "rccl"
|
||||
elif torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
backend = "nccl"
|
||||
elif torch.xpu.is_available():
|
||||
device_type = "xpu"
|
||||
backend = "ccl"
|
||||
elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available():
|
||||
device_type = "cuda" # ROCm uses cuda naming in torch
|
||||
backend = "rccl"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
backend = "gloo"
|
||||
|
|
|
|||
|
|
@ -191,11 +191,15 @@ class GPT(nn.Module):
|
|||
fan_out = module.weight.size(0)
|
||||
fan_in = module.weight.size(1)
|
||||
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
|
||||
# Initialize on CPU, then move to device
|
||||
weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=std)
|
||||
module.weight.data.copy_(weight)
|
||||
if module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||
# Initialize on CPU, then move to device
|
||||
weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=1.0)
|
||||
module.weight.data.copy_(weight)
|
||||
|
||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user