diff --git a/scripts/base_train.py b/scripts/base_train.py index bb8d8a6..bcbd484 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -106,21 +106,19 @@ vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model +# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division +# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) +# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) num_layers = args.depth -model_dim = args.depth * args.aspect_ratio -def find_num_heads(model_dim, target_head_dim): - # Find num_heads that divides model_dim evenly, with head_dim closest to target. - ideal = max(1, round(model_dim / target_head_dim)) - for offset in range(model_dim): - for candidate in [ideal + offset, ideal - offset]: - if candidate > 0 and model_dim % candidate == 0: - return candidate - return 1 -num_heads = find_num_heads(model_dim, args.head_dim) +base_dim = args.depth * args.aspect_ratio +model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim +num_heads = model_dim // args.head_dim num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +head_dim = model_dim // num_heads print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim}") +print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") print0(f"num_heads: {num_heads}") +print0(f"head_dim: {head_dim}") print0(f"num_kv_heads: {num_kv_heads}") # Optimizer / data / training length related hyperparameters