From a430ed5a638cdd0bd75fce39b08bb58d27be3d21 Mon Sep 17 00:00:00 2001 From: bedwards Date: Thu, 20 Nov 2025 19:58:29 -0600 Subject: [PATCH] enable autocast for mps device --- nanochat/engine.py | 2 +- scripts/base_eval.py | 2 +- scripts/base_loss.py | 2 +- scripts/base_train.py | 6 +++--- scripts/chat_cli.py | 2 +- scripts/chat_eval.py | 2 +- scripts/chat_sft.py | 2 +- scripts/chat_web.py | 8 ++++---- scripts/mid_train.py | 8 ++++---- 9 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nanochat/engine.py b/nanochat/engine.py index d749d94..5ac62da 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -330,7 +330,7 @@ if __name__ == "__main__": # init compute ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() device_type = autodetect_device_type() - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext() # load the model and tokenizer model, tokenizer, meta = load_model("base", device, phase="eval") diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 3663538..7578a82 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -154,7 +154,7 @@ def main(): # distributed / precision setup device_type = autodetect_device_type() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext() # Load model and tokenizer from command line or from file system if args.hf_path is not None: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index abcde5f..50b0dd3 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -29,7 +29,7 @@ device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=model_tag, step=model_step) sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext() # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size diff --git a/scripts/base_train.py b/scripts/base_train.py index c9ea6c9..5fe114c 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -72,9 +72,9 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None -get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext() +synchronize = torch.cuda.synchronize if device_type != "cpu" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type != "cpu" else lambda: 0 # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index b14843a..c82802e 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -27,7 +27,7 @@ args = parser.parse_args() device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index cae2f0f..ce4185e 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -200,7 +200,7 @@ if __name__ == "__main__": device_type = autodetect_device_type() if args.device_type == "" else args.device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext() model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index bbeb1f9..dd1bdec 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -66,7 +66,7 @@ device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext() # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 4b67b62..d8a49fb 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -100,7 +100,7 @@ class WorkerPool: def __init__(self, num_gpus: Optional[int] = None): if num_gpus is None: - if device_type == "cuda": + if device_type != "cpu": num_gpus = torch.cuda.device_count() else: num_gpus = 1 # e.g. cpu|mps @@ -112,11 +112,11 @@ class WorkerPool: """Load model on each GPU.""" print(f"Initializing worker pool with {self.num_gpus} GPUs...") if self.num_gpus > 1: - assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." + assert device_type != "cpu", "Only CUDA supports multiple workers/GPUs. cpu|mps does not." for gpu_id in range(self.num_gpus): - if device_type == "cuda": + if device_type != "cpu": device = torch.device(f"cuda:{gpu_id}") print(f"Loading model on GPU {gpu_id}...") else: @@ -125,7 +125,7 @@ class WorkerPool: model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step) engine = Engine(model, tokenizer) - autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext() + autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext() worker = Worker( gpu_id=gpu_id, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 6c2b82f..f8c41c0 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -57,9 +57,9 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi device_type = autodetect_device_type() if device_type == "" else device_type ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 -autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext() -synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None -get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0 +autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext() +synchronize = torch.cuda.synchronize if device_type != "cpu" else lambda: None +get_max_memory = torch.cuda.max_memory_allocated if device_type != "cpu" else lambda: 0 # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -123,7 +123,7 @@ def mid_data_generator(split): needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets token_buffer = deque() # CUDA supports memory pinning for faster transfers between CPU and GPU: - scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda")) + scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type != "cpu")) cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents it = 0 # iteration counter while True: