nanochat/speedrun_spark_ja.sh
karaage0703 55f8d4acf2 Fix OOM in Japanese tokenizer training by reducing max_chars
Reduce --max_chars from 2B to 500M characters to prevent OOM during
BPE training. Japanese text generates significantly more unique
sequences than English (92M+ unique sequences observed), causing
memory exhaustion during heap construction.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 12:09:24 +09:00

68 lines
2.0 KiB
Bash
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env bash
# speedrun_spark_ja.sh — 日本語学習用スクリプト (DGX Spark 単GPU版)
# 日本語データセット (fineweb-2-edu-japanese) で学習を実行
set -euo pipefail
# ===== ユーザー設定 =====
DEPTH=20
DEVICE_BATCH_SIZE=16
DATA_SHARDS=30
NUM_ITERATIONS=1000
#NUM_ITERATIONS=10
CACHE_DIR="$HOME/.cache/nanochat"
# ========================
# --- 日本語言語設定 ---
export NANOCHAT_LANG=ja
# --- 実行環境・OOM対策 ---
export PYTORCH_ALLOC_CONF="expandable_segments:True,max_split_size_mb:256"
export TORCHDYNAMO_DISABLE=1
export TORCHINDUCTOR_DISABLE=1
# ---- 計測開始 ----
T0=$(date +%s)
echo "=== nanochat 日本語学習 speedrun (single GPU on DGX Spark) ==="
echo "DEPTH=${DEPTH}, DEVICE_BATCH_SIZE=${DEVICE_BATCH_SIZE}, LANG=${NANOCHAT_LANG}"
python - <<'PY'
import torch
print("torch", torch.__version__, "cuda", torch.version.cuda)
print("gpu", torch.cuda.get_device_name(0), "cc", torch.cuda.get_device_capability(0))
PY
echo "== 1) 日本語データ準備 =="
python -m nanochat.dataset -n "${DATA_SHARDS}" --lang ja
echo "== 2) 日本語トークナイザ学習 =="
python -m scripts.tok_train --max_chars=500000000
python -m scripts.tok_eval || true
ls -l "${CACHE_DIR}/tokenizer" || true
echo "== 3) BASE (pretrain) =="
python -m scripts.base_train \
--depth="${DEPTH}" \
--device_batch_size="${DEVICE_BATCH_SIZE}" \
--num_iterations="${NUM_ITERATIONS}"
echo "== 4) MID =="
python -m scripts.mid_train \
--device_batch_size="${DEVICE_BATCH_SIZE}" \
--num_iterations="${NUM_ITERATIONS}"
echo "== 5) SFT =="
python -m scripts.chat_sft \
--device_batch_size="${DEVICE_BATCH_SIZE}" \
--num_iterations="${NUM_ITERATIONS}"
# echo "== 6) 日本語評価 =="
# python -m scripts.chat_eval -i sft
# ---- 計測終了&表示 ----
T1=$(date +%s)
ELAPSED=$((T1 - T0))
printf "\n== SUMMARY ==\nTotal elapsed: %d s (%02d:%02d:%02d)\n" \
"$ELAPSED" "$((ELAPSED/3600))" "$(((ELAPSED%3600)/60))" "$((ELAPSED%60))"
echo "✅ 日本語学習完了Web UI → python -m scripts.chat_web"