diff --git a/scripts/base_train.py b/scripts/base_train.py index 72ee147..afa3b7a 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -12,7 +12,7 @@ python -m scripts.base_train --depth=4 --max_seq_len=512 --device_batch_size=1 - """ import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time from contextlib import nullcontext diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index f93a6e6..1d14187 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -10,7 +10,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft """ import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import wandb import torch diff --git a/scripts/mid_train.py b/scripts/mid_train.py index dd0768c..848c7e7 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -11,7 +11,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_ from collections import deque import os -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import time import wandb import torch