From 08c628cb83febdf17b90ced69cfd65c263cf1316 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Tue, 14 Oct 2025 05:07:30 +0000 Subject: [PATCH] feat: Add ROCm and device-agnostic support This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU. The key changes are: - Modified `pyproject.toml` to use ROCm-compatible PyTorch wheels and added the `pytorch-triton-rocm` dependency. - Refactored `nanochat/common.py` to dynamically detect the available hardware and set the device and distributed backend accordingly. - Updated all training, evaluation, and inference scripts to be device-agnostic, removing hardcoded CUDA references. - Adapted `speedrun.sh` for single-device execution by replacing `torchrun` with `python`. - Updated `nanochat/report.py` to provide more generic GPU information. --- nanochat/common.py | 32 +++++++++++++++++++++----------- nanochat/dataloader.py | 6 +++--- nanochat/engine.py | 8 ++++---- nanochat/report.py | 6 +++--- pyproject.toml | 14 +++++++++----- scripts/base_eval.py | 2 +- scripts/base_loss.py | 4 ++-- scripts/base_train.py | 15 ++++++++------- scripts/chat_cli.py | 2 +- scripts/chat_eval.py | 2 +- scripts/chat_rl.py | 2 +- scripts/chat_sft.py | 2 +- scripts/chat_web.py | 2 +- scripts/mid_train.py | 9 +++++---- speedrun.sh | 18 +++++++++--------- 15 files changed, 70 insertions(+), 54 deletions(-) diff --git a/nanochat/common.py b/nanochat/common.py index 8b10df9..0f40b47 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -92,16 +92,24 @@ def get_dist_info(): def compute_init(): """Basic initialization that we keep doing over and over, so make common.""" - # CUDA is currently required - assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm" + # Detect hardware + if torch.cuda.is_available(): + device_type = "cuda" + backend = "nccl" + elif torch.xpu.is_available(): + device_type = "xpu" + backend = "ccl" + elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available(): + device_type = "cuda" # ROCm uses cuda naming in torch + backend = "rccl" + else: + device_type = "cpu" + backend = "gloo" # Reproducibility torch.manual_seed(42) - torch.cuda.manual_seed(42) - # skipping full reproducibility for now, possibly investigate slowdown later - # torch.use_deterministic_algorithms(True) - # torch.backends.cudnn.deterministic = True - # torch.backends.cudnn.benchmark = False + if device_type != "cpu": + torch.cuda.manual_seed(42) # works for rocm too # Precision torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls @@ -109,14 +117,16 @@ def compute_init(): # Distributed setup: Distributed Data Parallel (DDP), optional ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() if ddp: - device = torch.device("cuda", ddp_local_rank) - torch.cuda.set_device(device) # make "cuda" default to this device - dist.init_process_group(backend="nccl", device_id=device) + device = torch.device(device_type, ddp_local_rank) + if device_type != "cpu": + torch.cuda.set_device(device) # make "cuda" default to this device + dist.init_process_group(backend=backend, device_id=device if device_type != "cpu" else None) dist.barrier() else: - device = torch.device("cuda") + device = torch.device(device_type) if ddp_rank == 0: + logger.info(f"Using device: {device}") logger.info(f"Distributed world size: {ddp_world_size}") return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device diff --git a/nanochat/dataloader.py b/nanochat/dataloader.py index c1636b1..aecd5cc 100644 --- a/nanochat/dataloader.py +++ b/nanochat/dataloader.py @@ -6,7 +6,7 @@ from nanochat.common import get_dist_info from nanochat.dataset import parquets_iter_batched from nanochat.tokenizer import get_tokenizer -def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128): +def tokenizing_distributed_data_loader(B, T, split, device, tokenizer_threads=4, tokenizer_batch_size=128): """Stream pretraining text from parquet files, tokenize, yield training batches.""" assert split in ["train", "val"], "split must be 'train' or 'val'" ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info() @@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz inputs_cpu = scratch[:-1].to(dtype=torch.int32) targets_cpu = scratch[1:] # Reshape to 2D and move to GPU async - inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True) - targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True) + inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True) + targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True) yield inputs, targets diff --git a/nanochat/engine.py b/nanochat/engine.py index de1253a..ff697a4 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -308,7 +308,7 @@ if __name__ == "__main__": prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id) # generate the reference sequence using the model.generate() function generated_tokens = [] - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t0 = time.time() stream = model.generate(prompt_tokens, **kwargs) for token in stream: @@ -316,7 +316,7 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t1 = time.time() print(f"Reference time: {t1 - t0:.2f}s") reference_ids = generated_tokens @@ -324,7 +324,7 @@ if __name__ == "__main__": generated_tokens = [] engine = Engine(model, tokenizer) stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32 - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t0 = time.time() for token_column, token_masks in stream: token = token_column[0] # only print out the first row @@ -332,7 +332,7 @@ if __name__ == "__main__": chunk = tokenizer.decode([token]) print(chunk, end="", flush=True) print() - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t1 = time.time() print(f"Engine time: {t1 - t0:.2f}s") # compare the two sequences diff --git a/nanochat/report.py b/nanochat/report.py index 02cd8b0..a3bba3b 100644 --- a/nanochat/report.py +++ b/nanochat/report.py @@ -56,8 +56,8 @@ def get_gpu_info(): info["names"].append(props.name) info["memory_gb"].append(props.total_memory / (1024**3)) - # Get CUDA version - info["cuda_version"] = torch.version.cuda or "unknown" + # Get driver version + info["driver_version"] = torch.version.cuda if torch.version.cuda else torch.version.hip return info @@ -145,7 +145,7 @@ Generated: {timestamp} total_vram = sum(gpu_info["memory_gb"]) header += f"""- GPUs: {gpu_info['count']}x {gpu_names} - GPU Memory: {total_vram:.1f} GB total -- CUDA Version: {gpu_info['cuda_version']} +- Driver Version: {gpu_info['driver_version']} """ else: header += "- GPUs: None available\n" diff --git a/pyproject.toml b/pyproject.toml index ef3833a..9b007b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,6 +14,7 @@ dependencies = [ "tiktoken>=0.11.0", "tokenizers>=0.22.0", "torch>=2.8.0", + "pytorch-triton-rocm==3.4.0; platform_machine == 'x86_64' and sys_platform == 'linux'", "uvicorn>=0.36.0", "wandb>=0.21.3", ] @@ -22,15 +23,18 @@ dependencies = [ requires = ["maturin>=1.7,<2.0"] build-backend = "maturin" -# target torch to cuda 12.8 +# target torch to rocm 6.3 [tool.uv.sources] torch = [ - { index = "pytorch-cu128" }, + { index = "pytorch-rocm63" }, +] +pytorch-triton-rocm = [ + { index = "pytorch-rocm63" }, ] [[tool.uv.index]] -name = "pytorch-cu128" -url = "https://download.pytorch.org/whl/cu128" +name = "pytorch-rocm63" +url = "https://download.pytorch.org/whl/rocm6.3" explicit = true [tool.maturin] @@ -39,7 +43,7 @@ bindings = "pyo3" python-source = "." manifest-path = "rustbpe/Cargo.toml" -[dependency-groups] +[project.optional-dependencies] dev = [ "maturin>=1.9.4", "pytest>=8.0.0", diff --git a/scripts/base_eval.py b/scripts/base_eval.py index a566d49..ee56440 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -122,7 +122,7 @@ def main(): # distributed / precision setup ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) + autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) # Load model and tokenizer from command line or from file system if len(sys.argv) >= 2: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index ba3876d..57d99cb 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really # Set up the precision we'll run with -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size @@ -37,7 +37,7 @@ steps = split_tokens // tokens_per_step token_bytes = get_token_bytes(device=device) bpb_results = {} for split_name in ["train", "val"]: - loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name) + loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device) with autocast_ctx: bpb = evaluate_bpb(model, loader, steps, token_bytes) print0(f"{split_name} bpb: {bpb:.4f}") diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..9fbd271 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # Compute init ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -96,7 +96,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) -model.to_empty(device="cuda") +model.to_empty(device=device) model.init_weights() orig_model = model # original, uncompiled model, for saving raw model state_dict model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through @@ -133,8 +133,8 @@ adamw_optimizer, muon_optimizer = optimizers # Initialize the DataLoaders for train/val base_dir = get_base_dir() tokens_dir = os.path.join(base_dir, "tokenized_data") -train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train") -build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val") +train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device) +build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device) x, y = next(train_loader) # kick off load of the very first batch of data # ----------------------------------------------------------------------------- @@ -252,7 +252,7 @@ for step in range(num_iterations + 1): # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -275,7 +275,7 @@ for step in range(num_iterations + 1): for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -304,7 +304,8 @@ for step in range(num_iterations + 1): }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +if device.type != 'cpu': + print0(f"Peak memory usage: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index 3a38147..e37cf64 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -21,7 +21,7 @@ args = parser.parse_args() # Init the model and tokenizer ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_eval.py b/scripts/chat_eval.py index df6a01a..0a377a1 100644 --- a/scripts/chat_eval.py +++ b/scripts/chat_eval.py @@ -195,7 +195,7 @@ if __name__ == "__main__": ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16 - autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) + autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=ptdtype) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) engine = Engine(model, tokenizer) diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index af70bda..45e7538 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -57,7 +57,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index 8389deb..0b5be36 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -63,7 +63,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 1a4cfe2..55abc79 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) class ChatMessage(BaseModel): role: str diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 202682d..18daedf 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -53,7 +53,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 dtype = torch.float32 if dtype == 'float32' else torch.bfloat16 -autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype) +autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process @@ -214,7 +214,7 @@ while True: # ------------------------------------------------------------------------- # single training step # evaluate the gradient - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t0 = time.time() for micro_step in range(grad_accum_steps): with autocast_ctx: @@ -235,7 +235,7 @@ while True: for opt in optimizers: opt.step() model.zero_grad(set_to_none=True) - torch.cuda.synchronize() + if device.type != 'cpu': torch.cuda.synchronize() t1 = time.time() dt = t1 - t0 # ------------------------------------------------------------------------- @@ -267,7 +267,8 @@ while True: }) # print a few more stats -print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB") +if device.type != 'cpu': + print0(f"Peak memory usage: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.2f}MiB") print0(f"Total training time: {total_training_time/60:.2f}m") print0(f"Minimum validation bpb: {min_val_bpb:.4f}") diff --git a/speedrun.sh b/speedrun.sh index d2498ee..71e9350 100644 --- a/speedrun.sh +++ b/speedrun.sh @@ -92,25 +92,25 @@ echo "Waiting for dataset download to complete..." wait $DATASET_DOWNLOAD_PID # pretrain the d20 model -torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN +python -m scripts.base_train -- --depth=20 --run=$WANDB_RUN # evaluate the model on a larger chunk of train/val data and draw some samples -torchrun --standalone --nproc_per_node=8 -m scripts.base_loss +python -m scripts.base_loss # evaluate the model on CORE tasks -torchrun --standalone --nproc_per_node=8 -m scripts.base_eval +python -m scripts.base_eval # ----------------------------------------------------------------------------- # Midtraining (teach the model conversation special tokens, tool use, multiple choice) # run midtraining and eval the model -torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid +python -m scripts.mid_train -- --run=$WANDB_RUN +python -m scripts.chat_eval -- -i mid # ----------------------------------------------------------------------------- # Supervised Finetuning (domain adaptation to each sequence all by itself per row) # train sft and re-eval right away (should see a small bump) -torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN -torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft +python -m scripts.chat_sft -- --run=$WANDB_RUN +python -m scripts.chat_eval -- -i sft # chat with the model over CLI! Leave out the -p to chat interactively # python -m scripts.chat_cli -p "Why is the sky blue?" @@ -123,9 +123,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft # (optional) # run reinforcement learning -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN +# python -m scripts.chat_rl -- --run=$WANDB_RUN # eval the RL model only on GSM8K -# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K +# python -m scripts.chat_eval -- -i rl -a GSM8K # ----------------------------------------------------------------------------- # Generate the full report by putting together all the sections