diff --git a/nanochat/gpt.py b/nanochat/gpt.py index a077256..cb4bd05 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,8 +28,8 @@ from nanochat.flash_attention import flash_attn @dataclass class GPTConfig: - sequence_len: int = 1024 - vocab_size: int = 50304 + sequence_len: int = 2048 + vocab_size: int = 32768 n_layer: int = 12 n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (GQA) @@ -37,7 +37,7 @@ class GPTConfig: # Sliding window attention pattern string, tiled across layers. Final layer always L. # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long - window_pattern: str = "L" + window_pattern: str = "SSSL" def norm(x):