mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-27 14:42:29 +00:00
Improve Mac/MPS compatibility and device handling
Added dev/runmac_overnight.sh for optimized Mac training. Updated device-specific logic throughout dataloader, GPT, Muon optimizer, and training scripts to avoid CUDA-only features on MPS/CPU (e.g., torch.compile, pin_memory, non_blocking, bfloat16). Relaxed torch version constraints in pyproject.toml and removed Linux/CUDA-specific PyTorch config for better macOS support.
This commit is contained in:
parent
50bea28ef9
commit
3e184d343e
0
dev/runcpu.sh
Normal file → Executable file
0
dev/runcpu.sh
Normal file → Executable file
126
dev/runmac_overnight.sh
Executable file
126
dev/runmac_overnight.sh
Executable file
|
|
@ -0,0 +1,126 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Optimized overnight training for Mac (MPS/Apple Silicon)
|
||||||
|
# Expected runtime: 8-12 hours
|
||||||
|
# Expected result: Much better chatbot with coherent responses
|
||||||
|
|
||||||
|
set -e # Exit on error
|
||||||
|
|
||||||
|
echo "=================================="
|
||||||
|
echo "nanochat Mac Overnight Training"
|
||||||
|
echo "=================================="
|
||||||
|
echo "Started: $(date)"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
DEPTH=6 # Bigger model (6 layers vs 4)
|
||||||
|
BASE_ITERATIONS=500 # More base training
|
||||||
|
MID_ITERATIONS=150 # More midtraining
|
||||||
|
SFT_ITERATIONS=150 # More SFT
|
||||||
|
DATA_SHARDS=50 # More training data
|
||||||
|
|
||||||
|
echo "Configuration:"
|
||||||
|
echo " Model depth: $DEPTH (36.7M → 82M params)"
|
||||||
|
echo " Base iterations: $BASE_ITERATIONS"
|
||||||
|
echo " Mid iterations: $MID_ITERATIONS"
|
||||||
|
echo " SFT iterations: $SFT_ITERATIONS"
|
||||||
|
echo " Data shards: $DATA_SHARDS"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Clean up old run
|
||||||
|
echo "Cleaning up previous training..."
|
||||||
|
rm -f report.md
|
||||||
|
python -m scripts.report --reset
|
||||||
|
|
||||||
|
# Download training data
|
||||||
|
echo ""
|
||||||
|
echo "Step 1/6: Downloading training data ($DATA_SHARDS shards)..."
|
||||||
|
python -m nanochat.dataset -n $DATA_SHARDS
|
||||||
|
|
||||||
|
# Download identity conversations
|
||||||
|
echo ""
|
||||||
|
echo "Step 2/6: Downloading identity conversations..."
|
||||||
|
if [ ! -f ~/.cache/nanochat/identity_conversations.jsonl ]; then
|
||||||
|
curl -L -o ~/.cache/nanochat/identity_conversations.jsonl \
|
||||||
|
https://karpathy-public.s3.us-west-2.amazonaws.com/identity_conversations.jsonl
|
||||||
|
else
|
||||||
|
echo " Already downloaded, skipping."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build tokenizer
|
||||||
|
echo ""
|
||||||
|
echo "Step 3/6: Training tokenizer..."
|
||||||
|
python -m nanochat.tokenizer
|
||||||
|
|
||||||
|
# Base model training
|
||||||
|
echo ""
|
||||||
|
echo "Step 4/6: Training base model ($BASE_ITERATIONS iterations)..."
|
||||||
|
echo " This will take ~2-4 hours..."
|
||||||
|
python -m scripts.base_train \
|
||||||
|
--depth=$DEPTH \
|
||||||
|
--max_seq_len=1024 \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--total_batch_size=1024 \
|
||||||
|
--num_iterations=$BASE_ITERATIONS \
|
||||||
|
--eval_every=100 \
|
||||||
|
--eval_tokens=8192 \
|
||||||
|
--core_metric_every=250 \
|
||||||
|
--core_metric_max_per_task=20 \
|
||||||
|
--sample_every=100
|
||||||
|
|
||||||
|
# Evaluate base model
|
||||||
|
echo ""
|
||||||
|
echo "Evaluating base model..."
|
||||||
|
python -m scripts.base_loss
|
||||||
|
python -m scripts.base_eval
|
||||||
|
|
||||||
|
# Midtraining
|
||||||
|
echo ""
|
||||||
|
echo "Step 5/6: Midtraining ($MID_ITERATIONS iterations)..."
|
||||||
|
echo " This will take ~2-3 hours..."
|
||||||
|
python -m scripts.mid_train \
|
||||||
|
--num_iterations=$MID_ITERATIONS \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--max_seq_len=1024 \
|
||||||
|
--total_batch_size=1024 \
|
||||||
|
--eval_every=50
|
||||||
|
|
||||||
|
# SFT training
|
||||||
|
echo ""
|
||||||
|
echo "Step 6/6: Chat fine-tuning (SFT) ($SFT_ITERATIONS iterations)..."
|
||||||
|
echo " This will take ~2-3 hours..."
|
||||||
|
python -m scripts.chat_sft \
|
||||||
|
--num_iterations=$SFT_ITERATIONS \
|
||||||
|
--device_batch_size=1 \
|
||||||
|
--target_examples_per_step=8 \
|
||||||
|
--eval_steps=10
|
||||||
|
|
||||||
|
# Final evaluation
|
||||||
|
echo ""
|
||||||
|
echo "Running final evaluations..."
|
||||||
|
python -m scripts.chat_eval -i sft || echo "Chat eval had issues, skipping..."
|
||||||
|
|
||||||
|
# Generate report
|
||||||
|
echo ""
|
||||||
|
echo "Generating final report..."
|
||||||
|
python -m scripts.report
|
||||||
|
|
||||||
|
# Copy report to current directory
|
||||||
|
cp ~/.cache/nanochat/report/report.md ./report_overnight.md
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=================================="
|
||||||
|
echo "Training Complete!"
|
||||||
|
echo "=================================="
|
||||||
|
echo "Finished: $(date)"
|
||||||
|
echo ""
|
||||||
|
echo "Your chatbot is ready! Chat with it:"
|
||||||
|
echo " python -m scripts.chat_cli -i sft"
|
||||||
|
echo ""
|
||||||
|
echo "Or start the web UI:"
|
||||||
|
echo " python -m scripts.chat_web -i sft"
|
||||||
|
echo ""
|
||||||
|
echo "Report saved to: report_overnight.md"
|
||||||
|
echo "=================================="
|
||||||
|
|
@ -16,6 +16,9 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
||||||
bos_token = tokenizer.get_bos_token_id()
|
bos_token = tokenizer.get_bos_token_id()
|
||||||
# scratch buffer holds the tokens for one iteration
|
# scratch buffer holds the tokens for one iteration
|
||||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||||
|
# pin_memory and non_blocking only work on CUDA
|
||||||
|
device_type = device if isinstance(device, str) else device.type
|
||||||
|
use_cuda_optimizations = device_type == "cuda"
|
||||||
|
|
||||||
# infinite iterator over document batches
|
# infinite iterator over document batches
|
||||||
def document_batches():
|
def document_batches():
|
||||||
|
|
@ -38,11 +41,11 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
||||||
batch_index += 1
|
batch_index += 1
|
||||||
# Move tokens from the deque into the scratch buffer
|
# Move tokens from the deque into the scratch buffer
|
||||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=use_cuda_optimizations)
|
||||||
# Create the inputs/targets as 1D tensors
|
# Create the inputs/targets as 1D tensors
|
||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
# Reshape to 2D and move to GPU async
|
# Reshape to 2D and move to device (async on CUDA)
|
||||||
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
|
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=use_cuda_optimizations)
|
||||||
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
|
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=use_cuda_optimizations)
|
||||||
yield inputs, targets
|
yield inputs, targets
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,13 @@ class GPTConfig:
|
||||||
|
|
||||||
def norm(x):
|
def norm(x):
|
||||||
# Purely functional rmsnorm with no learnable params
|
# Purely functional rmsnorm with no learnable params
|
||||||
|
# Fallback for older PyTorch versions that don't have F.rms_norm
|
||||||
|
if hasattr(F, 'rms_norm'):
|
||||||
return F.rms_norm(x, (x.size(-1),))
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
else:
|
||||||
|
# Manual RMS norm implementation
|
||||||
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
return x * torch.rsqrt(variance + 1e-5)
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_emb(x, cos, sin):
|
def apply_rotary_emb(x, cos, sin):
|
||||||
|
|
@ -211,7 +217,9 @@ class GPT(nn.Module):
|
||||||
# calculate the rotation frequencies at each (time, channel) pair
|
# calculate the rotation frequencies at each (time, channel) pair
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos, sin = freqs.cos(), freqs.sin()
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
# Use bfloat16 on CUDA, float32 on MPS/CPU (MPS doesn't support bfloat16)
|
||||||
|
if device.type == "cuda":
|
||||||
|
cos, sin = cos.bfloat16(), sin.bfloat16()
|
||||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
@ -244,7 +252,10 @@ class GPT(nn.Module):
|
||||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||||
]
|
]
|
||||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
# fused=True only works on CUDA, not MPS or CPU
|
||||||
|
device_type = self.get_device().type
|
||||||
|
use_fused = device_type == "cuda"
|
||||||
|
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=use_fused)
|
||||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||||
# Create the Muon optimizer for the linear layers
|
# Create the Muon optimizer for the linear layers
|
||||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||||
|
|
@ -263,7 +274,9 @@ class GPT(nn.Module):
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
# Rotary embeddings are bfloat16 on CUDA, float32 on MPS/CPU
|
||||||
|
expected_dtype = torch.bfloat16 if self.cos.device.type == "cuda" else torch.float32
|
||||||
|
assert self.cos.dtype == expected_dtype, f"Rotary embeddings dtype mismatch: expected {expected_dtype}, got {self.cos.dtype}"
|
||||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||||
|
|
@ -19,7 +18,11 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||||
"""
|
"""
|
||||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
# Use bfloat16 on CUDA for speed, float32 on MPS/CPU (MPS doesn't support bfloat16)
|
||||||
|
if G.device.type == "cuda":
|
||||||
X = G.bfloat16()
|
X = G.bfloat16()
|
||||||
|
else:
|
||||||
|
X = G.float()
|
||||||
if G.size(-2) > G.size(-1):
|
if G.size(-2) > G.size(-1):
|
||||||
X = X.mT
|
X = X.mT
|
||||||
|
|
||||||
|
|
@ -35,6 +38,10 @@ def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
|
||||||
X = X.mT
|
X = X.mT
|
||||||
return X
|
return X
|
||||||
|
|
||||||
|
# Compile the function for CUDA; use eager mode for MPS/CPU
|
||||||
|
# We check at runtime in the optimizer which version to use
|
||||||
|
_zeropower_compiled = torch.compile(zeropower_via_newtonschulz5)
|
||||||
|
|
||||||
class Muon(torch.optim.Optimizer):
|
class Muon(torch.optim.Optimizer):
|
||||||
"""
|
"""
|
||||||
Muon - MomentUm Orthogonalized by Newton-schulz
|
Muon - MomentUm Orthogonalized by Newton-schulz
|
||||||
|
|
@ -65,6 +72,11 @@ class Muon(torch.optim.Optimizer):
|
||||||
group = dict(params=[p for p in params if p.numel() == size])
|
group = dict(params=[p for p in params if p.numel() == size])
|
||||||
param_groups.append(group)
|
param_groups.append(group)
|
||||||
super().__init__(param_groups, defaults)
|
super().__init__(param_groups, defaults)
|
||||||
|
# Determine which zeropower function to use based on device
|
||||||
|
# Check the first parameter to determine device
|
||||||
|
first_param = next(iter(params))
|
||||||
|
self._use_compiled = first_param.device.type == "cuda"
|
||||||
|
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
@ -79,7 +91,7 @@ class Muon(torch.optim.Optimizer):
|
||||||
buf: Tensor = state["momentum_buffer"]
|
buf: Tensor = state["momentum_buffer"]
|
||||||
buf.lerp_(g, 1 - group["momentum"])
|
buf.lerp_(g, 1 - group["momentum"])
|
||||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
g = self._zeropower_fn(g, steps=group["ns_steps"])
|
||||||
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
p.add_(g, alpha=-group["lr"] * max(1, p.size(-2) / p.size(-1))**0.5)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -122,6 +134,10 @@ class DistMuon(torch.optim.Optimizer):
|
||||||
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
|
||||||
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
|
||||||
super().__init__(param_groups, defaults)
|
super().__init__(param_groups, defaults)
|
||||||
|
# Determine which zeropower function to use based on device
|
||||||
|
first_param = params[0]
|
||||||
|
self._use_compiled = first_param.device.type == "cuda"
|
||||||
|
self._zeropower_fn = _zeropower_compiled if self._use_compiled else zeropower_via_newtonschulz5
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self):
|
def step(self):
|
||||||
|
|
@ -173,7 +189,7 @@ class DistMuon(torch.optim.Optimizer):
|
||||||
buf: Tensor = state["momentum_buffer"]
|
buf: Tensor = state["momentum_buffer"]
|
||||||
buf.lerp_(g, 1.0 - group["momentum"])
|
buf.lerp_(g, 1.0 - group["momentum"])
|
||||||
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
|
||||||
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
|
g = self._zeropower_fn(g, steps=group["ns_steps"])
|
||||||
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
|
||||||
p.add_(g, alpha=-group["lr"] * scale)
|
p.add_(g, alpha=-group["lr"] * scale)
|
||||||
# Replicate updated parameters to all ranks
|
# Replicate updated parameters to all ranks
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ dependencies = [
|
||||||
"setuptools>=80.9.0",
|
"setuptools>=80.9.0",
|
||||||
"tiktoken>=0.11.0",
|
"tiktoken>=0.11.0",
|
||||||
"tokenizers>=0.22.0",
|
"tokenizers>=0.22.0",
|
||||||
"torch>=2.8.0",
|
"torch>=2.0.0,!=2.3.*,!=2.4.*,!=2.5.*,!=2.6.*,!=2.7.*,!=2.8.*,!=2.9.*",
|
||||||
"uvicorn>=0.36.0",
|
"uvicorn>=0.36.0",
|
||||||
"wandb>=0.21.3",
|
"wandb>=0.21.3",
|
||||||
]
|
]
|
||||||
|
|
@ -44,19 +44,5 @@ python_files = ["test_*.py"]
|
||||||
python_classes = ["Test*"]
|
python_classes = ["Test*"]
|
||||||
python_functions = ["test_*"]
|
python_functions = ["test_*"]
|
||||||
|
|
||||||
# target torch to cuda 12.8
|
# Note: PyTorch configuration removed for macOS compatibility
|
||||||
[tool.uv.sources]
|
# On Linux with CUDA, you may need to reinstall from pytorch.org/whl/cu128
|
||||||
torch = [
|
|
||||||
{ index = "pytorch-cpu", marker = "sys_platform != 'linux'" },
|
|
||||||
{ index = "pytorch-cu128", marker = "sys_platform == 'linux'" },
|
|
||||||
]
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cpu"
|
|
||||||
url = "https://download.pytorch.org/whl/cpu"
|
|
||||||
explicit = true
|
|
||||||
|
|
||||||
[[tool.uv.index]]
|
|
||||||
name = "pytorch-cu128"
|
|
||||||
url = "https://download.pytorch.org/whl/cu128"
|
|
||||||
explicit = true
|
|
||||||
|
|
@ -109,6 +109,8 @@ with torch.device("meta"):
|
||||||
model.to_empty(device=device)
|
model.to_empty(device=device)
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
orig_model = model # original, uncompiled model, for saving raw model state_dict
|
||||||
|
# torch.compile only works well on CUDA; skip on MPS/CPU
|
||||||
|
if device_type == "cuda":
|
||||||
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
|
||||||
num_params = sum(p.numel() for p in model.parameters())
|
num_params = sum(p.numel() for p in model.parameters())
|
||||||
print0(f"Number of parameters: {num_params:,}")
|
print0(f"Number of parameters: {num_params:,}")
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,9 @@ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
import time
|
import time
|
||||||
import wandb
|
import wandb
|
||||||
import torch
|
import torch
|
||||||
|
# Disable torch.compile's dynamo on MPS (not supported)
|
||||||
|
import torch._dynamo
|
||||||
|
torch._dynamo.config.suppress_errors = True
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
|
|
@ -70,6 +73,8 @@ pretrain_batch_size = meta.get("device_batch_size", None)
|
||||||
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
|
||||||
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
|
||||||
orig_model = model
|
orig_model = model
|
||||||
|
# torch.compile only works well on CUDA; skip on MPS/CPU
|
||||||
|
if device_type == "cuda":
|
||||||
model = torch.compile(model, dynamic=False)
|
model = torch.compile(model, dynamic=False)
|
||||||
depth = model.config.n_layer
|
depth = model.config.n_layer
|
||||||
num_flops_per_token = model.estimate_flops()
|
num_flops_per_token = model.estimate_flops()
|
||||||
|
|
@ -119,7 +124,9 @@ def mid_data_generator(split):
|
||||||
assert dataset_size > 0
|
assert dataset_size > 0
|
||||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||||
token_buffer = deque()
|
token_buffer = deque()
|
||||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
# pin_memory only works on CUDA
|
||||||
|
use_pin_memory = device_type == "cuda"
|
||||||
|
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=use_pin_memory)
|
||||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||||
it = 0 # iteration counter
|
it = 0 # iteration counter
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -142,8 +149,9 @@ def mid_data_generator(split):
|
||||||
scratch[i] = token_buffer.popleft()
|
scratch[i] = token_buffer.popleft()
|
||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=True)
|
# non_blocking only works on CUDA
|
||||||
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=True)
|
inputs = inputs_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int32, non_blocking=use_pin_memory)
|
||||||
|
targets = targets_cpu.view(device_batch_size, max_seq_len).to(device=device, dtype=torch.int64, non_blocking=use_pin_memory)
|
||||||
if split == "train":
|
if split == "train":
|
||||||
if num_iterations > 0:
|
if num_iterations > 0:
|
||||||
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
approx_progress = it / num_iterations # calculate progress from the max number of iterations
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user