Fix chat_web.py to detect mps

This commit is contained in:
Shivam Ojha 2025-10-18 17:00:34 -05:00
parent cf2baf9933
commit 423bf31dbe

View File

@ -45,7 +45,7 @@ from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -80,7 +80,8 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
@dataclass
class Worker: