feat: Add ROCm and device-agnostic support

This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU.

The key changes are:
- Modified `pyproject.toml` to use ROCm-compatible PyTorch wheels and added the `pytorch-triton-rocm` dependency.
- Refactored `nanochat/common.py` to dynamically detect the available hardware and set the device and distributed backend accordingly.
- Updated all training, evaluation, and inference scripts to be device-agnostic, removing hardcoded CUDA references.
- Adapted `speedrun.sh` for single-device execution by replacing `torchrun` with `python`.
- Updated `nanochat/report.py` to provide more generic GPU information.
This commit is contained in:
google-labs-jules[bot] 2025-10-14 05:07:30 +00:00
parent dd6ff9a1cc
commit 08c628cb83
15 changed files with 70 additions and 54 deletions

View File

@ -92,16 +92,24 @@ def get_dist_info():
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# Detect hardware
if torch.cuda.is_available():
device_type = "cuda"
backend = "nccl"
elif torch.xpu.is_available():
device_type = "xpu"
backend = "ccl"
elif hasattr(torch.version, 'hip') and torch.version.hip and torch.cuda.is_available():
device_type = "cuda" # ROCm uses cuda naming in torch
backend = "rccl"
else:
device_type = "cpu"
backend = "gloo"
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
if device_type != "cpu":
torch.cuda.manual_seed(42) # works for rocm too
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
@ -109,14 +117,16 @@ def compute_init():
# Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
device = torch.device(device_type, ddp_local_rank)
if device_type != "cpu":
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend=backend, device_id=device if device_type != "cpu" else None)
dist.barrier()
else:
device = torch.device("cuda")
device = torch.device(device_type)
if ddp_rank == 0:
logger.info(f"Using device: {device}")
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device

View File

@ -6,7 +6,7 @@ from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
def tokenizing_distributed_data_loader(B, T, split, device, tokenizer_threads=4, tokenizer_batch_size=128):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
assert split in ["train", "val"], "split must be 'train' or 'val'"
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
@ -44,6 +44,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
yield inputs, targets

View File

@ -308,7 +308,7 @@ if __name__ == "__main__":
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
# generate the reference sequence using the model.generate() function
generated_tokens = []
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
for token in stream:
@ -316,7 +316,7 @@ if __name__ == "__main__":
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t1 = time.time()
print(f"Reference time: {t1 - t0:.2f}s")
reference_ids = generated_tokens
@ -324,7 +324,7 @@ if __name__ == "__main__":
generated_tokens = []
engine = Engine(model, tokenizer)
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t0 = time.time()
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
@ -332,7 +332,7 @@ if __name__ == "__main__":
chunk = tokenizer.decode([token])
print(chunk, end="", flush=True)
print()
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t1 = time.time()
print(f"Engine time: {t1 - t0:.2f}s")
# compare the two sequences

View File

@ -56,8 +56,8 @@ def get_gpu_info():
info["names"].append(props.name)
info["memory_gb"].append(props.total_memory / (1024**3))
# Get CUDA version
info["cuda_version"] = torch.version.cuda or "unknown"
# Get driver version
info["driver_version"] = torch.version.cuda if torch.version.cuda else torch.version.hip
return info
@ -145,7 +145,7 @@ Generated: {timestamp}
total_vram = sum(gpu_info["memory_gb"])
header += f"""- GPUs: {gpu_info['count']}x {gpu_names}
- GPU Memory: {total_vram:.1f} GB total
- CUDA Version: {gpu_info['cuda_version']}
- Driver Version: {gpu_info['driver_version']}
"""
else:
header += "- GPUs: None available\n"

View File

@ -14,6 +14,7 @@ dependencies = [
"tiktoken>=0.11.0",
"tokenizers>=0.22.0",
"torch>=2.8.0",
"pytorch-triton-rocm==3.4.0; platform_machine == 'x86_64' and sys_platform == 'linux'",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
]
@ -22,15 +23,18 @@ dependencies = [
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"
# target torch to cuda 12.8
# target torch to rocm 6.3
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
{ index = "pytorch-rocm63" },
]
pytorch-triton-rocm = [
{ index = "pytorch-rocm63" },
]
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
name = "pytorch-rocm63"
url = "https://download.pytorch.org/whl/rocm6.3"
explicit = true
[tool.maturin]
@ -39,7 +43,7 @@ bindings = "pyo3"
python-source = "."
manifest-path = "rustbpe/Cargo.toml"
[dependency-groups]
[project.optional-dependencies]
dev = [
"maturin>=1.9.4",
"pytest>=8.0.0",

View File

@ -122,7 +122,7 @@ def main():
# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2:

View File

@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
@ -37,7 +37,7 @@ steps = split_tokens // tokens_per_step
token_bytes = get_token_bytes(device=device)
bpb_results = {}
for split_name in ["train", "val"]:
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name)
loader = tokenizing_distributed_data_loader(device_batch_size, sequence_len, split_name, device=device)
with autocast_ctx:
bpb = evaluate_bpb(model, loader, steps, token_bytes)
print0(f"{split_name} bpb: {bpb:.4f}")

View File

@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -96,7 +96,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device="cuda")
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
@ -133,8 +133,8 @@ adamw_optimizer, muon_optimizer = optimizers
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train")
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
@ -252,7 +252,7 @@ for step in range(num_iterations + 1):
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -275,7 +275,7 @@ for step in range(num_iterations + 1):
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -304,7 +304,8 @@ for step in range(num_iterations + 1):
})
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
if device.type != 'cpu':
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

View File

@ -21,7 +21,7 @@ args = parser.parse_args()
# Init the model and tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -195,7 +195,7 @@ if __name__ == "__main__":
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=ptdtype)
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)

View File

@ -57,7 +57,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process

View File

@ -63,7 +63,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process

View File

@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
class ChatMessage(BaseModel):
role: str

View File

@ -53,7 +53,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=dtype)
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process
@ -214,7 +214,7 @@ while True:
# -------------------------------------------------------------------------
# single training step
# evaluate the gradient
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t0 = time.time()
for micro_step in range(grad_accum_steps):
with autocast_ctx:
@ -235,7 +235,7 @@ while True:
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
if device.type != 'cpu': torch.cuda.synchronize()
t1 = time.time()
dt = t1 - t0
# -------------------------------------------------------------------------
@ -267,7 +267,8 @@ while True:
})
# print a few more stats
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
if device.type != 'cpu':
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated(device=device) / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")

View File

@ -92,25 +92,25 @@ echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID
# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
python -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
python -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
python -m scripts.base_eval
# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
python -m scripts.mid_train -- --run=$WANDB_RUN
python -m scripts.chat_eval -- -i mid
# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
python -m scripts.chat_sft -- --run=$WANDB_RUN
python -m scripts.chat_eval -- -i sft
# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
@ -123,9 +123,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# (optional)
# run reinforcement learning
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
# python -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
# python -m scripts.chat_eval -- -i rl -a GSM8K
# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections