more configurable

This commit is contained in:
Muheng 2026-01-08 15:43:36 +00:00
parent 9196ff6fc0
commit 1f09520820
2 changed files with 22 additions and 6 deletions

View File

@ -49,19 +49,29 @@ def _get_env_float(name, default):
except ValueError as exc:
raise ValueError(f"Invalid {name} env value: {val}") from exc
def _get_env_int(name, default):
val = os.getenv(name)
if val is None or val == "":
return default
try:
return int(val)
except ValueError as exc:
raise ValueError(f"Invalid {name} env value: {val}") from exc
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
# Runtime
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
# Model architecture
depth = 6 # the depth of the Transformer model to train (matches nanoMoE n_layer=6), rest of the kwargs are derived
depth = _get_env_int("DEPTH", 6) # the depth of the Transformer model to train (matches nanoMoE n_layer=6), rest of the kwargs are derived
depth = _get_env_int("N_LAYER", depth)
max_seq_len = 1024 # max context length (matches nanoMoE block_size=1024)
dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ (matches nanoMoE)
bias = False # do we use bias inside LayerNorm and Linear layers? (matches nanoMoE)
# MoE settings (matching nanoMoE config/train_nano_moe.py)
n_exp = 8 # number of experts (matches train_nano_moe.py)
top_k = 2 # number of active experts (matches train_nano_moe.py)
n_exp = _get_env_int("N_EXP", 8) # number of experts (matches train_nano_moe.py)
top_k = _get_env_int("TOP_K", 2) # number of active experts (matches train_nano_moe.py)
use_aux_loss = True # apply auxiliary loss (from Switch Transformer) (matches train_nano_moe.py)
use_router_z_loss = True # apply router z loss (from ST-MoE) (matches train_nano_moe.py)
use_noisy_top_k = False # use noisy top-k routing (matches train_nano_moe.py)
@ -107,11 +117,13 @@ save_every = 10000 # every how many steps to save model checkpoints (-1 = disabl
# System
compile = True # use PyTorch 2.0 to compile the model to be faster (matches nanoMoE)
# Output
model_tag = f"d6_min_lr{min_lr}_max_lr{learning_rate}" # optionally override the model tag for the output checkpoint directory name
model_tag = os.getenv("MODEL_TAG", "") # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
if model_tag == "":
model_tag = f"d{depth}_min_lr{min_lr}_max_lr{learning_rate}"
user_config = {k: globals()[k] for k in config_keys} # will be useful for logging
# -----------------------------------------------------------------------------
@ -145,7 +157,7 @@ print0(f"Vocab size: {vocab_size:,}")
# Model kwargs are derived from the desired depth of the model
# For nanoMoE, we use n_layer, n_head, n_embd directly
n_layer = 6
n_layer = depth
model_dim = 384 # matches train_nano_moe.py
num_heads = 6 # matches train_nano_moe.py
n_head = num_heads

View File

@ -159,11 +159,15 @@ fi
MIN_LR=${MIN_LR:-6e-5}
LEARNING_RATE=${LEARNING_RATE:-6e-4}
DEPTH=${DEPTH:-${N_LAYER:-6}}
MODEL_TAG=${MODEL_TAG:-d${DEPTH}_min_lr${MIN_LR}_max_lr${LEARNING_RATE}}
# Number of processes/GPUs to use
NPROC_PER_NODE=8
# Master port for distributed training (default: 29500)
# Set this to avoid port conflicts when running multiple torchrun tasks simultaneously
# Example: MASTER_PORT=29501 bash speedrun.sh
MASTER_PORT=${MASTER_PORT:-29501}
LOG_TAG=${LOG_TAG:-$(date +%Y%m%d_%H%M%S)}
LOG_FILE=${LOG_FILE:-$NANOCHAT_BASE_DIR/${MODEL_TAG}_${LOG_TAG}.log}
# # # pretrain the d20 model
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train >> $NANOCHAT_BASE_DIR/d6_min_lr${MIN_LR}_max_lr${LEARNING_RATE}.log 2>&1
MASTER_PORT=$MASTER_PORT torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train >> "$LOG_FILE" 2>&1