modal script update

This commit is contained in:
Yoyo 2026-03-05 11:35:02 -05:00
parent 841849cdb8
commit 4ac7562d3d

View File

@ -243,7 +243,7 @@ def stage_pretrain(
f"--run={run_name}",
f"--model-tag={model_tag}",
f"--core-metric-every={CORE_METRIC_EVERY}", # skip expensive CORE eval
"--save-every=-1", # only save at end (picochat is fast)
"--save-every=500", # checkpoint every 500 steps (survive disconnects)
"--eval-every=100", # val/bpb every 100 steps for dense W&B curves
],
)
@ -290,28 +290,35 @@ def run_rope500k() -> None:
# =============================================================================
# MAIN ENTRYPOINT: run all 3 ablations
# PIPELINE ORCHESTRATOR (runs on Modal servers, not locally)
# =============================================================================
@app.local_entrypoint()
def main() -> None:
@app.function(
image=image,
secrets=[secret],
volumes={VOLUME_MOUNT: volume},
cpu=1,
timeout=60 * 60 * 5, # 5 hours: enough for tokenizer + 3 training runs
)
def run_pipeline(num_shards: int = NUM_SHARDS) -> None:
"""
Full ablation pipeline:
1. Download data (once)
2. Train tokenizer (once)
3. Baseline relu2, RoPE 10K
4. SwiGLU swiglu, RoPE 10K
5. RoPE-500K relu2, RoPE 500K
Full ablation pipeline running entirely on Modal servers.
Called via .spawn() from main() so your laptop can close immediately.
Runs are sequential to stay within budget (~$5-9 total on A10G).
Data and tokenizer are shared across all 3 training runs via the volume.
Stages:
1. Download data (idempotent skips shards already on volume)
2. Train tokenizer (idempotent skips if tokenizer.pkl exists)
3. Baseline relu2, RoPE 10K (~50 min)
4. SwiGLU swiglu, RoPE 10K (~50 min)
5. RoPE-500K relu2, RoPE 500K (~50 min)
"""
_setup_cache()
print("\n" + "="*60)
print("Picochat Ablation Study | yoyoliuuu/nanochat | W&B")
print("="*60 + "\n")
print("[1/5] Downloading data shards...")
stage_data.remote(num_shards=NUM_SHARDS)
stage_data.remote(num_shards=num_shards)
print("[2/5] Training tokenizer...")
stage_tokenizer.remote()
@ -343,3 +350,19 @@ def main() -> None:
print("\n" + "="*60)
print("All done! Check W&B at wandb.ai/yoyoliuuu/nanochat")
print("="*60 + "\n")
# =============================================================================
# MAIN ENTRYPOINT: just submits the pipeline and exits immediately
# =============================================================================
@app.local_entrypoint()
def main() -> None:
"""
Submit the full pipeline to Modal and return immediately.
The pipeline runs entirely on Modal servers close your laptop anytime.
Monitor at: wandb.ai/yoyoliuuu/nanochat or modal.com/apps
"""
print("Submitting pipeline to Modal (runs server-side, safe to close terminal)...")
run_pipeline.spawn()
print("Submitted! Monitor at wandb.ai/yoyoliuuu/nanochat")