mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-07 01:40:30 +00:00
Merge 4f2f78a4a5 into 83dccc20ae
This commit is contained in:
commit
4cdf5d35b5
|
|
@ -140,14 +140,66 @@ def get_dist_info():
|
|||
return False, 0, 0, 1
|
||||
|
||||
def autodetect_device_type():
|
||||
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
|
||||
"""
|
||||
Enhanced platform detection for ARM64, RISC-V, and edge devices.
|
||||
|
||||
Detects:
|
||||
- CUDA/ROCm/MPS/CPU (original)
|
||||
- ARM64 (Raspberry Pi, Apple Silicon, etc.)
|
||||
- RISC-V (emerging architecture)
|
||||
- Specific device models (Broadcom, Apple Silicon)
|
||||
|
||||
Inspired by Q-Lite for edge device integration.
|
||||
"""
|
||||
import platform
|
||||
import torch
|
||||
|
||||
# Original detection (CUDA/MPS/CPU)
|
||||
if torch.cuda.is_available():
|
||||
device_type = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
device_type = "mps"
|
||||
else:
|
||||
device_type = "cpu"
|
||||
print0(f"Autodetected device type: {device_type}")
|
||||
|
||||
# Enhanced detection (ARM64/RISC-V)
|
||||
machine = platform.machine()
|
||||
system = platform.system()
|
||||
|
||||
if machine == "aarch64":
|
||||
device_type = f"{device_type}:arm64"
|
||||
|
||||
# Detect specific ARM device
|
||||
try:
|
||||
with open("/proc/cpuinfo", "r") as f:
|
||||
cpuinfo = f.read()
|
||||
|
||||
if "Raspberry Pi" in cpuinfo:
|
||||
# Extract Pi version
|
||||
for line in cpuinfo.split("\n"):
|
||||
if "Hardware" in line:
|
||||
device_model = line.split(":")[1].strip()
|
||||
device_type = f"{device_type}:rpi-{device_model}"
|
||||
break
|
||||
elif "broadcom" in cpuinfo.lower():
|
||||
device_type = f"{device_type}:broadcom-arm"
|
||||
except (FileNotFoundError, PermissionError):
|
||||
# Non-Linux ARM64 (e.g., macOS ARM64)
|
||||
if system == "Darwin":
|
||||
device_type = f"{device_type}:apple-silicon"
|
||||
else:
|
||||
device_type = f"{device_type}:generic-arm64"
|
||||
|
||||
elif machine == "riscv64":
|
||||
device_type = f"{device_type}:riscv64"
|
||||
|
||||
# Add platform info logging
|
||||
print0(f"Autodetected platform:")
|
||||
print0(f" System: {system}")
|
||||
print0(f" Machine: {machine}")
|
||||
print0(f" Device: {device_type}")
|
||||
|
||||
return device_type
|
||||
return device_type
|
||||
|
||||
def compute_init(device_type="cuda"): # cuda|cpu|mps
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user