diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..6eea6cd 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -89,15 +89,18 @@ def get_dist_info(): else: return False, 0, 0, 1 -def compute_init(): +def compute_init(device_type="cuda"): """Basic initialization that we keep doing over and over, so make common.""" # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + assert torch.cuda.is_available() | torch.mps.is_available(), "CUDA or MPS is needed for a distributed run atm" # Reproducibility torch.manual_seed(42) - torch.cuda.manual_seed(42) + if device_type == "cuda": + torch.cuda.manual_seed(42) + elif device_type == "mps": + torch.mps.manual_seed(42) # skipping full reproducibility for now, possibly investigate slowdown later # torch.use_deterministic_algorithms(True) # torch.backends.cudnn.deterministic = True @@ -109,12 +112,15 @@ def compute_init(): # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if ddp: - device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device + device = torch.device(device_type, ddp_local_rank) + if device_type == "mps": + torch.mps.set_device(device) # make "device" default to this device + else: + torch.cuda.set_device(device) # make "device" default to this device dist.init_process_group(backend="nccl", device_id=device) dist.barrier() else: - device = torch.device("cuda") + device = torch.device(device_type) if ddp_rank == 0: logger.info(f"Distributed world size: {ddp_world_size}") diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 3a38147..4fb7e92 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -17,11 +17,13 @@ parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--prompt', type=str, default='', help='Prompt the model, get a single response back') parser.add_argument('-t', '--temperature', type=float, default=0.6, help='Temperature for generation') parser.add_argument('-k', '--top-k', type=int, default=50, help='Top-k sampling parameter') +parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps') + args = parser.parse_args() # Init the model and tokenizer -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device) +autocast_ctx = torch.amp.autocast(device_type=args.device, dtype=torch.bfloat16) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 1a4cfe2..44ae749 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -29,10 +29,11 @@ parser.add_argument('-g', '--model-tag', type=str, default=None, help='Model tag parser.add_argument('-s', '--step', type=int, default=None, help='Step to load') parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on') parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to') +parser.add_argument('-d', '--device', type=str, default='cuda', choices=['mps', 'cuda'], help='Device to run the model on: cuda|mps') args = parser.parse_args() -ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(args.device) +autocast_ctx = torch.amp.autocast(device_type=args.device, dtype=torch.bfloat16) class ChatMessage(BaseModel): role: str