mirror of
https://github.com/karpathy/nanochat.git
synced 2026-06-20 04:59:08 +00:00
Remove ordering
- Removed size-based ordering for comms and work. - Moved the comments around. - Consolidated the gather futures into one list.
This commit is contained in:
parent
d1595fb2d1
commit
7aceb020dd
|
|
@ -1,7 +1,3 @@
|
|||
"""
|
||||
Distributed AdamW optimizer with a fused step function.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
"""
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
|
@ -61,33 +57,6 @@ Some of the changes in nanochat implementation:
|
|||
- Makes no assumptions about model architecture (e.g. that attention weights are fused into QKVO format)
|
||||
"""
|
||||
|
||||
"""
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- This optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
Arguments:
|
||||
lr: The learning rate used by the internal SGD.
|
||||
momentum: The momentum used by the internal SGD.
|
||||
ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
import torch.distributed as dist
|
||||
|
||||
# Coefficients for Polar Express (computed for num_iters=5, safety_factor=2e-2, cushion=2)
|
||||
# From https://arxiv.org/pdf/2505.16932
|
||||
polar_express_coeffs = [
|
||||
|
|
@ -157,27 +126,43 @@ def muon_step_fused(
|
|||
class MuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
Non-distributed version for single-GPU training.
|
||||
|
||||
Args:
|
||||
|
||||
AdamW - Distributed AdamW optimizer with a fused step function.
|
||||
|
||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||
|
||||
https://kellerjordan.github.io/posts/muon/
|
||||
|
||||
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
|
||||
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||
|
||||
Some warnings:
|
||||
- The Muon optimizer should not be used for the embedding layer, the final fully connected layer,
|
||||
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
|
||||
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
|
||||
|
||||
AdamW Arguments:
|
||||
adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params
|
||||
muon_params: List of 2D tensors to optimize with Muon
|
||||
adamw_lr: Default learning rate for AdamW (default: 1e-3)
|
||||
adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999))
|
||||
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
|
||||
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
|
||||
adamw_weight_decay: Weight decay for AdamW (default: 0.01)
|
||||
muon_lr: Learning rate for Muon (default: 0.02)
|
||||
muon_momentum: Momentum for Muon (default: 0.95)
|
||||
muon_ns_steps: Number of Newton-Schulz iterations (default: 5)
|
||||
muon_beta2: Second moment decay for Muon variance reduction (default: 0.95)
|
||||
muon_weight_decay: Cautious weight decay for Muon (default: 0.0)
|
||||
|
||||
Muon Arguments:
|
||||
muon_lr: The learning rate used by the internal SGD.
|
||||
muon_momentum: The momentum used by the internal SGD.
|
||||
muon_ns_steps: The number of Newton-Schulz iteration steps to use.
|
||||
muon_beta2: The decay rate for the second moment (variance) estimate. Set to None to disable.
|
||||
muon_weight_decay: Cautious weight decay coefficient. Only decays where update and weight agree.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
adamw_groups: list[dict],
|
||||
muon_params,
|
||||
# AdamW hyperparams
|
||||
adamw_lr: float = 1e-3,
|
||||
adamw_lr: float = 1e-3, # can be overridden per-group
|
||||
adamw_betas: tuple[float, float] = (0.9, 0.999),
|
||||
adamw_eps: float = 1e-8,
|
||||
adamw_weight_decay: float = 0.01,
|
||||
|
|
@ -198,7 +183,7 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
for group in adamw_groups:
|
||||
assert isinstance(group, dict) and 'params' in group
|
||||
params = list(group['params'])
|
||||
lr = group.get('lr', adamw_lr)
|
||||
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
|
||||
for p in params:
|
||||
print(f"AdamW: 1 param of shape {p.shape}")
|
||||
param_groups.append(dict(
|
||||
|
|
@ -217,7 +202,7 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
beta2=muon_beta2, weight_decay=muon_weight_decay,
|
||||
))
|
||||
|
||||
defaults = dict(lr=adamw_lr)
|
||||
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
|
|
@ -331,21 +316,16 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
class DistMuonAdamW(torch.optim.Optimizer):
|
||||
"""
|
||||
Combined distributed optimizer: Muon for 2D matrix params, AdamW for others.
|
||||
|
||||
Communication optimization: starts communications for largest tensors first,
|
||||
then processes smallest tensors first while large tensor comms complete.
|
||||
This overlaps communication with computation for better efficiency.
|
||||
|
||||
Args:
|
||||
adamw_groups: List of dicts with 'params' and optional 'lr' for AdamW params
|
||||
muon_params: List of 2D tensors to optimize with Muon
|
||||
adamw_betas: Beta coefficients for AdamW (default: (0.9, 0.999))
|
||||
adamw_eps: Epsilon for AdamW numerical stability (default: 1e-8)
|
||||
muon_lr: Learning rate for Muon (default: 0.02)
|
||||
muon_momentum: Momentum for Muon (default: 0.95)
|
||||
muon_ns_steps: Number of Newton-Schulz iterations (default: 5)
|
||||
muon_beta2: Second moment decay for Muon variance reduction (default: 0.95)
|
||||
muon_weight_decay: Cautious weight decay for Muon (default: 0.0)
|
||||
|
||||
(See MuonAdamW for algorithmic details.)
|
||||
|
||||
AdamW Communication:
|
||||
In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction.
|
||||
A bunch of ideas (e.g. dist comms in slices) are borrowed from modded-nanogpt.
|
||||
|
||||
Muon Communication:
|
||||
Parameters are grouped by shape, then stacked into single Tensors for efficient communication.
|
||||
We launch comms largest-first, then process smallest-first so large comms finish in time.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -374,11 +354,11 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
# Validate
|
||||
if rank == 0:
|
||||
|
||||
# AdamW groups: each input group becomes one param_group
|
||||
for group in adamw_groups:
|
||||
|
||||
# AdamW groups: each input group becomes one param_group
|
||||
for group in adamw_groups:
|
||||
# Validate
|
||||
if rank == 0:
|
||||
assert isinstance(group, dict), "expecting param_groups to be a list of dicts"
|
||||
assert isinstance(group['params'], list), "expecting group['params'] to be a list of tensors"
|
||||
for p in group['params']:
|
||||
|
|
@ -386,18 +366,16 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
print(f"AdamW: 1 param of shape {p.shape}, sliced={sliced}")
|
||||
if sliced: # large parameter tensors will be operated on in slices
|
||||
assert p.shape[0] % world_size == 0, f"First dim of parameter shape {p.shape} must be divisible by world size {world_size}"
|
||||
|
||||
# AdamW groups: each input group becomes one param_group
|
||||
for group in adamw_groups:
|
||||
# Add to param_groups
|
||||
params = list(group['params'])
|
||||
lr = group.get('lr', adamw_lr)
|
||||
lr = group.get('lr', adamw_lr) # AdamW supports per-group learning rates
|
||||
param_groups.append(dict(
|
||||
params=params, lr=lr, kind='adamw',
|
||||
betas=adamw_betas, eps=adamw_eps, weight_decay=adamw_weight_decay,
|
||||
))
|
||||
|
||||
# Muon groups: group by shape for stacking, with all Muon hyperparams in the group
|
||||
muon_shapes = sorted({p.shape for p in muon_params})
|
||||
muon_shapes = sorted({p.shape for p in muon_params}) # sort for deterministic ordering across ranks
|
||||
for shape in muon_shapes:
|
||||
group_params = [p for p in muon_params if p.shape == shape]
|
||||
device, dtype = group_params[0].device, group_params[0].dtype
|
||||
|
|
@ -413,9 +391,9 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
beta2=muon_beta2, weight_decay=muon_weight_decay,
|
||||
))
|
||||
|
||||
defaults = dict(lr=adamw_lr)
|
||||
defaults = dict(lr=adamw_lr) # torch.optim.Optimizer requires a default lr
|
||||
super().__init__(param_groups, defaults)
|
||||
|
||||
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
|
@ -429,25 +407,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
|
||||
# Precompute group order sorted by communication size for optimal overlap
|
||||
# We launch comms largest-first, then process smallest-first so large comms finish in time
|
||||
group_comm_sizes = []
|
||||
for group_idx, group in enumerate(self.param_groups):
|
||||
if group['kind'] == 'adamw':
|
||||
# AdamW group comm size = size of first param (all params in group have same shape)
|
||||
comm_size = group['params'][0].numel()
|
||||
else: # muon
|
||||
# Muon group comm size = padded stacked tensor size
|
||||
chunk_size = group['chunk_size']
|
||||
comm_size = chunk_size * world_size * group['params'][0].numel()
|
||||
group_comm_sizes.append((comm_size, group_idx))
|
||||
|
||||
# Sort: largest first for comms, we'll reverse for compute
|
||||
group_comm_sizes.sort(key=lambda x: x[0], reverse=True)
|
||||
self._group_order = [group_idx for _, group_idx in group_comm_sizes]
|
||||
if rank == 0:
|
||||
print(f"DistMuonAdamW: {len(self._group_order)} groups, comm order by size (largest first)")
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self):
|
||||
|
|
@ -457,13 +416,11 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
# Ensure all grads exist
|
||||
assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
|
||||
|
||||
# First pass: launch all async communications (largest groups first)
|
||||
# First pass: launch all async communications
|
||||
adamw_infos: dict[Tensor, dict] = {} # param -> {reduce_future, grad_slice, is_small}
|
||||
muon_infos: dict[int, dict] = {} # group_idx -> {reduce_future, grad_chunk, stacked_grads}
|
||||
|
||||
for group_idx in self._group_order:
|
||||
group = self.param_groups[group_idx]
|
||||
|
||||
for group_idx, group in enumerate(self.param_groups):
|
||||
if group['kind'] == 'adamw':
|
||||
params: list[Tensor] = group['params']
|
||||
for p in params:
|
||||
|
|
@ -474,7 +431,6 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
adamw_infos[p] = dict(reduce_future=reduce_future, grad_slice=grad, is_small=True)
|
||||
# Large param: reduce_scatter
|
||||
else:
|
||||
|
||||
rank_size = grad.shape[0] // world_size # p.shape[0] % world_size == 0 is checked in __init__
|
||||
grad_slice = torch.empty_like(grad[:rank_size])
|
||||
reduce_future = dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future()
|
||||
|
|
@ -508,13 +464,10 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
stacked_grads=stacked_grads, # reuse for all_gather output
|
||||
)
|
||||
|
||||
# Second pass: process groups (smallest first, so large comms finish in time) wait for reduce, compute batched updates, kick off all_gather
|
||||
gather_futures: list[torch.Future] = []
|
||||
muon_gather_infos: list[dict] = []
|
||||
|
||||
for group_idx in reversed(self._group_order):
|
||||
group = self.param_groups[group_idx]
|
||||
# Second pass: wait for reduce, compute updates, kick off all_gather
|
||||
gather_infos: list[dict] = [] # unified list for both AdamW and Muon gathers
|
||||
|
||||
for group_idx, group in enumerate(self.param_groups):
|
||||
if group['kind'] == 'adamw':
|
||||
beta1, beta2 = group['betas']
|
||||
eps = group['eps']
|
||||
|
|
@ -561,7 +514,8 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
|
||||
# Only large params need all_gather
|
||||
if not info['is_small']:
|
||||
gather_futures.append(dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future())
|
||||
gather_future = dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
|
||||
gather_infos.append(dict(gather_future=gather_future, params=None))
|
||||
|
||||
else: # muon
|
||||
info = muon_infos[group_idx]
|
||||
|
|
@ -642,19 +596,18 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
stacked_params, updated_params, async_op=True
|
||||
).get_future()
|
||||
|
||||
muon_gather_infos.append(dict(
|
||||
gather_infos.append(dict(
|
||||
gather_future=gather_future,
|
||||
stacked_params=stacked_params,
|
||||
params=params,
|
||||
))
|
||||
|
||||
# Final pass: wait for all_gather and copy back to params
|
||||
if gather_futures:
|
||||
torch.futures.collect_all(gather_futures).wait()
|
||||
|
||||
for info in muon_gather_infos:
|
||||
# Final pass: wait for all_gather and copy back to params (Muon only)
|
||||
for info in gather_infos:
|
||||
info["gather_future"].wait()
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
# Muon params need to be copied back from stacked buffer
|
||||
if info["params"] is not None:
|
||||
stacked_params = info["stacked_params"]
|
||||
params = info["params"]
|
||||
# Batched copy back (single kernel instead of N individual copies)
|
||||
torch._foreach_copy_(params, list(stacked_params[:len(params)].unbind(0)))
|
||||
153
speedrun.sh
Normal file
153
speedrun.sh
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script is the "Best ChatGPT clone that $100 can buy",
|
||||
# It is designed to run in ~4 hours on 8XH100 node at $3/GPU/hour.
|
||||
|
||||
# 1) Example launch (simplest):
|
||||
# bash speedrun.sh
|
||||
# 2) Example launch in a screen session (because the run takes ~4 hours):
|
||||
# screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
# 3) Example launch with wandb logging, but see below for setting up wandb first:
|
||||
# WANDB_RUN=speedrun screen -L -Logfile speedrun.log -S speedrun bash speedrun.sh
|
||||
|
||||
# Default intermediate artifacts directory is in ~/.cache/nanochat
|
||||
export OMP_NUM_THREADS=1
|
||||
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
|
||||
mkdir -p $NANOCHAT_BASE_DIR
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Python venv setup with uv
|
||||
|
||||
# install uv (if not already installed)
|
||||
if ! command -v uv &> /dev/null; then
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
fi
|
||||
# Add uv to PATH (it installs to ~/.local/bin)
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
# create a .venv local virtual environment (if it doesn't exist)
|
||||
[ -d ".venv" ] || uv venv
|
||||
# install the repo dependencies
|
||||
uv sync --extra gpu
|
||||
# activate venv so that `python` uses the project's venv instead of system python
|
||||
source .venv/bin/activate
|
||||
# Ensure we're using the venv Python and torchrun
|
||||
PYTHON=".venv/bin/python"
|
||||
TORCHRUN=".venv/bin/torchrun"
|
||||
|
||||
# Install flash_attn if the wheel exists (for A100 compatibility)
|
||||
if [ -f "flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl" ]; then
|
||||
uv pip install flash_attn-2.8.3+cu128torch2.9-cp310-cp310-linux_x86_64.whl
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# wandb setup
|
||||
# If you wish to use wandb for logging (it's nice!, recommended).
|
||||
# You can authenticate in one of two ways:
|
||||
# 1) Set WANDB_API_KEY environment variable before running:
|
||||
# `export WANDB_API_KEY=your_api_key_here`
|
||||
# `bash runs/speedrun.sh`
|
||||
# 2) Or run `wandb login` after the venv is set up (the venv will be active)
|
||||
# The script will automatically use wandb if WANDB_API_KEY is set or if you've logged in.
|
||||
# Set the WANDB_RUN environment variable when running this script, e.g.:
|
||||
# `WANDB_RUN=d26 bash runs/speedrun.sh`
|
||||
if [ -z "$WANDB_RUN" ]; then
|
||||
# by default use "dummy" : it's handled as a special case, skips logging to wandb
|
||||
WANDB_RUN=dummy
|
||||
fi
|
||||
|
||||
# If WANDB_API_KEY is set, export it so wandb can use it automatically
|
||||
if [ -n "$WANDB_API_KEY" ]; then
|
||||
export WANDB_API_KEY
|
||||
echo "Using WANDB_API_KEY from environment for wandb authentication"
|
||||
fi
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# During the course of the run, we will be writing markdown reports to the report/
|
||||
# directory in the base dir. This command clears it out and writes a header section
|
||||
# with a bunch of system info and a timestamp that marks the start of the run.
|
||||
$PYTHON -m nanochat.report reset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Tokenizer
|
||||
|
||||
# Download the first ~2B characters of pretraining dataset
|
||||
# look at dev/repackage_data_reference.py for details on how this data was prepared
|
||||
# each data shard is ~250M chars
|
||||
# so we download 2e9 / 250e6 = 8 data shards at this point
|
||||
# each shard is ~100MB of text (compressed), so this is about ~800MB of data on disk
|
||||
$PYTHON -m nanochat.dataset -n 8
|
||||
# Immediately also kick off downloading more shards in the background while tokenizer trains
|
||||
# See comment below for why 370 is the right number here
|
||||
$PYTHON -m nanochat.dataset -n 370 &
|
||||
DATASET_DOWNLOAD_PID=$!
|
||||
# train the tokenizer with vocab size 2**16 = 65536 on ~2B characters of data
|
||||
$PYTHON -m scripts.tok_train --max-chars=20000000 --vocab-size=50304
|
||||
# evaluate the tokenizer (report compression ratio etc.)
|
||||
$PYTHON -m scripts.tok_eval
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Base model (pretraining)
|
||||
|
||||
# The d20 model is 561M parameters.
|
||||
# Chinchilla says #tokens = 20X #params, so we need 561e6 * 20 = 11.2B tokens.
|
||||
# Assume our tokenizer is 4.8 chars/token, this is 11.2B * 4.8 ~= 54B chars.
|
||||
# At 250M chars/shard, this is 54B / 250M ~= 216 shards needed for pretraining.
|
||||
# Round up to 240 for safety. Also, the new DataLoader wastes about 35% of tokens to cropping
|
||||
# so 240 / (1 - 0.35) = 370 shards are needed.
|
||||
# At ~100MB/shard, this downloads ~37GB of data to disk.
|
||||
# (The total number of shards available in the entire dataset is 1822.)
|
||||
echo "Waiting for dataset download to complete..."
|
||||
wait $DATASET_DOWNLOAD_PID
|
||||
|
||||
# Number of processes/GPUs to use
|
||||
NPROC_PER_NODE=8
|
||||
# Per-device batch size (reduce this if you hit OOM - gradient accumulation will automatically increase) Default is 32.
|
||||
# To match modded-nanogpt initial batch: 8 seqs * 2048 seq_len * 8 GPUs = 131,072 tokens
|
||||
DEVICE_BATCH_SIZE=8
|
||||
TOTAL_BATCH_SIZE=131072
|
||||
|
||||
# pretrain the d20 model
|
||||
#$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-mine -- --depth=12 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
|
||||
$TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train-main-profiled -- --depth=11 --target-param-data-ratio=20 --device-batch-size=$DEVICE_BATCH_SIZE --total-batch-size=$TOTAL_BATCH_SIZE --run=$WANDB_RUN
|
||||
# # evaluate the model on a larger chunk of train/val data and draw some samples
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_loss
|
||||
# # evaluate the model on CORE tasks
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_eval
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Midtraining (teach the model conversation special tokens, tool use, multiple choice)
|
||||
|
||||
# # download 2.3MB of synthetic identity conversations to impart a personality to nanochat
|
||||
# # see dev/gen_synthetic_data.py for details on how this data was prepared and to get a sense of how you can easily tune it
|
||||
# curl -L -o $NANOCHAT_BASE_DIR/identity_conversations.jsonl https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||
|
||||
# # run midtraining and eval the model
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.mid_train -- --device-batch-size=$DEVICE_BATCH_SIZE --run=$WANDB_RUN
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i mid
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Supervised Finetuning (domain adaptation to each sequence all by itself per row)
|
||||
|
||||
# # train sft and re-eval right away (should see a small bump)
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_sft -- --run=$WANDB_RUN
|
||||
# $TORCHRUN --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i sft
|
||||
|
||||
# # chat with the model over CLI! Leave out the -p to chat interactively
|
||||
# # python -m scripts.chat_cli -p "Why is the sky blue?"
|
||||
|
||||
# # even better, chat with your model over a pretty WebUI ChatGPT style
|
||||
# # python -m scripts.chat_web
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Reinforcement Learning. Optional, and currently only on GSM8K
|
||||
# # (optional)
|
||||
|
||||
# # run reinforcement learning
|
||||
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_rl -- --run=$WANDB_RUN
|
||||
# # eval the RL model only on GSM8K
|
||||
# # torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.chat_eval -- -i rl -a GSM8K
|
||||
|
||||
# # -----------------------------------------------------------------------------
|
||||
# # Generate the full report by putting together all the sections
|
||||
# # report.md is the output and will be copied to current directory for convenience
|
||||
# $PYTHON -m nanochat.report generate
|
||||
Loading…
Reference in New Issue
Block a user