mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Reduce base_train batch size and set PYTORCH_HIP_ALLOC_CONF
To address "HIP out of memory" errors on some AMD ROCm configurations (potentially due to memory fragmentation or limited per-device VRAM), this change: 1. Reduces the default `device_batch_size` from 32 to 16. 2. Explicitly sets `PYTORCH_HIP_ALLOC_CONF=expandable_segments:True` when ROCm is detected, which helps the allocator manage fragmented memory better than the default behavior.
This commit is contained in:
parent
41ba458c3b
commit
bbc816dc77
|
|
@ -19,6 +19,9 @@ import wandb
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip):
|
if torch.cuda.is_available() or (hasattr(torch.version, 'hip') and torch.version.hip):
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
# Also set the HIP-specific env var if on ROCm, as suggested by OOM errors
|
||||||
|
if hasattr(torch.version, 'hip') and torch.version.hip:
|
||||||
|
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
|
||||||
|
|
@ -43,7 +46,7 @@ num_iterations = -1 # explicit number of steps of the optimization (-1 = disable
|
||||||
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
target_flops = -1.0 # calculate num_iterations to reach target_flops. Useful for scaling laws experiments (-1 = disable)
|
||||||
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
target_param_data_ratio = 20 # calculate num_iterations to maintain fixed data:param ratio (Chinchilla=20) (-1 = disable)
|
||||||
# Optimization
|
# Optimization
|
||||||
device_batch_size = 32 # per-device batch size (set to not OOM)
|
device_batch_size = 16 # per-device batch size (reduced from 32 to avoid OOM on some GPUs)
|
||||||
total_batch_size = 524288 # total desired batch size, in #tokens
|
total_batch_size = 524288 # total desired batch size, in #tokens
|
||||||
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
embedding_lr = 0.2 # learning rate for the embedding parameters (Adam)
|
||||||
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
unembedding_lr = 0.004 # learning rate for the unembedding parameters (Adam)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user