mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Add training continuation script and update MacOS guide
Introduces continue_training.sh to automatically resume interrupted training stages by detecting existing checkpoints and proceeding as needed. Updates README_MACOS.md with instructions and troubleshooting for using the new script, including manual continuation steps and improved guidance for memory, architecture, and performance issues.
This commit is contained in:
parent
b81d789992
commit
e83d633179
|
|
@ -140,21 +140,134 @@ bash dev/runmac_overnight.sh
|
|||
DEPTH=8 BASE_ITERATIONS=1000 bash dev/runmac_overnight.sh
|
||||
```
|
||||
|
||||
## Continuing Training After Interruption
|
||||
|
||||
### Use `continue_training.sh` (Recommended)
|
||||
|
||||
If training was interrupted or you want to continue from existing checkpoints:
|
||||
|
||||
```bash
|
||||
bash dev/continue_training.sh
|
||||
```
|
||||
|
||||
**What it does:**
|
||||
- ✅ Checks for existing base/mid/sft checkpoints
|
||||
- ✅ Automatically continues from where you left off
|
||||
- ✅ Skips completed stages
|
||||
- ✅ Matches model tags (d4, d6, d8) correctly
|
||||
- ✅ Uses memory-optimized batch sizes
|
||||
|
||||
**Example scenarios:**
|
||||
|
||||
1. **Base training completed, but mid/sft interrupted:**
|
||||
```
|
||||
Status:
|
||||
✓ Base model: d8/step_001000
|
||||
✗ Midtraining: Not found
|
||||
|
||||
→ Will run: Midtraining → SFT
|
||||
```
|
||||
|
||||
2. **Base and mid complete, only need SFT:**
|
||||
```
|
||||
Status:
|
||||
✓ Base model: d8/step_001000
|
||||
✓ Midtraining: d8/step_000150
|
||||
✗ SFT: Not found
|
||||
|
||||
→ Will run: SFT only
|
||||
```
|
||||
|
||||
3. **Everything complete:**
|
||||
```
|
||||
Status:
|
||||
✓ Base model: d8/step_001000
|
||||
✓ Midtraining: d8/step_000150
|
||||
✓ SFT: d8/step_000150
|
||||
|
||||
🎉 All training stages complete!
|
||||
→ Ready to chat!
|
||||
```
|
||||
|
||||
### Manual Continuation
|
||||
|
||||
If you prefer manual control:
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
|
||||
# Continue midtraining from existing base model
|
||||
python -m scripts.mid_train \
|
||||
--num_iterations=150 \
|
||||
--device_batch_size=16
|
||||
|
||||
# Continue SFT from existing mid model
|
||||
python -m scripts.chat_sft \
|
||||
--num_iterations=150 \
|
||||
--device_batch_size=16
|
||||
|
||||
# Chat with the result
|
||||
python -m scripts.chat_cli -i sft
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Training Won't Start
|
||||
|
||||
**Error: `AssertionError: total_batch_size must be divisible by...`**
|
||||
|
||||
Fix: Ensure `total_batch_size` is divisible by `device_batch_size × max_seq_len`
|
||||
```bash
|
||||
# For max_seq_len=1024:
|
||||
# device_batch_size=16 → total_batch_size=16384 (16 × 1024)
|
||||
# device_batch_size=8 → total_batch_size=8192 (8 × 1024)
|
||||
```
|
||||
|
||||
**Error: `split_tokens must be divisible by tokens_per_step`**
|
||||
|
||||
Fix: Pass `--device_batch_size` to base_loss:
|
||||
```bash
|
||||
python -m scripts.base_loss --device_batch_size=16 --split_tokens=16384
|
||||
```
|
||||
|
||||
### Architecture Issues
|
||||
|
||||
**Running x86_64 Python on ARM64 Mac (Rosetta 2)**
|
||||
|
||||
Check your Python architecture:
|
||||
```bash
|
||||
file .venv/bin/python
|
||||
# Should show: Mach-O 64-bit executable arm64
|
||||
# Bad: Mach-O 64-bit executable x86_64
|
||||
```
|
||||
|
||||
Fix: Recreate venv with native ARM64 Python:
|
||||
```bash
|
||||
rm -rf .venv
|
||||
uv venv --python /opt/homebrew/opt/python@3.10/bin/python3.10
|
||||
uv sync
|
||||
maturin develop --release
|
||||
```
|
||||
|
||||
**Performance impact:** Native ARM64 is ~2-3× faster than Rosetta 2!
|
||||
|
||||
### Memory & Performance Issues
|
||||
|
||||
**Script fails with memory errors:**
|
||||
- Reduce `MEMORY_SIZE=64` or `DEVICE_BATCH_SIZE=8`
|
||||
- Reduce `DEPTH=4`
|
||||
- Close other applications
|
||||
|
||||
**Training is slow:**
|
||||
- Check memory profile is correct: `sysctl hw.memsize`
|
||||
- Ensure MPS is being used: Check logs for "Autodetected device type: mps"
|
||||
- Close other applications
|
||||
- Check memory profile: `sysctl hw.memsize`
|
||||
- Verify MPS: Check logs for "Autodetected device type: mps"
|
||||
- Verify ARM64: `file .venv/bin/python` should show `arm64`
|
||||
- Check CPU usage: Should be 80-100% on one core
|
||||
|
||||
**Chat responses are still poor:**
|
||||
- Increase iterations: `BASE_ITERATIONS=1000 MID_ITERATIONS=300 SFT_ITERATIONS=300`
|
||||
- Download more data: `DATA_SHARDS=100`
|
||||
- Increase model size: `DEPTH=8` (warning: needs more memory)
|
||||
- Increase model size: `DEPTH=8` (needs more memory)
|
||||
|
||||
## Running in Background
|
||||
|
||||
|
|
|
|||
230
dev/continue_training.sh
Executable file
230
dev/continue_training.sh
Executable file
|
|
@ -0,0 +1,230 @@
|
|||
#!/bin/bash
|
||||
# Smart training continuation script
|
||||
# Checks for existing checkpoints and continues from where you left off
|
||||
|
||||
set -e
|
||||
|
||||
echo "=================================="
|
||||
echo "nanochat Training Continuation"
|
||||
echo "=================================="
|
||||
echo "Started: $(date)"
|
||||
echo ""
|
||||
|
||||
# Activate virtual environment
|
||||
source .venv/bin/activate
|
||||
|
||||
# Memory-based configuration (same as runmac_overnight.sh)
|
||||
if [ -z "$MEMORY_SIZE" ]; then
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
MEMORY_SIZE=$(sysctl hw.memsize | awk '{print int($2/1024/1024/1024)}')
|
||||
echo "Auto-detected memory: ${MEMORY_SIZE}GB"
|
||||
else
|
||||
MEMORY_SIZE=16
|
||||
fi
|
||||
fi
|
||||
|
||||
# Calculate optimal batch sizes
|
||||
if [ $MEMORY_SIZE -ge 128 ]; then
|
||||
DEVICE_BATCH_SIZE=16
|
||||
TOTAL_BATCH_SIZE=16384
|
||||
EVAL_TOKENS=16384
|
||||
SPLIT_TOKENS=16384
|
||||
elif [ $MEMORY_SIZE -ge 64 ]; then
|
||||
DEVICE_BATCH_SIZE=8
|
||||
TOTAL_BATCH_SIZE=8192
|
||||
EVAL_TOKENS=8192
|
||||
SPLIT_TOKENS=8192
|
||||
elif [ $MEMORY_SIZE -ge 32 ]; then
|
||||
DEVICE_BATCH_SIZE=4
|
||||
TOTAL_BATCH_SIZE=4096
|
||||
EVAL_TOKENS=4096
|
||||
SPLIT_TOKENS=4096
|
||||
else
|
||||
DEVICE_BATCH_SIZE=1
|
||||
TOTAL_BATCH_SIZE=1024
|
||||
EVAL_TOKENS=2048
|
||||
SPLIT_TOKENS=2048
|
||||
fi
|
||||
|
||||
# Allow manual overrides
|
||||
DEVICE_BATCH_SIZE=${DEVICE_BATCH_SIZE:-16}
|
||||
MID_ITERATIONS=${MID_ITERATIONS:-150}
|
||||
SFT_ITERATIONS=${SFT_ITERATIONS:-150}
|
||||
|
||||
echo "Configuration:"
|
||||
echo " Memory: ${MEMORY_SIZE}GB"
|
||||
echo " Device batch size: $DEVICE_BATCH_SIZE"
|
||||
echo " Total batch size: $TOTAL_BATCH_SIZE"
|
||||
echo ""
|
||||
|
||||
# Check what exists
|
||||
CACHE_DIR="$HOME/.cache/nanochat"
|
||||
BASE_DIR="$CACHE_DIR/base_checkpoints"
|
||||
MID_DIR="$CACHE_DIR/mid_checkpoints"
|
||||
SFT_DIR="$CACHE_DIR/sft_checkpoints"
|
||||
|
||||
echo "Checking existing checkpoints..."
|
||||
echo ""
|
||||
|
||||
# Function to find latest checkpoint and extract tag
|
||||
find_latest_checkpoint() {
|
||||
local dir=$1
|
||||
if [ ! -d "$dir" ]; then
|
||||
echo "none"
|
||||
return
|
||||
fi
|
||||
# Find the latest model tag directory
|
||||
local latest_tag=$(ls -1 "$dir" 2>/dev/null | grep -E "^d[0-9]+$" | sort -V | tail -1)
|
||||
if [ -z "$latest_tag" ]; then
|
||||
echo "none"
|
||||
return
|
||||
fi
|
||||
# Find the latest step in that tag
|
||||
local latest_step=$(ls -1 "$dir/$latest_tag" 2>/dev/null | grep -E "^model_[0-9]+\.pt$" | sed 's/model_//;s/\.pt//' | sort -n | tail -1)
|
||||
if [ -z "$latest_step" ]; then
|
||||
echo "none"
|
||||
return
|
||||
fi
|
||||
echo "$latest_tag/step_$latest_step"
|
||||
}
|
||||
|
||||
BASE_CHECKPOINT=$(find_latest_checkpoint "$BASE_DIR")
|
||||
MID_CHECKPOINT=$(find_latest_checkpoint "$MID_DIR")
|
||||
SFT_CHECKPOINT=$(find_latest_checkpoint "$SFT_DIR")
|
||||
|
||||
# Extract base model tag (e.g., "d8" from "d8/step_001000")
|
||||
BASE_TAG=$(echo $BASE_CHECKPOINT | cut -d'/' -f1)
|
||||
MID_TAG=$(echo $MID_CHECKPOINT | cut -d'/' -f1)
|
||||
SFT_TAG=$(echo $SFT_CHECKPOINT | cut -d'/' -f1)
|
||||
|
||||
echo "Status:"
|
||||
if [ "$BASE_CHECKPOINT" != "none" ]; then
|
||||
echo " ✓ Base model: $BASE_CHECKPOINT"
|
||||
else
|
||||
echo " ✗ Base model: Not found"
|
||||
fi
|
||||
|
||||
if [ "$MID_CHECKPOINT" != "none" ]; then
|
||||
echo " ✓ Midtraining: $MID_CHECKPOINT"
|
||||
else
|
||||
echo " ✗ Midtraining: Not found"
|
||||
fi
|
||||
|
||||
if [ "$SFT_CHECKPOINT" != "none" ]; then
|
||||
echo " ✓ SFT: $SFT_CHECKPOINT"
|
||||
else
|
||||
echo " ✗ SFT: Not found"
|
||||
fi
|
||||
echo ""
|
||||
|
||||
# Determine what to do
|
||||
if [ "$SFT_CHECKPOINT" != "none" ]; then
|
||||
echo "🎉 All training stages complete!"
|
||||
echo ""
|
||||
echo "Your chatbot is ready. Chat with:"
|
||||
echo " python -m scripts.chat_cli -i sft"
|
||||
echo ""
|
||||
echo "Or start web UI:"
|
||||
echo " python -m scripts.chat_web -i sft"
|
||||
echo ""
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$BASE_CHECKPOINT" = "none" ]; then
|
||||
echo "❌ No base model found. Please run base training first:"
|
||||
echo " bash dev/runmac_overnight.sh"
|
||||
echo ""
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Download identity conversations if needed
|
||||
if [ ! -f "$CACHE_DIR/identity_conversations.jsonl" ]; then
|
||||
echo "Downloading identity conversations..."
|
||||
curl -L -o "$CACHE_DIR/identity_conversations.jsonl" \
|
||||
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Continue from where we left off
|
||||
# Check if we need midtraining for the current base model tag
|
||||
if [ "$MID_CHECKPOINT" = "none" ] || [ "$MID_TAG" != "$BASE_TAG" ]; then
|
||||
if [ "$MID_TAG" != "$BASE_TAG" ] && [ "$MID_CHECKPOINT" != "none" ]; then
|
||||
echo "⚠️ Found mid checkpoint for $MID_TAG but base model is $BASE_TAG"
|
||||
echo " Need to run midtraining for $BASE_TAG"
|
||||
fi
|
||||
|
||||
echo "📍 Continuing from: Base model complete ($BASE_TAG)"
|
||||
echo "📋 Next steps: Midtraining → SFT"
|
||||
echo ""
|
||||
|
||||
# Run midtraining
|
||||
echo "Step 1/2: Midtraining ($MID_ITERATIONS iterations)..."
|
||||
echo " Loading base checkpoint: $BASE_CHECKPOINT"
|
||||
echo " Device batch size: $DEVICE_BATCH_SIZE"
|
||||
python -m scripts.mid_train \
|
||||
--num_iterations=$MID_ITERATIONS \
|
||||
--device_batch_size=$DEVICE_BATCH_SIZE \
|
||||
--max_seq_len=1024 \
|
||||
--total_batch_size=$TOTAL_BATCH_SIZE \
|
||||
--eval_every=50 \
|
||||
--eval_tokens=$EVAL_TOKENS
|
||||
|
||||
echo ""
|
||||
echo "✓ Midtraining complete!"
|
||||
echo ""
|
||||
fi
|
||||
|
||||
# Check again for mid checkpoint and verify tag matches
|
||||
MID_CHECKPOINT=$(find_latest_checkpoint "$MID_DIR")
|
||||
MID_TAG=$(echo $MID_CHECKPOINT | cut -d'/' -f1)
|
||||
|
||||
if [ "$MID_CHECKPOINT" = "none" ]; then
|
||||
echo "❌ Midtraining failed to produce checkpoint"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Verify tags match
|
||||
if [ "$MID_TAG" != "$BASE_TAG" ]; then
|
||||
echo "❌ Tag mismatch: Base is $BASE_TAG but mid is $MID_TAG"
|
||||
echo "This shouldn't happen. Please check checkpoints manually."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if we need SFT for the current mid model tag
|
||||
if [ "$SFT_CHECKPOINT" = "none" ] || [ "$SFT_TAG" != "$MID_TAG" ]; then
|
||||
if [ "$SFT_TAG" != "$MID_TAG" ] && [ "$SFT_CHECKPOINT" != "none" ]; then
|
||||
echo "⚠️ Found SFT checkpoint for $SFT_TAG but mid model is $MID_TAG"
|
||||
echo " Need to run SFT for $MID_TAG"
|
||||
fi
|
||||
|
||||
# Run SFT
|
||||
echo "📍 Continuing from: Midtraining complete ($MID_TAG)"
|
||||
echo "📋 Next step: SFT (final stage!)"
|
||||
echo ""
|
||||
echo "Step 2/2: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..."
|
||||
echo " Loading mid checkpoint: $MID_CHECKPOINT"
|
||||
echo " Device batch size: $DEVICE_BATCH_SIZE"
|
||||
python -m scripts.chat_sft \
|
||||
--num_iterations=$SFT_ITERATIONS \
|
||||
--device_batch_size=$DEVICE_BATCH_SIZE \
|
||||
--target_examples_per_step=$((DEVICE_BATCH_SIZE * 2)) \
|
||||
--eval_steps=10
|
||||
else
|
||||
echo "✓ SFT already complete for $SFT_TAG"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=================================="
|
||||
echo "🎉 All Training Complete!"
|
||||
echo "=================================="
|
||||
echo "Finished: $(date)"
|
||||
echo ""
|
||||
echo "Your chatbot is ready! Chat with:"
|
||||
echo " python -m scripts.chat_cli -i sft"
|
||||
echo ""
|
||||
echo "Or start the web UI:"
|
||||
echo " python -m scripts.chat_web -i sft"
|
||||
echo ""
|
||||
echo "Generate final report:"
|
||||
echo " python -m nanochat.report generate"
|
||||
echo "=================================="
|
||||
Loading…
Reference in New Issue
Block a user