mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-23 03:44:19 +00:00
resolve a crash for odd depths because FA3 needs head_dim % 8 == 0
This commit is contained in:
parent
413e91aa0f
commit
cf5c9e5b8e
|
|
@ -106,21 +106,19 @@ vocab_size = tokenizer.get_vocab_size()
|
|||
print0(f"Vocab size: {vocab_size:,}")
|
||||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division
|
||||
# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly)
|
||||
# (For very small depths, this gives a slight "unfair" advantage to models with odd depths)
|
||||
num_layers = args.depth
|
||||
model_dim = args.depth * args.aspect_ratio
|
||||
def find_num_heads(model_dim, target_head_dim):
|
||||
# Find num_heads that divides model_dim evenly, with head_dim closest to target.
|
||||
ideal = max(1, round(model_dim / target_head_dim))
|
||||
for offset in range(model_dim):
|
||||
for candidate in [ideal + offset, ideal - offset]:
|
||||
if candidate > 0 and model_dim % candidate == 0:
|
||||
return candidate
|
||||
return 1
|
||||
num_heads = find_num_heads(model_dim, args.head_dim)
|
||||
base_dim = args.depth * args.aspect_ratio
|
||||
model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim
|
||||
num_heads = model_dim // args.head_dim
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
head_dim = model_dim // num_heads
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
# Optimizer / data / training length related hyperparameters
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user