Merge branch 'master' into rocm-support

This commit is contained in:
Lawrence R Kincheloe III 2025-10-16 18:47:37 -05:00 committed by GitHub
commit 1bbba0f0d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 574 additions and 106 deletions

View File

@ -6,6 +6,10 @@
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
## Talk to it
To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
@ -129,6 +133,10 @@ The `speedrun.sh` script has been configured to run on a single GPU by default,
If your GPU has less than 80GB of VRAM, you may need to reduce the `device_batch_size` in the training scripts to avoid running out of memory. This will increase training time but will allow the model to train successfully on lower-VRAM cards.
## Running on CPU / MPS
If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:

View File

@ -108,6 +108,15 @@
background: transparent;
border: none;
padding: 0.25rem 0;
cursor: pointer;
border-radius: 0.5rem;
padding: 0.5rem;
margin-left: -0.5rem;
transition: background-color 0.2s ease;
}
.message.assistant .message-content:hover {
background-color: #f9fafb;
}
.message.user .message-content {
@ -115,6 +124,21 @@
border-radius: 1.25rem;
padding: 0.8rem 1rem;
max-width: 65%;
cursor: pointer;
transition: background-color 0.2s ease;
}
.message.user .message-content:hover {
background-color: #e5e7eb;
}
.message.console .message-content {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
font-size: 0.875rem;
background-color: #fafafa;
padding: 0.75rem 1rem;
color: #374151;
max-width: 80%;
}
.input-container {
@ -255,6 +279,8 @@
let messages = [];
let isGenerating = false;
let currentTemperature = 0.8;
let currentTopK = 50;
chatInput.addEventListener('input', function() {
this.style.height = 'auto';
@ -289,7 +315,7 @@
chatInput.focus();
}
function addMessage(role, content) {
function addMessage(role, content, messageIndex = null) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
@ -297,6 +323,28 @@
contentDiv.className = 'message-content';
contentDiv.textContent = content;
// Add click handler for user messages to enable editing
if (role === 'user' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to edit and restart from here');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
editMessage(messageIndex);
}
});
}
// Add click handler for assistant messages to enable regeneration
if (role === 'assistant' && messageIndex !== null) {
contentDiv.setAttribute('data-message-index', messageIndex);
contentDiv.setAttribute('title', 'Click to regenerate this response');
contentDiv.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(messageIndex);
}
});
}
messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv);
@ -304,17 +352,35 @@
return contentDiv;
}
async function sendMessage() {
const message = chatInput.value.trim();
if (!message || isGenerating) return;
function editMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
isGenerating = true;
chatInput.value = '';
const messageToEdit = messages[messageIndex];
if (messageToEdit.role !== 'user') return;
// Copy message content to input
chatInput.value = messageToEdit.content;
chatInput.style.height = 'auto';
sendButton.disabled = true;
chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
messages.push({ role: 'user', content: message });
addMessage('user', message);
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Enable send button and focus input
sendButton.disabled = false;
chatInput.focus();
}
async function generateAssistantResponse() {
isGenerating = true;
sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '<span class="typing-indicator"></span>';
@ -327,8 +393,8 @@
},
body: JSON.stringify({
messages: messages,
stream: true,
temperature: 0.8,
temperature: currentTemperature,
top_k: currentTopK,
max_tokens: 512
}),
});
@ -364,8 +430,18 @@
}
}
const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
// Add click handler to regenerate this assistant message
assistantContent.setAttribute('data-message-index', assistantMessageIndex);
assistantContent.setAttribute('title', 'Click to regenerate this response');
assistantContent.addEventListener('click', function() {
if (!isGenerating) {
regenerateMessage(assistantMessageIndex);
}
});
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `<div class="error-message">Error: ${error.message}</div>`;
@ -375,6 +451,97 @@
}
}
async function regenerateMessage(messageIndex) {
// Find the message in the messages array
if (messageIndex < 0 || messageIndex >= messages.length) return;
const messageToRegenerate = messages[messageIndex];
if (messageToRegenerate.role !== 'assistant') return;
// Remove this message and all subsequent messages from the array
messages = messages.slice(0, messageIndex);
// Remove message elements from DOM starting from messageIndex
const allMessages = chatWrapper.querySelectorAll('.message');
for (let i = messageIndex; i < allMessages.length; i++) {
allMessages[i].remove();
}
// Regenerate the assistant response
await generateAssistantResponse();
}
function handleSlashCommand(command) {
const parts = command.trim().split(/\s+/);
const cmd = parts[0].toLowerCase();
const arg = parts[1];
if (cmd === '/temperature') {
if (arg === undefined) {
addMessage('console', `Current temperature: ${currentTemperature}`);
} else {
const temp = parseFloat(arg);
if (isNaN(temp) || temp < 0 || temp > 2) {
addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
} else {
currentTemperature = temp;
addMessage('console', `Temperature set to ${currentTemperature}`);
}
}
return true;
} else if (cmd === '/topk') {
if (arg === undefined) {
addMessage('console', `Current top-k: ${currentTopK}`);
} else {
const topk = parseInt(arg);
if (isNaN(topk) || topk < 1 || topk > 200) {
addMessage('console', 'Invalid top-k. Must be between 1 and 200');
} else {
currentTopK = topk;
addMessage('console', `Top-k set to ${currentTopK}`);
}
}
return true;
} else if (cmd === '/clear') {
newConversation();
return true;
} else if (cmd === '/help') {
addMessage('console',
'Available commands:\n' +
'/temperature - Show current temperature\n' +
'/temperature <value> - Set temperature (0.0-2.0)\n' +
'/topk - Show current top-k\n' +
'/topk <value> - Set top-k (1-200)\n' +
'/clear - Clear conversation\n' +
'/help - Show this help message'
);
return true;
}
return false;
}
async function sendMessage() {
const message = chatInput.value.trim();
if (!message || isGenerating) return;
// Handle slash commands
if (message.startsWith('/')) {
chatInput.value = '';
chatInput.style.height = 'auto';
handleSlashCommand(message);
return;
}
chatInput.value = '';
chatInput.style.height = 'auto';
const userMessageIndex = messages.length;
messages.push({ role: 'user', content: message });
addMessage('user', message, userMessageIndex);
await generateAssistantResponse();
}
sendButton.disabled = false;
// Autofocus the chat input on page load

94
run1000.sh Normal file
View File

@ -0,0 +1,94 @@
# The $1000 tier of nanochat
# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
# A bit sparser on comments, see speedrun.sh for more detail
# all the setup stuff
export OMP_NUM_THREADS=1
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
[ -d ".venv" ] || uv venv
uv sync
source .venv/bin/activate
if [ -z "$WANDB_RUN" ]; then
WANDB_RUN=dummy
fi
python -m nanochat.report reset
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
source "$HOME/.cargo/env"
uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
unzip -q eval_bundle.zip
rm eval_bundle.zip
mv eval_bundle $NANOCHAT_BASE_DIR
fi
# train tokenizer on ~4B characters and kick off download of the rest for pretraining
python -m nanochat.dataset -n 16
# start downloading the rest of the shards for a total of 800 (see below why 800)
python -m nanochat.dataset -n 800 &
# todo: download the rest of it
python -m scripts.tok_train --max_chars=4000000000
python -m scripts.tok_eval
# Documenting my process for determining the hyperparameters for this run1000.sh script:
# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
# 1) I guessed the model size for this to be about depth=32
# 2) Determine the device_batch_size that fits:
# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
# So the training script was running ok and showed:
# Vocab size: 65,536
# num_layers: 32
# model_dim: 2048
# num_heads: 16
# num_kv_heads: 16
# Tokens / micro-batch / rank: 8 x 2048 = 16,384
# Tokens / micro-batch: 131,072
# Total batch size 524,288 => gradient accumulation steps: 4
# Number of parameters: 1,879,048,192
# Estimated FLOPs per token: 1.207960e+10
# Calculated number of iterations from target data:param ratio: 71,680
# Total number of training tokens: 37,580,963,840
# Tokens : Params ratio: 20.00
# Total training FLOPs estimate: 4.539628e+20
# step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
# step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
# ...
# 3) validate that the runtime fits our budget:
# The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
# The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
# estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
# This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
# It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
# 4) The last thing to pay attention to is the amount of training data required for the run.
# The script above calculated that "Total number of training tokens: 37,580,963,840"
# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
# So ~38B tokens # ~4.8 chars/token = ~185B chars.
# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard.
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
# sft
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# generate final report
python -m nanochat.report generate
# talk to it
python -m scripts.chat_web

View File

@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import copy
import wandb
import torch
@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
from tasks.common import TaskMixture, TaskSequence
from tasks.mmlu import MMLU
from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
from tasks.humaneval import HumanEval
from tasks.smoltalk import SmolTalk
# -----------------------------------------------------------------------------
@ -186,7 +183,7 @@ for step in range(num_iterations):
})
model.train()
# evlauate MMLU accuracy
# evlauate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}
@ -194,8 +191,6 @@ for step in range(num_iterations):
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({

View File

@ -1,26 +1,67 @@
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
Run with: python web_chat.py
Then open http://localhost:8000 in your browser.
Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
a full copy of the model, and incoming requests are distributed to available workers.
Launch examples:
- single available GPU (default)
python -m scripts.chat_web
- 4 GPUs
python -m scripts.chat_web --num-gpus 4
To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
Endpoints:
GET / - Chat UI
POST /chat/completions - Chat API (streaming only)
GET /health - Health check with worker pool status
GET /stats - Worker pool statistics and GPU utilization
Abuse Prevention:
- Maximum 500 messages per request
- Maximum 8000 characters per message
- Maximum 32000 characters total conversation length
- Temperature clamped to 0.0-2.0
- Top-k clamped to 1-200
- Max tokens clamped to 1-4096
"""
import argparse
import json
import os
import torch
import asyncio
import logging
import random
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
# Abuse prevention limits
MAX_MESSAGES_PER_REQUEST = 500
MAX_MESSAGE_LENGTH = 8000
MAX_TOTAL_CONVERSATION_LENGTH = 32000
MIN_TEMPERATURE = 0.0
MAX_TEMPERATURE = 2.0
MIN_TOP_K = 1
MAX_TOP_K = 200
MIN_MAX_TOKENS = 1
MAX_MAX_TOKENS = 4096
parser = argparse.ArgumentParser(description='NanoChat Web Server')
parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
@ -31,8 +72,64 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
# Configure logging for conversation traffic
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
@dataclass
class Worker:
"""A worker with a model loaded on a specific GPU."""
gpu_id: int
device: torch.device
engine: Engine
tokenizer: object
autocast_ctx: torch.amp.autocast
class WorkerPool:
"""Pool of workers, each with a model replica on a different GPU."""
def __init__(self, num_gpus: Optional[int] = None):
self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
self.workers: List[Worker] = []
self.available_workers: asyncio.Queue = asyncio.Queue()
async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
for gpu_id in range(self.num_gpus):
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
worker = Worker(
gpu_id=gpu_id,
device=device,
engine=engine,
tokenizer=tokenizer,
autocast_ctx=autocast_ctx
)
self.workers.append(worker)
await self.available_workers.put(worker)
print(f"All {self.num_gpus} workers initialized!")
async def acquire_worker(self) -> Worker:
"""Get an available worker from the pool."""
return await self.available_workers.get()
async def release_worker(self, worker: Worker):
"""Return a worker to the pool."""
await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
@ -43,14 +140,76 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
stream: Optional[bool] = True
def validate_chat_request(request: ChatRequest):
"""Validate chat request to prevent abuse."""
# Check number of messages
if len(request.messages) == 0:
raise HTTPException(status_code=400, detail="At least one message is required")
if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
raise HTTPException(
status_code=400,
detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
)
# Check individual message lengths and total conversation length
total_length = 0
for i, message in enumerate(request.messages):
if not message.content:
raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
msg_length = len(message.content)
if msg_length > MAX_MESSAGE_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
)
total_length += msg_length
if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
raise HTTPException(
status_code=400,
detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
)
# Validate role values
for i, message in enumerate(request.messages):
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400,
detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
)
# Validate temperature
if request.temperature is not None:
if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
raise HTTPException(
status_code=400,
detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
)
# Validate top_k
if request.top_k is not None:
if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
raise HTTPException(
status_code=400,
detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
)
# Validate max_tokens
if request.max_tokens is not None:
if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
raise HTTPException(
status_code=400,
detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
)
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
print("Loading nanochat model...")
app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
app.state.engine = Engine(app.state.model, app.state.tokenizer)
"""Load models on all GPUs on startup."""
print("Loading nanochat models across GPUs...")
app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@ -85,8 +244,7 @@ async def logo():
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
engine,
tokenizer,
worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
@ -97,98 +255,141 @@ async def generate_stream(
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
assistant_end = tokenizer.encode_special("<|assistant_end|>")
bos = tokenizer.get_bos_token_id()
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
bos = worker.tokenizer.get_bos_token_id()
with autocast_ctx:
for token_column, token_masks in engine.generate(
# Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
accumulated_tokens = []
# Track the last complete UTF-8 string (without replacement characters)
last_clean_text = ""
with worker.autocast_ctx:
for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k
top_k=top_k,
seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
# Stopping criteria
if token == assistant_end or token == bos:
break
token_text = tokenizer.decode([token])
yield f"data: {json.dumps({'token': token_text})}\n\n"
# Append the token to sequence
accumulated_tokens.append(token)
# Decode all accumulated tokens to get proper UTF-8 handling
# Note that decode is a quite efficient operation, basically table lookup and string concat
current_text = worker.tokenizer.decode(accumulated_tokens)
# Only emit text if it doesn't end with a replacement character
# This ensures we don't emit incomplete UTF-8 sequences
if not current_text.endswith('<EFBFBD>'):
# Extract only the new text since last clean decode
new_text = current_text[len(last_clean_text):]
if new_text: # Only yield if there's new content
yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
"""Chat completion endpoint with streaming."""
engine = app.state.engine
tokenizer = app.state.tokenizer
"""Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
# Build conversation tokens
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
# Basic validation to prevent abuse
validate_chat_request(request)
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
# Log incoming conversation to console
logger.info("="*20)
for i, message in enumerate(request.messages):
logger.info(f"[{message.role.upper()}]: {message.content}")
logger.info("-"*20)
conversation_tokens.append(assistant_start)
# Acquire a worker from the pool (will wait if all are busy)
worker_pool = app.state.worker_pool
worker = await worker_pool.acquire_worker()
try:
# Build conversation tokens
bos = worker.tokenizer.get_bos_token_id()
user_start = worker.tokenizer.encode_special("<|user_start|>")
user_end = worker.tokenizer.encode_special("<|user_end|>")
assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(worker.tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
conversation_tokens.append(assistant_start)
# Streaming response with worker release after completion
response_tokens = []
async def stream_and_release():
try:
async for chunk in generate_stream(
worker,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
):
# Accumulate response for logging
chunk_data = json.loads(chunk.replace("data: ", "").strip())
if "token" in chunk_data:
response_tokens.append(chunk_data["token"])
yield chunk
finally:
# Log the assistant response to console
full_response = "".join(response_tokens)
logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
logger.info("="*20)
# Release worker back to pool after streaming is done
await worker_pool.release_worker(worker)
if request.stream:
return StreamingResponse(
generate_stream(
engine,
tokenizer,
conversation_tokens,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
),
stream_and_release(),
media_type="text/event-stream"
)
else:
# Non-streaming response
temperature = request.temperature if request.temperature is not None else args.temperature
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
top_k = request.top_k if request.top_k is not None else args.top_k
with autocast_ctx:
result_tokens, masks = engine.generate_batch(
conversation_tokens,
num_samples=1,
max_tokens=max_tokens,
temperature=temperature,
top_k=top_k
)[0]
response_tokens = result_tokens[len(conversation_tokens):]
response_text = tokenizer.decode(response_tokens)
return {
"choices": [{
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}]
}
except Exception as e:
# Make sure to release worker even on error
await worker_pool.release_worker(worker)
raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
"ready": hasattr(app.state, 'model') and app.state.model is not None
"ready": worker_pool is not None and len(worker_pool.workers) > 0,
"num_gpus": worker_pool.num_gpus if worker_pool else 0,
"available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
}
@app.get("/stats")
async def stats():
"""Get worker pool statistics."""
worker_pool = app.state.worker_pool
return {
"total_workers": len(worker_pool.workers),
"available_workers": worker_pool.available_workers.qsize(),
"busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
"workers": [
{
"gpu_id": w.gpu_id,
"device": str(w.device)
} for w in worker_pool.workers
]
}
if __name__ == "__main__":

View File

@ -40,10 +40,10 @@ embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
eval_every = 150
eval_tokens = 20*524288
total_batch_size = 524288
dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
@ -141,7 +141,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
return progress * 1.0 + (1 - progress) * final_lr_frac
# first 80% of training: no decay, then linearly ramp down to 0.
return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@ -185,7 +186,7 @@ while True:
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
@ -273,17 +274,18 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
if not dry_run:
from nanochat.report import get_report
get_report().log(section="Midtraining", data=[
user_config, # CLI args
{ # stats about the training setup
"Number of iterations": step,
"DDP world size": ddp_world_size,
},
{ # stats about training outcomes
"Minimum validation bpb": min_val_bpb,
}
])
# cleanup
wandb_run.finish() # wandb run finish

View File

@ -17,6 +17,7 @@ export OMP_NUM_THREADS=1
# For example, for a gfx1151 GPU, we can use gfx1100 (11.0.0).
export HSA_OVERRIDE_GFX_VERSION=11.0.0
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------