mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-20 15:58:41 +00:00
On branch feature-add-mamba-arch-support
Changes to be committed:
new file: IMPLEMENTATION_SUMMARY.md
new file: MAMBA_INTEGRATION.md
new file: QUICKSTART_MAMBA.md
new file: configs/README.md
new file: configs/hybrid_alternating_d20.py
new file: configs/hybrid_early_t_late_m_d20.py
new file: configs/mamba_d20.py
new file: configs/rtx3070_d16.py
new file: configs/transformer_d20.py
new file: nanochat/blocks/__init__.py
new file: nanochat/blocks/mamba_block.py
new file: nanochat/blocks/transformer_block.py
modified: nanochat/checkpoint_manager.py
modified: nanochat/gpt.py
new file: tests/test_hybrid_blocks.py
32 lines
967 B
Python
32 lines
967 B
Python
# Pure Mamba configuration for d20 model (561M parameters)
|
|
# This replaces all transformer blocks with Mamba SSM blocks
|
|
# Expected benefits: faster training/inference, lower memory, better long-range modeling
|
|
|
|
# Model architecture
|
|
depth = 20
|
|
block_pattern = ["M"] * 20 # All Mamba blocks
|
|
|
|
# Mamba-specific parameters
|
|
mamba_d_state = 16 # Conservative state dimension for 12GB GPUs
|
|
mamba_d_conv = 4 # Standard convolution kernel
|
|
mamba_expand = 2 # Standard expansion factor
|
|
mamba_use_mlp = False # Mamba has built-in gating, MLP often redundant
|
|
|
|
# Training (same as base_train.py defaults)
|
|
max_seq_len = 2048
|
|
device_batch_size = 32 # Can potentially use more since no attention overhead
|
|
total_batch_size = 524288
|
|
target_param_data_ratio = 20 # Chinchilla ratio
|
|
|
|
# Optimization
|
|
embedding_lr = 0.2
|
|
unembedding_lr = 0.004
|
|
matrix_lr = 0.02
|
|
weight_decay = 0.0
|
|
grad_clip = 1.0
|
|
|
|
# For 12GB GPUs, use:
|
|
# device_batch_size = 4
|
|
# max_seq_len = 1024
|
|
|