diff --git a/nanochat/common.py b/nanochat/common.py index 0f40b47..e92380d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -93,15 +93,15 @@ def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" # Detect hardware - if torch.cuda.is_available(): + if hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): + device_type = "cuda" # ROCm uses cuda naming in torch + backend = "rccl" + elif torch.cuda.is_available(): device_type = "cuda" backend = "nccl" elif torch.xpu.is_available(): device_type = "xpu" backend = "ccl" - elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): - device_type = "cuda" # ROCm uses cuda naming in torch - backend = "rccl" else: device_type = "cpu" backend = "gloo" diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..732f331 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -191,11 +191,15 @@ class GPT(nn.Module): fan_out = module.weight.size(0) fan_in = module.weight.size(1) std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in)) - torch.nn.init.normal_(module.weight, mean=0.0, std=std) + # Initialize on CPU, then move to device + weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=std) + module.weight.data.copy_(weight) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): - torch.nn.init.normal_(module.weight, mean=0.0, std=1.0) + # Initialize on CPU, then move to device + weight = torch.empty_like(module.weight, device='cpu').normal_(mean=0.0, std=1.0) + module.weight.data.copy_(weight) # TODO: bump base theta more, e.g. 100K is more common more recently def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):