mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
add head_dim, num_heads, num_kv_heads, depth_to_width_ratio as arguments to base_train.py to allow modeling flexibility
This commit is contained in:
parent
cf587acb1a
commit
1688ba9597
|
|
@ -36,6 +36,14 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
|
|||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
|
||||
depth_to_width_scale = 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
# model_dim (width) = depth * depth_to_width_scale
|
||||
|
||||
head_dim = 128 # safe default
|
||||
num_heads = -1 # inferred by head_dim by default. If num_heads is specified, default calculation of num_heads using head_dim will be ignored (see below)
|
||||
num_kv_heads = -1 # same as num_heads by default. If specified, num_heads should be divisible by num_kv_heads
|
||||
|
||||
max_seq_len = 2048 # max context length
|
||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||
|
|
@ -86,11 +94,20 @@ print0(f"Vocab size: {vocab_size:,}")
|
|||
|
||||
# Model kwargs are derived from the desired depth of the model
|
||||
num_layers = depth
|
||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
model_dim = depth * depth_to_width_scale # model_dim = depth * 64 by default
|
||||
if num_heads == -1:
|
||||
num_heads = max(1, (model_dim + (head_dim - 1)) // head_dim) # head dim 128 (the division here is ceil div)
|
||||
else:
|
||||
head_dim = model_dim // num_heads # num_heads is updated, so head_dim is overwritten
|
||||
if num_kv_heads == -1:
|
||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||
|
||||
assert num_heads * head_dim == model_dim, f"num_heads ({num_heads}) * head_dim ({head_dim}) != model_dim ({model_dim})"
|
||||
assert num_heads % num_kv_heads == 0, f"num_heads ({num_heads}) % num_kv_heads ({num_kv_heads}) != 0"
|
||||
|
||||
print0(f"num_layers: {num_layers}")
|
||||
print0(f"model_dim: {model_dim}")
|
||||
print0(f"head_dim: {head_dim}")
|
||||
print0(f"num_heads: {num_heads}")
|
||||
print0(f"num_kv_heads: {num_kv_heads}")
|
||||
|
||||
|
|
@ -347,3 +364,4 @@ get_report().log(section="Base model training", data=[
|
|||
# cleanup
|
||||
wandb_run.finish() # wandb run finish
|
||||
compute_cleanup()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user