mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-03 10:12:42 +00:00
add head_dim, num_heads, num_kv_heads, depth_to_width_ratio as arguments to base_train.py to allow modeling flexibility
This commit is contained in:
parent
cf587acb1a
commit
1688ba9597
|
|
@ -36,6 +36,14 @@ run = "dummy" # wandb run name default ("dummy" is special - we won't log to wan
|
||||||
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
device_type = "" # cuda|cpu|mps (empty => autodetect good device type default, in order: CUDA > MPS > CPU)
|
||||||
# Model architecture
|
# Model architecture
|
||||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||||
|
|
||||||
|
depth_to_width_scale = 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
||||||
|
# model_dim (width) = depth * depth_to_width_scale
|
||||||
|
|
||||||
|
head_dim = 128 # safe default
|
||||||
|
num_heads = -1 # inferred by head_dim by default. If num_heads is specified, default calculation of num_heads using head_dim will be ignored (see below)
|
||||||
|
num_kv_heads = -1 # same as num_heads by default. If specified, num_heads should be divisible by num_kv_heads
|
||||||
|
|
||||||
max_seq_len = 2048 # max context length
|
max_seq_len = 2048 # max context length
|
||||||
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
# Training horizon. Only one of these 3 will be used, in this order of precedence.
|
||||||
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
num_iterations = -1 # explicit number of steps of the optimization (-1 = disable)
|
||||||
|
|
@ -86,11 +94,20 @@ print0(f"Vocab size: {vocab_size:,}")
|
||||||
|
|
||||||
# Model kwargs are derived from the desired depth of the model
|
# Model kwargs are derived from the desired depth of the model
|
||||||
num_layers = depth
|
num_layers = depth
|
||||||
model_dim = depth * 64 # aspect ratio 64 (usually this is varied from 64 -> 128 as model size increases)
|
model_dim = depth * depth_to_width_scale # model_dim = depth * 64 by default
|
||||||
num_heads = max(1, (model_dim + 127) // 128) # head dim 128 (the division here is ceil div)
|
if num_heads == -1:
|
||||||
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
num_heads = max(1, (model_dim + (head_dim - 1)) // head_dim) # head dim 128 (the division here is ceil div)
|
||||||
|
else:
|
||||||
|
head_dim = model_dim // num_heads # num_heads is updated, so head_dim is overwritten
|
||||||
|
if num_kv_heads == -1:
|
||||||
|
num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled)
|
||||||
|
|
||||||
|
assert num_heads * head_dim == model_dim, f"num_heads ({num_heads}) * head_dim ({head_dim}) != model_dim ({model_dim})"
|
||||||
|
assert num_heads % num_kv_heads == 0, f"num_heads ({num_heads}) % num_kv_heads ({num_kv_heads}) != 0"
|
||||||
|
|
||||||
print0(f"num_layers: {num_layers}")
|
print0(f"num_layers: {num_layers}")
|
||||||
print0(f"model_dim: {model_dim}")
|
print0(f"model_dim: {model_dim}")
|
||||||
|
print0(f"head_dim: {head_dim}")
|
||||||
print0(f"num_heads: {num_heads}")
|
print0(f"num_heads: {num_heads}")
|
||||||
print0(f"num_kv_heads: {num_kv_heads}")
|
print0(f"num_kv_heads: {num_kv_heads}")
|
||||||
|
|
||||||
|
|
@ -347,3 +364,4 @@ get_report().log(section="Base model training", data=[
|
||||||
# cleanup
|
# cleanup
|
||||||
wandb_run.finish() # wandb run finish
|
wandb_run.finish() # wandb run finish
|
||||||
compute_cleanup()
|
compute_cleanup()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user