Compare commits

...

3 Commits

Author SHA1 Message Date
Sofie Van Landeghem
8ef82c45b5
Merge 3b372875c1 into bc1fca39f3 2025-11-15 23:43:42 +08:00
Andrej Karpathy
bc1fca39f3 mqa -> gqa to reduce confusion 2025-11-15 15:43:37 +00:00
konstin
3b372875c1 Manage the Python module with maturin 2025-10-31 15:58:05 +01:00
4 changed files with 14 additions and 7 deletions

1
.gitignore vendored
View File

@ -1,6 +1,7 @@
.venv/ .venv/
__pycache__/ __pycache__/
*.pyc *.pyc
*.so
rustbpe/target/ rustbpe/target/
dev-ignore/ dev-ignore/
report.md report.md

View File

@ -8,7 +8,7 @@ Notable features:
- norm after token embedding - norm after token embedding
- no learnable params in rmsnorm - no learnable params in rmsnorm
- no bias in linear layers - no bias in linear layers
- Multi-Query Attention (MQA) support for more efficient inference - Group-Query Attention (GQA) support for more efficient inference
""" """
import math import math
@ -29,7 +29,7 @@ class GPTConfig:
vocab_size: int = 50304 vocab_size: int = 50304
n_layer: int = 12 n_layer: int = 12
n_head: int = 6 # number of query heads n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA) n_kv_head: int = 6 # number of key/value heads (GQA)
n_embd: int = 768 n_embd: int = 768

View File

@ -149,7 +149,7 @@ class HuggingFaceTokenizer:
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Tokenizer based on rustbpe + tiktoken combo # Tokenizer based on rustbpe + tiktoken combo
import pickle import pickle
import rustbpe from nanochat import rustbpe
import tiktoken import tiktoken
class RustBPETokenizer: class RustBPETokenizer:

View File

@ -23,7 +23,7 @@ requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin" build-backend = "maturin"
[tool.maturin] [tool.maturin]
module-name = "rustbpe" module-name = "nanochat.rustbpe"
bindings = "pyo3" bindings = "pyo3"
python-source = "." python-source = "."
manifest-path = "rustbpe/Cargo.toml" manifest-path = "rustbpe/Cargo.toml"
@ -67,9 +67,15 @@ cpu = [
gpu = [ gpu = [
"torch>=2.8.0", "torch>=2.8.0",
] ]
[tool.uv] [tool.uv]
conflicts = [ cache-keys = [
{ file = "pyproject.toml" },
{ file = "rustbpe/src/**/*.rs" },
{ file = "rustbpe/Cargo.toml" },
{ file = "rustbpe/Cargo.lock" }
]
conflicts = [
[ [
{ extra = "cpu" }, { extra = "cpu" },
{ extra = "gpu" }, { extra = "gpu" },