From bbc816dc77983fe7820c79248e01e1f077251980 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 23 Nov 2025 16:03:02 +0000 Subject: [PATCH] Reduce base_train batch size and set PYTORCH_HIP_ALLOC_CONF To address "HIP out of memory" errors on some AMD ROCm configurations (potentially due to memory fragmentation or limited per-device VRAM), this change: 1. Reduces the default `device_batch_size` from 32 to 16. 2. Explicitly sets `PYTORCH_HIP_ALLOC_CONF=expandable_segments:True` when ROCm is detected, which helps the allocator manage fragmented memory better than the default behavior. --- scripts/base_train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/base_train.py b/scripts/base_train.py index 74068ed..3bcb38a 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -19,6 +19,9 @@ import wandb import torch if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip): os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" + # Also set the HIP-specific env var if on ROCm, as suggested by OOM errors + if hasattr(torch.version, 'hip') and torch.version.hip: + os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True" from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state @@ -43,7 +46,7 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable) target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable) # Optimization -device_batch_size = 32 # per-device batch size (set to not OOM) +device_batch_size = 16 # per-device batch size (reduced from 32 to avoid OOM on some GPUs) total_batch_size = 524288 # total desired batch size, in #tokens embedding_lr = 0.2 # learning rate for the embedding parameters (Adam) unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)