Compare commits

...

8 Commits

Author SHA1 Message Date
Gaurav
4695fdb6ec
Merge 65865df300 into 788dadeb88 2026-02-16 18:26:59 +01:00
Andrej Karpathy
788dadeb88 a number of upgrades to SFT script to bring it up to date w.r.t. pretraining and tuning some of its kwargs based on sweeps 2026-02-16 14:41:53 +00:00
svlandeg
65865df300 Merge branch 'master' into master_goderr 2026-01-15 20:47:26 +01:00
svlandeg
c79559674b Merge branch 'master' into master_goderr 2025-12-30 15:13:42 +01:00
Goderr
a3b701af8d change in code block display 2025-10-21 20:17:03 +05:30
Goderr
e9c5e911f6 change in linespacing of response 2025-10-21 20:14:15 +05:30
Goderr
07348bb1d3 forgot to change the max_tokens 2025-10-21 18:01:41 +05:30
Goderr
5ec267fabc Change in UI 2025-10-21 17:49:08 +05:30
4 changed files with 351 additions and 65 deletions

View File

@ -170,3 +170,22 @@ def load_model(source, *args, **kwargs):
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
return load_model_from_dir(checkpoints_dir, *args, **kwargs)
def load_optimizer_state(source, device, rank, model_tag=None, step=None):
"""Load just the optimizer shard for a given rank, without re-loading the model."""
model_dir = {
"base": "base_checkpoints",
"sft": "chatsft_checkpoints",
"rl": "chatrl_checkpoints",
}[source]
base_dir = get_base_dir()
checkpoints_dir = os.path.join(base_dir, model_dir)
if model_tag is None:
model_tag = find_largest_model(checkpoints_dir)
checkpoint_dir = os.path.join(checkpoints_dir, model_tag)
if step is None:
step = find_last_step(checkpoint_dir)
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
log0(f"Loading optimizer state from {optimizer_path}")
optimizer_data = torch.load(optimizer_path, map_location=device)
return optimizer_data

View File

@ -1,10 +1,13 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0, viewport-fit=cover">
<title>NanoChat</title>
<link rel="icon" type="image/svg+xml" href="/logo.svg">
<script src="https://cdn.jsdelivr.net/npm/marked@9.1.6/marked.min.js"></script>
<script src="https://cdn.jsdelivr.net/npm/dompurify@3.0.5/dist/purify.min.js"></script>
<style>
:root {
color-scheme: light;
@ -104,11 +107,14 @@
}
.message-content {
white-space: pre-wrap;
line-height: 1.6;
max-width: 100%;
}
.message.user .message-content {
white-space: pre-wrap;
}
.message.assistant .message-content {
background: transparent;
border: none;
@ -224,8 +230,16 @@
}
@keyframes typing {
0%, 60%, 100% { opacity: 0.2; }
30% { opacity: 1; }
0%,
60%,
100% {
opacity: 0.2;
}
30% {
opacity: 1;
}
}
.error-message {
@ -236,13 +250,109 @@
border-radius: 0.75rem;
margin-top: 0.5rem;
}
/* Markdown styling */
.message-content table {
border-collapse: collapse;
margin: 0.25rem 0;
width: 100%;
}
.message-content th,
.message-content td {
border: 1px solid #e5e7eb;
padding: 0.5rem;
text-align: left;
}
.message-content th {
background-color: #f9fafb;
font-weight: 600;
}
.message-content code {
background-color: #f3f4f6;
padding: 0.125rem 0.25rem;
border-radius: 0.25rem;
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
font-size: 0.875em;
}
.message-content pre {
background-color: #f3f4f6;
padding: 0.75rem;
border-radius: 0.5rem;
overflow-x: auto;
margin: 0.25rem 0;
position: relative;
}
.message-content pre code {
background: none;
padding: 0;
}
.code-language-label {
position: absolute;
top: 0.5rem;
right: 0.5rem;
background-color: #e5e7eb;
color: #6b7280;
padding: 0.125rem 0.375rem;
border-radius: 0.25rem;
font-size: 0.75rem;
font-family: ui-sans-serif, -apple-system, system-ui, "Segoe UI", Helvetica, Arial, sans-serif;
text-transform: lowercase;
font-weight: 500;
}
.message-content h1,
.message-content h2,
.message-content h3,
.message-content h4,
.message-content h5,
.message-content h6 {
margin: 0.75rem 0 0.25rem 0;
font-weight: 600;
}
.message-content h1 {
font-size: 1.5em;
}
.message-content h2 {
font-size: 1.3em;
}
.message-content h3 {
font-size: 1.1em;
}
.message-content ul,
.message-content ol {
margin: 0.25rem 0;
padding-left: 1.5rem;
}
.message-content li {
margin: 0.1rem 0;
}
.message-content blockquote {
border-left: 4px solid #e5e7eb;
padding-left: 1rem;
margin: 0.5rem 0;
color: #6b7280;
}
</style>
</head>
<body>
<div class="header">
<div class="header-left">
<button class="new-conversation-btn" onclick="newConversation()" title="New Conversation (Ctrl+Shift+N)">
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<svg width="18" height="18" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"
stroke-linecap="round" stroke-linejoin="round">
<path d="M12 5v14"></path>
<path d="M5 12h14"></path>
</svg>
@ -259,15 +369,11 @@
<div class="input-container">
<div class="input-wrapper">
<textarea
id="chatInput"
class="chat-input"
placeholder="Ask anything"
rows="1"
onkeydown="handleKeyDown(event)"
></textarea>
<textarea id="chatInput" class="chat-input" placeholder="Ask anything" rows="1"
onkeydown="handleKeyDown(event)"></textarea>
<button id="sendButton" class="send-button" onclick="sendMessage()" disabled>
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<svg width="22" height="22" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"
stroke-linecap="round" stroke-linejoin="round">
<path d="M22 2L11 13"></path>
<path d="M22 2l-7 20-4-9-9-4 20-7z"></path>
</svg>
@ -287,7 +393,65 @@
let currentTemperature = 0.8;
let currentTopK = 50;
chatInput.addEventListener('input', function() {
// Configure marked.js for better rendering
if (typeof marked !== 'undefined') {
marked.setOptions({
breaks: true, // Convert line breaks to <br>
gfm: true, // GitHub Flavored Markdown
sanitize: false, // We'll use DOMPurify instead
smartypants: false // Disable smart quotes for simplicity
});
}
// Markdown rendering function
function renderMarkdown(content) {
try {
if (typeof marked !== 'undefined' && typeof DOMPurify !== 'undefined') {
const rawHtml = marked.parse(content);
const sanitizedHtml = DOMPurify.sanitize(rawHtml);
return addLanguageLabels(sanitizedHtml);
} else {
console.warn('Markdown libraries not loaded, falling back to plain text');
return escapeHtml(content);
}
} catch (error) {
console.error('Markdown rendering failed:', error);
return escapeHtml(content);
}
}
// Add language labels to code blocks
function addLanguageLabels(html) {
const tempDiv = document.createElement('div');
tempDiv.innerHTML = html;
const codeBlocks = tempDiv.querySelectorAll('pre code');
codeBlocks.forEach(codeBlock => {
const pre = codeBlock.parentElement;
const className = codeBlock.className;
// Extract language from class (e.g., "language-javascript" -> "javascript")
const languageMatch = className.match(/language-(\w+)/);
if (languageMatch) {
const language = languageMatch[1];
const label = document.createElement('span');
label.className = 'code-language-label';
label.textContent = language;
pre.appendChild(label);
}
});
return tempDiv.innerHTML;
}
// HTML escape fallback function
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
chatInput.addEventListener('input', function () {
this.style.height = 'auto';
this.style.height = Math.min(this.scrollHeight, 200) + 'px';
sendButton.disabled = !this.value.trim() || isGenerating;
@ -300,7 +464,7 @@
}
}
document.addEventListener('keydown', function(event) {
document.addEventListener('keydown', function (event) {
// Ctrl+Shift+N for new conversation
if (event.ctrlKey && event.shiftKey && event.key === 'N') {
event.preventDefault();
@ -326,13 +490,19 @@
const contentDiv = document.createElement('div');
contentDiv.className = 'message-content';
contentDiv.textContent = content;
// Render markdown for assistant messages, plain text for others
if (role === 'assistant' && content) {
contentDiv.innerHTML = renderMarkdown(content);
} else {
contentDiv.textContent = content;
}
// Add click handler for user messages to enable editing
if (role === 'user' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to edit and restart from here');
contentDiv.addEventListener('click', function() {
contentDiv.addEventListener('click', function () {
if (!isGenerating) {
editMessage(messageIndex);
}
@ -343,7 +513,7 @@
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
contentDiv.addEventListener('click', function () {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
@ -426,7 +596,8 @@
const data = JSON.parse(line.slice(6));
if (data.token) {
fullResponse += data.token;
assistantContent.textContent = fullResponse;
// Render markdown in real-time as tokens arrive
assistantContent.innerHTML = renderMarkdown(fullResponse);
chatContainer.scrollTop = chatContainer.scrollHeight;
}
} catch (e) {
@ -441,7 +612,7 @@
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
assistantContent.addEventListener('click', function () {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
@ -563,4 +734,5 @@
});
</script>
</body>
</html>
</html>

View File

@ -468,6 +468,7 @@ while True:
"user_config": user_config, # inputs to the training script
"device_batch_size": args.device_batch_size,
"max_seq_len": args.max_seq_len,
"total_batch_size": total_batch_size,
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,

View File

@ -9,6 +9,7 @@ Or torchrun for training:
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --device-batch-size=16
"""
import gc
import argparse
import os
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
@ -16,12 +17,14 @@ import time
import wandb
import torch
from contextlib import nullcontext
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, get_base_dir, autodetect_device_type, get_peak_flops
from nanochat.tokenizer import get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.checkpoint_manager import save_checkpoint, load_model, load_optimizer_state
from nanochat.loss_eval import evaluate_bpb
from nanochat.checkpoint_manager import load_model
import torch.distributed as dist
from nanochat.flash_attention import HAS_FA3
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture
from tasks.gsm8k import GSM8K
@ -37,27 +40,30 @@ parser = argparse.ArgumentParser(description="Supervised fine-tuning (SFT) the m
parser.add_argument("--run", type=str, default="dummy", help="wandb run name ('dummy' disables wandb logging)")
# Runtime
parser.add_argument("--device-type", type=str, default="", help="cuda|cpu|mps (empty = autodetect)")
parser.add_argument("--dtype", type=str, default="bfloat16", help="float32|bfloat16")
# Model loading
parser.add_argument("--model-tag", type=str, default=None, help="model tag to load from")
parser.add_argument("--model-step", type=int, default=None, help="model step to load from")
parser.add_argument("--load-optimizer", type=int, default=0, help="warm-start optimizer from pretrained checkpoint (0=no, 1=yes)")
# Training horizon
parser.add_argument("--num-iterations", type=int, default=-1, help="number of optimization steps (-1 = full epoch)")
# Batch sizes
parser.add_argument("--max-seq-len", type=int, default=2048, help="max context length")
parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size")
parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens")
# Optimization
parser.add_argument("--embedding-lr", type=float, default=0.3, help="learning rate for embedding parameters (Adam)")
parser.add_argument("--unembedding-lr", type=float, default=0.004, help="learning rate for unembedding parameters (Adam)")
parser.add_argument("--matrix-lr", type=float, default=0.02, help="learning rate for matrix parameters (Muon)")
parser.add_argument("--weight-decay", type=float, default=0.0, help="weight decay for embedding/unembedding parameters (Adam)")
parser.add_argument("--init-lr-frac", type=float, default=1.0, help="initial LR as fraction of base LR")
# Batch sizes (default: inherit from pretrained checkpoint)
parser.add_argument("--max-seq-len", type=int, default=None, help="max context length (default: inherit from pretrain)")
parser.add_argument("--device-batch-size", type=int, default=None, help="per-device batch size (default: inherit from pretrain)")
parser.add_argument("--total-batch-size", type=int, default=None, help="total batch size in tokens (default: inherit from pretrain)")
# Optimization (default: inherit from pretrained checkpoint)
parser.add_argument("--embedding-lr", type=float, default=None, help="learning rate for embedding parameters (Adam) (default: inherit from pretrain)")
parser.add_argument("--unembedding-lr", type=float, default=None, help="learning rate for unembedding parameters (Adam) (default: inherit from pretrain)")
parser.add_argument("--matrix-lr", type=float, default=None, help="learning rate for matrix parameters (Muon) (default: inherit from pretrain)")
parser.add_argument("--init-lr-frac", type=float, default=0.8, help="initial LR as fraction of base LR")
parser.add_argument("--warmup-ratio", type=float, default=0.0, help="ratio of iterations for LR warmup")
parser.add_argument("--warmdown-ratio", type=float, default=0.5, help="ratio of iterations for LR warmdown")
parser.add_argument("--final-lr-frac", type=float, default=0.0, help="final LR as fraction of initial LR")
# Evaluation
parser.add_argument("--eval-every", type=int, default=150, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=20*524288, help="number of tokens to evaluate val loss on")
# Output
parser.add_argument("--dry-run", action="store_true", help="log to wandb but skip checkpoints/report")
parser.add_argument("--eval-every", type=int, default=200, help="evaluate val bpb every N steps (-1 = disable)")
parser.add_argument("--eval-tokens", type=int, default=40*524288, help="number of tokens to evaluate val loss on")
parser.add_argument("--chatcore-every", type=int, default=200, help="evaluate ChatCORE metric every N steps (-1 = disable)")
parser.add_argument("--chatcore-max-cat", type=int, default=-1, help="max problems per categorical task for ChatCORE")
parser.add_argument("--chatcore-max-sample", type=int, default=24, help="max problems per generative task for ChatCORE")
args = parser.parse_args()
user_config = vars(args).copy()
# -----------------------------------------------------------------------------
@ -66,20 +72,48 @@ user_config = vars(args).copy()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)
print0(f"GPU: {gpu_device_name} | Peak FLOPS (BF16): {gpu_peak_flops:.2e}")
else:
gpu_peak_flops = float('inf') # MFU not meaningful for CPU/MPS
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=args.run, config=user_config)
# Flash Attention status
if not HAS_FA3:
print0("WARNING: Flash Attention 3 not available, using PyTorch SDPA fallback. Training will be less efficient.")
# Load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="train", model_tag=args.model_tag, step=args.model_step)
pretrain_batch_size = meta.get("device_batch_size", None)
if pretrain_batch_size is not None and args.device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device-batch-size to this script?")
# Inherit training hyperparameters from pretrained checkpoint (None = inherit, explicit value = override)
pretrain_user_config = meta.get("user_config", {})
for name, fallback, source in [
("max_seq_len", 2048, meta),
("device_batch_size", 32, meta),
("total_batch_size", 524288, meta),
("embedding_lr", 0.3, pretrain_user_config),
("unembedding_lr", 0.004, pretrain_user_config),
("matrix_lr", 0.02, pretrain_user_config),
]:
arg_val = getattr(args, name)
pretrain_val = source.get(name)
if arg_val is None:
resolved = pretrain_val if pretrain_val is not None else fallback
setattr(args, name, resolved)
print0(f"Inherited {name}={resolved} from pretrained checkpoint")
elif pretrain_val is not None and arg_val != pretrain_val:
print0(f"NOTE: --{name.replace('_', '-')}={arg_val} overrides pretrained value of {pretrain_val}")
else:
print0(f"Using {name}={arg_val}")
orig_model = model
model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
@ -94,14 +128,23 @@ print0(f"Total batch size {args.total_batch_size:,} => gradient accumulation ste
token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=args.weight_decay)
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
base_dir = get_base_dir()
if args.load_optimizer:
optimizer_data = load_optimizer_state("base", device, rank=ddp_rank, model_tag=args.model_tag, step=args.model_step)
optimizer.load_state_dict(optimizer_data)
del optimizer_data
print0("Loaded optimizer state from pretrained checkpoint")
# Override the initial learning rate as a fraction of the base learning rate
for group in optimizer.param_groups:
group["lr"] = group["lr"] * args.init_lr_frac
group["initial_lr"] = group["lr"]
# SFT data mixture and DataLoader
base_dir = get_base_dir()
identity_conversations_filepath = os.path.join(base_dir, "identity_conversations.jsonl")
train_dataset = TaskMixture([
SmolTalk(split="train"), # 460K rows of general conversations
@ -236,10 +279,17 @@ train_loader = sft_data_generator_bos_bestfit("train")
build_val_loader = lambda: sft_data_generator_bos_bestfit("val")
progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
# Learning rate schedule (linear warmup, constant, linear warmdown)
# Same shape as base_train but uses progress (0→1) instead of absolute step counts,
# because SFT doesn't always know num_iterations in advance (dataset-driven stopping).
def get_lr_multiplier(progress):
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
if progress < args.warmup_ratio:
return (progress + 1e-8) / args.warmup_ratio
elif progress <= 1.0 - args.warmdown_ratio:
return 1.0
else:
decay = (progress - (1.0 - args.warmdown_ratio)) / args.warmdown_ratio
return (1 - decay) * 1.0 + decay * args.final_lr_frac
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -282,8 +332,44 @@ while True:
})
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step and not args.dry_run:
# once in a while: estimate the ChatCORE metric (all ranks participate)
# use the original uncompiled model because the inputs keep changing shape
chatcore_results = {}
if args.chatcore_every > 0 and (last_step or (step > 0 and step % args.chatcore_every == 0)):
model.eval()
engine = Engine(orig_model, tokenizer)
all_tasks = ['ARC-Easy', 'ARC-Challenge', 'MMLU', 'GSM8K', 'HumanEval', 'SpellingBee']
categorical_tasks = {'ARC-Easy', 'ARC-Challenge', 'MMLU'}
baseline_accuracies = {
'ARC-Easy': 0.25, 'ARC-Challenge': 0.25, 'MMLU': 0.25,
'GSM8K': 0.0, 'HumanEval': 0.0, 'SpellingBee': 0.0,
}
task_results = {}
for task_name in all_tasks:
limit = args.chatcore_max_cat if task_name in categorical_tasks else args.chatcore_max_sample
max_problems = None if limit < 0 else limit # -1 means no limit
with autocast_ctx:
acc = run_chat_eval(task_name, orig_model, tokenizer, engine,
batch_size=args.device_batch_size, max_problems=max_problems)
task_results[task_name] = acc
print0(f" {task_name}: {100*acc:.2f}%")
# Compute ChatCORE metrics (mean centered accuracy, ranges from 0=random to 1=perfect)
def centered_mean(tasks):
return sum((task_results[t] - baseline_accuracies[t]) / (1.0 - baseline_accuracies[t]) for t in tasks) / len(tasks)
chatcore = centered_mean(all_tasks)
chatcore_cat = centered_mean(categorical_tasks)
print0(f"Step {step:05d} | ChatCORE: {chatcore:.4f} | ChatCORE_cat: {chatcore_cat:.4f}")
wandb_run.log({
"step": step,
"total_training_flops": flops_so_far,
"chatcore_metric": chatcore,
"chatcore_cat": chatcore_cat,
**{f"chatcore/{task_name}": acc for task_name, acc in task_results.items()},
})
model.train()
# save checkpoint at the end of the run (all ranks participate so each saves its optimizer shard)
if last_step:
output_dirname = args.model_tag if args.model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "chatsft_checkpoints", output_dirname)
save_checkpoint(
@ -304,7 +390,8 @@ while True:
"window_pattern": model.config.window_pattern,
},
"user_config": user_config, # inputs to the training script
}
},
rank=ddp_rank,
)
if last_step:
@ -346,8 +433,7 @@ while True:
pct_done = 100 * progress
tok_per_sec = int(args.total_batch_size / dt)
flops_per_sec = num_flops_per_token * args.total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
mfu = 100 * flops_per_sec / (gpu_peak_flops * ddp_world_size)
if step > 10:
total_training_time += dt # only count the time after the first 10 steps
print0(f"step {step:05d} ({pct_done:.2f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt * 1000:.2f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.2f} | epoch: {current_epoch} | total time: {total_training_time/60:.2f}m")
@ -364,24 +450,32 @@ while True:
"train/epoch": current_epoch,
})
# The garbage collector spends ~500ms scanning for cycles quite frequently.
# We manually manage it to avoid these pauses during training.
if step == 1:
gc.collect() # manually collect a lot of garbage from setup
gc.freeze() # freeze all currently surviving objects and exclude them from GC
gc.disable() # disable GC entirely except:
elif step % 5000 == 0: # every 5000 steps...
gc.collect() # manually collect, just to be safe for very long runs
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
if not args.dry_run:
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
from nanochat.report import get_report
get_report().log(section="SFT", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish() # wandb run finish