diff --git a/scripts/chat_web.py b/scripts/chat_web.py index c07725e..57afc9a 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -45,7 +45,7 @@ from pydantic import BaseModel from typing import List, Optional, AsyncGenerator from dataclasses import dataclass -from nanochat.common import compute_init +from nanochat.common import compute_init, autodetect_device_type from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -80,7 +80,8 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() +device_type = autodetect_device_type() +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) @dataclass class Worker: