mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-06 07:35:32 +00:00
Fix chat_web.py to detect mps
This commit is contained in:
parent
cf2baf9933
commit
423bf31dbe
|
|
@ -45,7 +45,7 @@ from pydantic import BaseModel
|
|||
from typing import List, Optional, AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, autodetect_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -80,7 +80,8 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
device_type = autodetect_device_type()
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
|
||||
@dataclass
|
||||
class Worker:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user