diff --git a/nanochat/common.py b/nanochat/common.py index 2dd0792..9c1e30e 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -140,14 +140,66 @@ def get_dist_info(): return False, 0, 0, 1 def autodetect_device_type(): - # prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU + """ + Enhanced platform detection for ARM64, RISC-V, and edge devices. + + Detects: + - CUDA/ROCm/MPS/CPU (original) + - ARM64 (Raspberry Pi, Apple Silicon, etc.) + - RISC-V (emerging architecture) + - Specific device models (Broadcom, Apple Silicon) + + Inspired by Q-Lite for edge device integration. + """ + import platform + import torch + + # Original detection (CUDA/MPS/CPU) if torch.cuda.is_available(): device_type = "cuda" elif torch.backends.mps.is_available(): device_type = "mps" else: device_type = "cpu" - print0(f"Autodetected device type: {device_type}") + + # Enhanced detection (ARM64/RISC-V) + machine = platform.machine() + system = platform.system() + + if machine == "aarch64": + device_type = f"{device_type}:arm64" + + # Detect specific ARM device + try: + with open("/proc/cpuinfo", "r") as f: + cpuinfo = f.read() + + if "Raspberry Pi" in cpuinfo: + # Extract Pi version + for line in cpuinfo.split("\n"): + if "Hardware" in line: + device_model = line.split(":")[1].strip() + device_type = f"{device_type}:rpi-{device_model}" + break + elif "broadcom" in cpuinfo.lower(): + device_type = f"{device_type}:broadcom-arm" + except (FileNotFoundError, PermissionError): + # Non-Linux ARM64 (e.g., macOS ARM64) + if system == "Darwin": + device_type = f"{device_type}:apple-silicon" + else: + device_type = f"{device_type}:generic-arm64" + + elif machine == "riscv64": + device_type = f"{device_type}:riscv64" + + # Add platform info logging + print0(f"Autodetected platform:") + print0(f" System: {system}") + print0(f" Machine: {machine}") + print0(f" Device: {device_type}") + + return device_type return device_type def compute_init(device_type="cuda"): # cuda|cpu|mps