ablation + anyscale

This commit is contained in:
Yoyo 2026-03-04 21:20:24 -05:00
parent f7b71341fd
commit 6d0afeacd3
543 changed files with 1776 additions and 9 deletions

View File

@ -24,3 +24,5 @@ uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
# train tokenizer
python -m nanochat.dataset -n 240
# modal set up
pip install modal

View File

@ -37,6 +37,9 @@ class GPTConfig:
# Characters: L=long (full context), S=short (half context)
# Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long
window_pattern: str = "SSSL"
# Ablation options
mlp_type: str = "relu2" # "relu2" (baseline) or "swiglu"
rope_base: int = 10000 # RoPE base theta (10K baseline, 500K for long-context)
def norm(x):
@ -121,14 +124,24 @@ class CausalSelfAttention(nn.Module):
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.mlp_type = config.mlp_type
if config.mlp_type == "swiglu":
# hidden_dim = 8/3 * n_embd to match relu2 param count:
# relu2: n_embd*(4n) + (4n)*n_embd = 8*n^2
# swiglu: n*(h) + n*(h) + h*n = 3*h*n => h = 8/3*n
hidden_dim = int(8 / 3 * config.n_embd)
self.c_gate = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_up = nn.Linear(config.n_embd, hidden_dim, bias=False)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=False)
else: # relu2
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
if self.mlp_type == "swiglu":
return self.c_proj(F.silu(self.c_gate(x)) * self.c_up(x))
else:
return self.c_proj(F.relu(self.c_fc(x)).square())
class Block(nn.Module):
@ -213,7 +226,11 @@ class GPT(nn.Module):
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
torch.nn.init.zeros_(block.attn.c_proj.weight) # projections are zero
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
if block.mlp.mlp_type == "swiglu":
torch.nn.init.uniform_(block.mlp.c_gate.weight, -s, s)
torch.nn.init.uniform_(block.mlp.c_up.weight, -s, s)
else:
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
torch.nn.init.zeros_(block.mlp.c_proj.weight)
# Per-layer scalars
@ -240,8 +257,9 @@ class GPT(nn.Module):
for ve in self.value_embeds.values():
ve.to(dtype=torch.bfloat16)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=None, device=None):
if base is None:
base = self.config.rope_base
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device

View File

@ -0,0 +1 @@
{"rustc_fingerprint":3153245733862020455,"outputs":{"17747080675513052775":{"success":true,"status":"","code":0,"stdout":"rustc 1.93.1 (01f6ddf75 2026-02-11)\nbinary: rustc\ncommit-hash: 01f6ddf7588f42ae2d7eb0a2f21d44e8e96674cf\ncommit-date: 2026-02-11\nhost: x86_64-unknown-linux-gnu\nrelease: 1.93.1\nLLVM version: 21.1.8\n","stderr":""},"7971740275564407648":{"success":true,"status":"","code":0,"stdout":"___\nlib___.rlib\nlib___.so\nlib___.so\nlib___.a\nlib___.so\n/home/yoyo/.rustup/toolchains/stable-x86_64-unknown-linux-gnu\noff\npacked\nunpacked\n___\ndebug_assertions\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"x86_64\"\ntarget_endian=\"little\"\ntarget_env=\"gnu\"\ntarget_family=\"unix\"\ntarget_feature=\"fxsr\"\ntarget_feature=\"sse\"\ntarget_feature=\"sse2\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"linux\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"unknown\"\nunix\n","stderr":""}},"successes":{}}

Binary file not shown.

View File

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"","declared_features":"","target":0,"profile":0,"path":0,"deps":[[966925859616469517,"build_script_build",false,18326824378327578846]],"local":[{"RerunIfChanged":{"output":"release/build/ahash-6e0ddbeeb0963122/output","paths":["build.rs"]}}],"rustflags":[],"config":0,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
768837cee41ae808

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"getrandom\", \"runtime-rng\", \"std\"]","declared_features":"[\"atomic-polyfill\", \"compile-time-rng\", \"const-random\", \"default\", \"getrandom\", \"nightly-arm-aes\", \"no-rng\", \"runtime-rng\", \"serde\", \"std\"]","target":8470944000320059508,"profile":2040997289075261528,"path":16416957882855069506,"deps":[[966925859616469517,"build_script_build",false,2183787099258598574],[3331586631144870129,"getrandom",false,17730662864604905174],[3722963349756955755,"once_cell",false,4314456024153461717],[7843059260364151289,"cfg_if",false,10089321399605156288],[14131061446229887432,"zerocopy",false,11340975796785757038]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/ahash-a1f007138bd92198/dep-lib-ahash","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"getrandom\", \"runtime-rng\", \"std\"]","declared_features":"[\"atomic-polyfill\", \"compile-time-rng\", \"const-random\", \"default\", \"getrandom\", \"nightly-arm-aes\", \"no-rng\", \"runtime-rng\", \"serde\", \"std\"]","target":17883862002600103897,"profile":1369601567987815722,"path":11330912615325810443,"deps":[[5398981501050481332,"version_check",false,8987439642223450148]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/ahash-cd37cac257383fa2/dep-build-script-build-script-build","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
3502d0d0115d6431

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"perf-literal\", \"std\"]","declared_features":"[\"default\", \"logging\", \"perf-literal\", \"std\"]","target":7534583537114156500,"profile":2040997289075261528,"path":4645813219967419164,"deps":[[15932120279885307830,"memchr",false,10396455531107345722]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/aho-corasick-e3889325011ae6bb/dep-lib-aho_corasick","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
0c59d0718144ab20

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"experimental-strategies\", \"experimental-thread-local\", \"internal-test-strategies\", \"serde\", \"weak\"]","target":8262801893777646146,"profile":2040997289075261528,"path":2759164353353770132,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/arc-swap-abf696bd8cc9a491/dep-lib-arc_swap","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
03efeb1a60525ba5

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[]","target":6962977057026645649,"profile":1369601567987815722,"path":8461504647337710719,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/autocfg-aba430aa48ce4fb3/dep-lib-autocfg","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
81703cf40e463393

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"std\"]","declared_features":"[\"default\", \"serde\", \"std\"]","target":1565461888733056401,"profile":2040997289075261528,"path":1592813740539531122,"deps":[[5692597712387868707,"bit_vec",false,17444092127660309242]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/bit-set-6b8f42ad13979896/dep-lib-bit_set","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
fa7a610a5edd15f2

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"std\"]","declared_features":"[\"borsh\", \"borsh_std\", \"default\", \"miniserde\", \"nanoserde\", \"serde\", \"serde_no_std\", \"serde_std\", \"std\"]","target":1886748672988989682,"profile":2040997289075261528,"path":16673263365007446360,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/bit-vec-9b94034b2e5140d5/dep-lib-bit_vec","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
71fda64eef503318

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"alloc\"]","declared_features":"[\"alloc\", \"default\", \"std\"]","target":13710694652376480987,"profile":2040997289075261528,"path":10772041731947363315,"deps":[[14156967978702956262,"rustversion",false,12357750521860089894]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/castaway-9c60905fea19c288/dep-lib-castaway","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
c05936805b78048c

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"core\", \"rustc-dep-of-std\"]","target":13840298032947503755,"profile":2040997289075261528,"path":15565018148930908515,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/cfg-if-5cae3b2ea5837b6a/dep-lib-cfg_if","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
b551f2562703ee58

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"std\"]","declared_features":"[\"arbitrary\", \"borsh\", \"bytes\", \"default\", \"diesel\", \"markup\", \"proptest\", \"quickcheck\", \"rkyv\", \"serde\", \"smallvec\", \"sqlx\", \"sqlx-mysql\", \"sqlx-postgres\", \"sqlx-sqlite\", \"std\", \"zeroize\"]","target":7968499388442294171,"profile":2040997289075261528,"path":13933458373579383538,"deps":[[1127187624154154345,"castaway",false,1743826469469486449],[1216309103264968120,"ryu",false,13475922803077515806],[7695812897323945497,"itoa",false,8880013175622664606],[7843059260364151289,"cfg_if",false,10089321399605156288],[13785866025199020095,"static_assertions",false,3779651264367808892],[14156967978702956262,"rustversion",false,12357750521860089894]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/compact_str-9bf829de7b24ec74/dep-lib-compact_str","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"std\"]","declared_features":"[\"default\", \"std\"]","target":15353977948366730291,"profile":14791228037615401302,"path":244282978214369260,"deps":[[3528074118530651198,"crossbeam_epoch",false,12084335119325868716],[4468123440088164316,"crossbeam_utils",false,4066709884124278728]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/crossbeam-deque-9a3abdc3674f18e8/dep-lib-crossbeam_deque","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"alloc\", \"std\"]","declared_features":"[\"alloc\", \"default\", \"loom\", \"loom-crate\", \"nightly\", \"std\"]","target":5830366855417007734,"profile":2040997289075261528,"path":5214367671550285343,"deps":[[4468123440088164316,"crossbeam_utils",false,4066709884124278728]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/crossbeam-epoch-85d9d4dbce2834b5/dep-lib-crossbeam_epoch","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"std\"]","declared_features":"[\"default\", \"loom\", \"nightly\", \"std\"]","target":5408242616063297496,"profile":1419616050453328851,"path":14306531806754996452,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/crossbeam-utils-0798667d6011c212/dep-build-script-build-script-build","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"","declared_features":"","target":0,"profile":0,"path":0,"deps":[[4468123440088164316,"build_script_build",false,3378941045167230312]],"local":[{"RerunIfChanged":{"output":"release/build/crossbeam-utils-3a211e20cc1ffb53/output","paths":["no_atomic.rs"]}}],"rustflags":[],"config":0,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"std\"]","declared_features":"[\"default\", \"loom\", \"nightly\", \"std\"]","target":9626079250877207070,"profile":14791228037615401302,"path":10014696583203501017,"deps":[[4468123440088164316,"build_script_build",false,11747075223134917507]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/crossbeam-utils-6c429d87be32f91e/dep-lib-crossbeam_utils","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
eddf97c895cfd4dc

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"extra\", \"serde\", \"unstable\", \"unstable_nightly\"]","target":1019866667645897510,"profile":2040997289075261528,"path":16896791896759728600,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/dary_heap-e91823f3b0b21e52/dep-lib-dary_heap","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
cdc7b6513fa06e7d

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"default\", \"serde\", \"std\", \"use_std\"]","target":17124342308084364240,"profile":2040997289075261528,"path":4332997825135745380,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/either-e7d9422f2c8df24b/dep-lib-either","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
71ffc4d84eda4492

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[]","target":1524667692659508025,"profile":2040997289075261528,"path":4581097645938585182,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/equivalent-b7fca33b7ad402aa/dep-lib-equivalent","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
58c377dd4129682a

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[\"default\", \"perf\", \"std\", \"unicode\"]","declared_features":"[\"default\", \"perf\", \"std\", \"track_caller\", \"unicode\"]","target":1052671945528592274,"profile":2040997289075261528,"path":8103441100212571416,"deps":[[7507008215594894126,"regex_syntax",false,8010140734463789758],[9519969280819313548,"bit_set",false,10606898577428738177],[16311927252525485886,"regex_automata",false,4709622341456386799]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/fancy-regex-047b2b18fb3cb285/dep-lib-fancy_regex","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
d63efc2ceff70ff6

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"rustc-dep-of-std\", \"std\", \"wasm_js\"]","target":11669924403970522481,"profile":2896712256932847751,"path":1777140628112724188,"deps":[[3331586631144870129,"build_script_build",false,13351924516332481098],[7843059260364151289,"cfg_if",false,10089321399605156288],[11887305395906501191,"libc",false,10424884877148727811]],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/getrandom-de50dc268e3e9e3b/dep-lib-getrandom","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"","declared_features":"","target":0,"profile":0,"path":0,"deps":[[3331586631144870129,"build_script_build",false,3128377812680934069]],"local":[{"RerunIfChanged":{"output":"release/build/getrandom-fde142ed4afd426e/output","paths":["build.rs"]}}],"rustflags":[],"config":0,"compile_kind":0}

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"rustc-dep-of-std\", \"std\", \"wasm_js\"]","target":5408242616063297496,"profile":7474683644146943920,"path":618837956153051478,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/getrandom-fe4f3207607d445f/dep-build-script-build-script-build","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
532b0305907161d2

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[\"alloc\", \"allocator-api2\", \"core\", \"default\", \"default-hasher\", \"equivalent\", \"inline-more\", \"nightly\", \"raw-entry\", \"rayon\", \"rustc-dep-of-std\", \"rustc-internal-api\", \"serde\"]","target":13796197676120832388,"profile":2040997289075261528,"path":5190282083066847217,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/hashbrown-73de462bcf2ff96a/dep-lib-hashbrown","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

View File

@ -0,0 +1 @@
This file has an mtime of when this was started.

View File

@ -0,0 +1 @@
f295b09387ad7550

View File

@ -0,0 +1 @@
{"rustc":1100337564441796057,"features":"[]","declared_features":"[]","target":17886154901722686619,"profile":1369601567987815722,"path":6806002441689259226,"deps":[],"local":[{"CheckDepInfo":{"dep_info":"release/.fingerprint/heck-b2d11fd1520b7534/dep-lib-heck","checksum":false}}],"rustflags":[],"config":2069994364910194474,"compile_kind":0}

Some files were not shown because too many files have changed in this diff Show More