mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
feat: cross-platform support for CPU and GPU environments
This commit is contained in:
parent
dd6ff9a1cc
commit
447567634c
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -3,3 +3,4 @@ __pycache__/
|
|||
*.pyc
|
||||
rustbpe/target/
|
||||
dev-ignore/
|
||||
.idea
|
||||
|
|
|
|||
|
|
@ -79,6 +79,18 @@ def is_ddp():
|
|||
# TODO is there a proper way
|
||||
return int(os.environ.get('RANK', -1)) != -1
|
||||
|
||||
def is_macos():
|
||||
"""Check if running on macOS."""
|
||||
import platform
|
||||
return platform.system() == "Darwin"
|
||||
|
||||
def get_device_type():
|
||||
"""Get the device type string for autocast: 'cuda' or 'cpu'."""
|
||||
# Use CPU if on macOS or if CUDA is not available
|
||||
if is_macos() or not torch.cuda.is_available():
|
||||
return "cpu"
|
||||
return "cuda"
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||
|
|
@ -92,12 +104,14 @@ def get_dist_info():
|
|||
def compute_init():
|
||||
"""Basic initialization that we keep doing over and over, so make common."""
|
||||
|
||||
# CUDA is currently required
|
||||
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
|
||||
# Check if CUDA is available
|
||||
has_cuda = torch.cuda.is_available()
|
||||
on_macos = is_macos()
|
||||
|
||||
# Reproducibility
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.manual_seed(42)
|
||||
if has_cuda:
|
||||
torch.cuda.manual_seed(42)
|
||||
# skipping full reproducibility for now, possibly investigate slowdown later
|
||||
# torch.use_deterministic_algorithms(True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
|
|
@ -108,13 +122,25 @@ def compute_init():
|
|||
|
||||
# Distributed setup: Distributed Data Parallel (DDP), optional
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
|
||||
if ddp:
|
||||
|
||||
# Determine device
|
||||
if on_macos or not has_cuda:
|
||||
device = torch.device("cpu")
|
||||
if on_macos:
|
||||
logger.info("Running on macOS with CPU")
|
||||
else:
|
||||
logger.info("Running on CPU (CUDA not available)")
|
||||
if ddp:
|
||||
logger.warning("DDP requested but will run on CPU")
|
||||
elif ddp:
|
||||
device = torch.device("cuda", ddp_local_rank)
|
||||
torch.cuda.set_device(device) # make "cuda" default to this device
|
||||
dist.init_process_group(backend="nccl", device_id=device)
|
||||
dist.barrier()
|
||||
logger.info(f"Running on CUDA with DDP (rank {ddp_rank}/{ddp_world_size})")
|
||||
else:
|
||||
device = torch.device("cuda")
|
||||
logger.info("Running on CUDA (single GPU)")
|
||||
|
||||
if ddp_rank == 0:
|
||||
logger.info(f"Distributed world size: {ddp_world_size}")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from collections import deque
|
|||
|
||||
import torch
|
||||
|
||||
from nanochat.common import get_dist_info
|
||||
from nanochat.common import get_dist_info, get_device_type
|
||||
from nanochat.dataset import parquets_iter_batched
|
||||
from nanochat.tokenizer import get_tokenizer
|
||||
|
||||
|
|
@ -43,7 +43,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
# Reshape to 2D and move to GPU async
|
||||
inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
|
||||
# Reshape to 2D and move to device async
|
||||
device_type = get_device_type()
|
||||
inputs = inputs_cpu.view(B, T).to(device=device_type, dtype=torch.int32, non_blocking=True)
|
||||
targets = targets_cpu.view(B, T).to(device=device_type, dtype=torch.int64, non_blocking=True)
|
||||
yield inputs, targets
|
||||
|
|
|
|||
|
|
@ -308,7 +308,8 @@ if __name__ == "__main__":
|
|||
prompt_tokens = tokenizer.encode("The chemical formula of water is", prepend=bos_token_id)
|
||||
# generate the reference sequence using the model.generate() function
|
||||
generated_tokens = []
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
stream = model.generate(prompt_tokens, **kwargs)
|
||||
for token in stream:
|
||||
|
|
@ -316,7 +317,8 @@ if __name__ == "__main__":
|
|||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
print(f"Reference time: {t1 - t0:.2f}s")
|
||||
reference_ids = generated_tokens
|
||||
|
|
@ -324,7 +326,8 @@ if __name__ == "__main__":
|
|||
generated_tokens = []
|
||||
engine = Engine(model, tokenizer)
|
||||
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for token_column, token_masks in stream:
|
||||
token = token_column[0] # only print out the first row
|
||||
|
|
@ -332,7 +335,8 @@ if __name__ == "__main__":
|
|||
chunk = tokenizer.decode([token])
|
||||
print(chunk, end="", flush=True)
|
||||
print()
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
print(f"Engine time: {t1 - t0:.2f}s")
|
||||
# compare the two sequences
|
||||
|
|
|
|||
|
|
@ -22,12 +22,34 @@ dependencies = [
|
|||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
# target torch to cuda 12.8
|
||||
# target torch to cuda 12.8 or CPU
|
||||
[project.optional-dependencies]
|
||||
cpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
gpu = [
|
||||
"torch>=2.8.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
conflicts = [
|
||||
[
|
||||
{ extra = "cpu" },
|
||||
{ extra = "gpu" },
|
||||
],
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
torch = [
|
||||
{ index = "pytorch-cu128" },
|
||||
{ index = "pytorch-cpu", extra = "cpu" },
|
||||
{ index = "pytorch-cu128", extra = "gpu" },
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cpu"
|
||||
url = "https://download.pytorch.org/whl/cpu"
|
||||
explicit = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "pytorch-cu128"
|
||||
url = "https://download.pytorch.org/whl/cu128"
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import yaml
|
|||
import pandas as pd
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type
|
||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.core_eval import evaluate_task
|
||||
|
|
@ -122,7 +122,7 @@ def main():
|
|||
|
||||
# distributed / precision setup
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
||||
|
||||
# Load model and tokenizer from command line or from file system
|
||||
if len(sys.argv) >= 2:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
|||
import os
|
||||
import torch
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.common import compute_init, print0, compute_cleanup
|
||||
from nanochat.common import compute_init, print0, compute_cleanup, get_device_type
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
|
|||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||
|
||||
# Set up the precision we'll run with
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
||||
|
||||
# Evaluate the loss on each split
|
||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import torch
|
|||
|
||||
from nanochat.gpt import GPT, GPTConfig
|
||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type
|
||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
# Compute init
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -96,7 +96,7 @@ model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_la
|
|||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
model.to_empty(device="cuda")
|
||||
model.to_empty(device=device)
|
||||
model.init_weights()
|
||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||
|
|
@ -252,7 +252,8 @@ for step in range(num_iterations + 1):
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -275,7 +276,8 @@ for step in range(num_iterations + 1):
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -304,7 +306,8 @@ for step in range(num_iterations + 1):
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0
|
||||
print0(f"Peak memory usage: {peak_mem:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
@ -330,7 +333,7 @@ get_report().log(section="Base model training", data=[
|
|||
"MFU %": f"{mfu:.2f}%",
|
||||
"Total training flops": f"{flops_so_far:e}",
|
||||
"Total training time": f"{total_training_time/60:.2f}m",
|
||||
"Peak memory usage": f"{torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB",
|
||||
"Peak memory usage": f"{peak_mem:.2f}MiB",
|
||||
}
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid
|
|||
"""
|
||||
import argparse
|
||||
import torch
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, get_device_type
|
||||
from nanochat.engine import Engine
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
|
||||
|
|
@ -21,7 +21,7 @@ args = parser.parse_args()
|
|||
|
||||
# Init the model and tokenizer
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
|
||||
# Special tokens for the chat state machine
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from functools import partial
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0
|
||||
from nanochat.common import compute_init, compute_cleanup, get_dist_info, print0, get_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -195,7 +195,7 @@ if __name__ == "__main__":
|
|||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=ptdtype)
|
||||
|
||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||
engine = Engine(model, tokenizer)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ import wandb
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, DummyWandb, get_device_type
|
||||
from nanochat.checkpoint_manager import save_checkpoint, load_model
|
||||
from nanochat.engine import Engine
|
||||
from tasks.gsm8k import GSM8K
|
||||
|
|
@ -57,7 +57,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ import wandb
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb
|
||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, get_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.engine import Engine
|
||||
|
|
@ -63,7 +63,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import List, Optional, AsyncGenerator
|
||||
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.common import compute_init, get_device_type
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.engine import Engine
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
|
|||
args = parser.parse_args()
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import time
|
|||
import wandb
|
||||
import torch
|
||||
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir
|
||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, get_device_type
|
||||
from nanochat.tokenizer import get_token_bytes
|
||||
from nanochat.checkpoint_manager import save_checkpoint
|
||||
from nanochat.loss_eval import evaluate_bpb
|
||||
|
|
@ -53,7 +53,7 @@ user_config = {k: globals()[k] for k in config_keys} # possibly useful for loggi
|
|||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||
master_process = ddp_rank == 0
|
||||
dtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=dtype)
|
||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=dtype)
|
||||
|
||||
# wandb logging init
|
||||
use_dummy_wandb = run == "dummy" or not master_process
|
||||
|
|
@ -214,7 +214,8 @@ while True:
|
|||
# -------------------------------------------------------------------------
|
||||
# single training step
|
||||
# evaluate the gradient
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t0 = time.time()
|
||||
for micro_step in range(grad_accum_steps):
|
||||
with autocast_ctx:
|
||||
|
|
@ -235,7 +236,8 @@ while True:
|
|||
for opt in optimizers:
|
||||
opt.step()
|
||||
model.zero_grad(set_to_none=True)
|
||||
torch.cuda.synchronize()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
dt = t1 - t0
|
||||
# -------------------------------------------------------------------------
|
||||
|
|
@ -267,7 +269,8 @@ while True:
|
|||
})
|
||||
|
||||
# print a few more stats
|
||||
print0(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1024 / 1024:.2f}MiB")
|
||||
peak_mem = torch.cuda.max_memory_allocated() / 1024 / 1024 if torch.cuda.is_available() else 0
|
||||
print0(f"Peak memory usage: {peak_mem:.2f}MiB")
|
||||
print0(f"Total training time: {total_training_time/60:.2f}m")
|
||||
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
|
||||
|
||||
|
|
|
|||
50
speedrun.sh
50
speedrun.sh
|
|
@ -22,8 +22,19 @@ mkdir -p $NANOCHAT_BASE_DIR
|
|||
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync
|
||||
|
||||
# Detect hardware and install appropriate PyTorch version
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
echo "Detected macOS - installing CPU version of PyTorch"
|
||||
uv sync --extra cpu
|
||||
elif command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then
|
||||
echo "Detected NVIDIA GPU(s) - installing CUDA version of PyTorch"
|
||||
uv sync --extra gpu
|
||||
else
|
||||
echo "No GPU detected - installing CPU version of PyTorch"
|
||||
uv sync --extra cpu
|
||||
fi
|
||||
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
|
||||
|
|
@ -70,6 +81,23 @@ python -m scripts.tok_train --max_chars=2000000000
|
|||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
python -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Platform detection for compute configuration
|
||||
|
||||
# Check if running on macOS
|
||||
if [[ "$OSTYPE" == "darwin"* ]]; then
|
||||
echo "Detected macOS - running in CPU mode (single process)"
|
||||
TORCHRUN_CMD="python"
|
||||
# Check if CUDA/GPUs are available
|
||||
elif command -v nvidia-smi &> /dev/null && nvidia-smi &> /dev/null; then
|
||||
GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)
|
||||
echo "Detected $GPU_COUNT GPU(s) - running in GPU mode"
|
||||
TORCHRUN_CMD="torchrun --standalone --nproc_per_node=$GPU_COUNT"
|
||||
else
|
||||
echo "No GPUs detected - running in CPU mode (single process)"
|
||||
TORCHRUN_CMD="python"
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
|
|
@ -92,25 +120,25 @@ echo "Waiting for dataset download to complete..."
|
|||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# pretrain the d20 model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
$TORCHRUN_CMD -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
|
||||
# evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||
$TORCHRUN_CMD -m scripts.base_loss
|
||||
# evaluate the model on CORE tasks
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
|
||||
$TORCHRUN_CMD -m scripts.base_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# run midtraining and eval the model
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
|
||||
$TORCHRUN_CMD -m scripts.mid_train -- --run=$WANDB_RUN
|
||||
$TORCHRUN_CMD -m scripts.chat_eval -- -i mid
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# train sft and re-eval right away (should see a small bump)
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
||||
$TORCHRUN_CMD -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
$TORCHRUN_CMD -m scripts.chat_eval -- -i sft
|
||||
|
||||
# chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
|
@ -123,9 +151,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
|
|||
# (optional)
|
||||
|
||||
# run reinforcement learning
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# $TORCHRUN_CMD -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# eval the RL model only on GSM8K
|
||||
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
# $TORCHRUN_CMD -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Generate the full report by putting together all the sections
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user