diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..8b220c3 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -8,7 +8,7 @@ Notable features: - norm after token embedding - no learnable params in rmsnorm - no bias in linear layers -- Multi-Query Attention (MQA) support for more efficient inference +- Group-Query Attention (GQA) support for more efficient inference """ import math @@ -29,7 +29,7 @@ class GPTConfig: vocab_size: int = 50304 n_layer: int = 12 n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (MQA) + n_kv_head: int = 6 # number of key/value heads (GQA) n_embd: int = 768