nanochat/configs/mamba_d20.py
CadBane d7c1db6408 Added Mamba architecture support
On branch feature-add-mamba-arch-support
 Changes to be committed:
	new file:   IMPLEMENTATION_SUMMARY.md
	new file:   MAMBA_INTEGRATION.md
	new file:   QUICKSTART_MAMBA.md
	new file:   configs/README.md
	new file:   configs/hybrid_alternating_d20.py
	new file:   configs/hybrid_early_t_late_m_d20.py
	new file:   configs/mamba_d20.py
	new file:   configs/rtx3070_d16.py
	new file:   configs/transformer_d20.py
	new file:   nanochat/blocks/__init__.py
	new file:   nanochat/blocks/mamba_block.py
    new file:   nanochat/blocks/transformer_block.py
	modified:   nanochat/checkpoint_manager.py
	modified:   nanochat/gpt.py
	new file:   tests/test_hybrid_blocks.py
2025-10-15 10:32:22 +02:00

32 lines
967 B
Python

# Pure Mamba configuration for d20 model (561M parameters)
# This replaces all transformer blocks with Mamba SSM blocks
# Expected benefits: faster training/inference, lower memory, better long-range modeling
# Model architecture
depth = 20
block_pattern = ["M"] * 20 # All Mamba blocks
# Mamba-specific parameters
mamba_d_state = 16 # Conservative state dimension for 12GB GPUs
mamba_d_conv = 4 # Standard convolution kernel
mamba_expand = 2 # Standard expansion factor
mamba_use_mlp = False # Mamba has built-in gating, MLP often redundant
# Training (same as base_train.py defaults)
max_seq_len = 2048
device_batch_size = 32 # Can potentially use more since no attention overhead
total_batch_size = 524288
target_param_data_ratio = 20 # Chinchilla ratio
# Optimization
embedding_lr = 0.2
unembedding_lr = 0.004
matrix_lr = 0.02
weight_decay = 0.0
grad_clip = 1.0
# For 12GB GPUs, use:
# device_batch_size = 4
# max_seq_len = 1024