This commit is contained in:
Ralph 2026-03-03 10:44:18 +02:00 committed by GitHub
commit 4cdf5d35b5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -140,14 +140,66 @@ def get_dist_info():
return False, 0, 0, 1
def autodetect_device_type():
# prefer to use CUDA if available, otherwise use MPS, otherwise fallback on CPU
"""
Enhanced platform detection for ARM64, RISC-V, and edge devices.
Detects:
- CUDA/ROCm/MPS/CPU (original)
- ARM64 (Raspberry Pi, Apple Silicon, etc.)
- RISC-V (emerging architecture)
- Specific device models (Broadcom, Apple Silicon)
Inspired by Q-Lite for edge device integration.
"""
import platform
import torch
# Original detection (CUDA/MPS/CPU)
if torch.cuda.is_available():
device_type = "cuda"
elif torch.backends.mps.is_available():
device_type = "mps"
else:
device_type = "cpu"
print0(f"Autodetected device type: {device_type}")
# Enhanced detection (ARM64/RISC-V)
machine = platform.machine()
system = platform.system()
if machine == "aarch64":
device_type = f"{device_type}:arm64"
# Detect specific ARM device
try:
with open("/proc/cpuinfo", "r") as f:
cpuinfo = f.read()
if "Raspberry Pi" in cpuinfo:
# Extract Pi version
for line in cpuinfo.split("\n"):
if "Hardware" in line:
device_model = line.split(":")[1].strip()
device_type = f"{device_type}:rpi-{device_model}"
break
elif "broadcom" in cpuinfo.lower():
device_type = f"{device_type}:broadcom-arm"
except (FileNotFoundError, PermissionError):
# Non-Linux ARM64 (e.g., macOS ARM64)
if system == "Darwin":
device_type = f"{device_type}:apple-silicon"
else:
device_type = f"{device_type}:generic-arm64"
elif machine == "riscv64":
device_type = f"{device_type}:riscv64"
# Add platform info logging
print0(f"Autodetected platform:")
print0(f" System: {system}")
print0(f" Machine: {machine}")
print0(f" Device: {device_type}")
return device_type
return device_type
def compute_init(device_type="cuda"): # cuda|cpu|mps