diff --git a/README.md b/README.md
index 3da3c2a..30d9f4d 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,10 @@
This repo is a full-stack implementation of an LLM like ChatGPT in a single, clean, minimal, hackable, dependency-lite codebase. nanochat is designed to run on a single 8XH100 node via scripts like [speedrun.sh](speedrun.sh), that run the entire pipeline start to end. This includes tokenization, pretraining, finetuning, evaluation, inference, and web serving over a simple UI so that you can talk to your own LLM just like ChatGPT. nanochat will become the capstone project of the course LLM101n being developed by Eureka Labs.
+## Talk to it
+
+To get a sense of the endpoint of this repo, you can currently find [nanochat d32](https://github.com/karpathy/nanochat/discussions/8) hosted on [nanochat.karpathy.ai](https://nanochat.karpathy.ai/). "d32" means that this model has 32 layers in the Transformer neural network. This model has 1.9 billion parameters, it was trained on 38 billion tokens by simply running the single script [run1000.sh](run1000.sh), and the total cost of training was ~$800 (about 33 hours training time on 8XH100 GPU node). While today this is enough to outperform GPT-2 of 2019, it falls dramatically short of moden Large Language Models like GPT-5. When talking to these micro models, you'll see that they make a lot of mistakes, they are a little bit naive and silly and they hallucinate a ton, a bit like children. It's kind of amusing. But what makes nanochat unique is that it is fully yours - fully configurable, tweakable, hackable, and trained by you from start to end. To train and talk to your own, we turn to...
+
## Quick start
The fastest way to feel the magic is to run the speedrun script [speedrun.sh](speedrun.sh), which trains and inferences the $100 tier of nanochat. On an 8XH100 node at $24/hr, this gives a total run time of about 4 hours. Boot up a new 8XH100 GPU box from your favorite provider (e.g. I use and like [Lambda](https://lambda.ai/service/gpu-cloud)), and kick off the training script:
@@ -129,6 +133,10 @@ The `speedrun.sh` script has been configured to run on a single GPU by default,
If your GPU has less than 80GB of VRAM, you may need to reduce the `device_batch_size` in the training scripts to avoid running out of memory. This will increase training time but will allow the model to train successfully on lower-VRAM cards.
+## Running on CPU / MPS
+
+If you'd like to tinker with nanochat on your Macbook or a CPU machine, there is a work in progress [CPU|MPS PR](https://github.com/karpathy/nanochat/pull/88) up here. If you're on Macbook, use `--device_type=mps` when running `base_train.py`. See the PR and its diff for more. You're not going to get too far without GPU nodes, but at least you'll be able to run the code and maybe train a very tiny LLM with some patience.
+
## Questions
nanochat is designed to be short and sweet. One big advantage of this is that we can package up all of the files together and copy paste them to your favorite LLM to ask arbitrary questions. As an example, I like to package up the repo using the [files-to-prompt](https://github.com/simonw/files-to-prompt) utility like so:
diff --git a/nanochat/ui.html b/nanochat/ui.html
index 39e608f..b2b4605 100644
--- a/nanochat/ui.html
+++ b/nanochat/ui.html
@@ -108,6 +108,15 @@
background: transparent;
border: none;
padding: 0.25rem 0;
+ cursor: pointer;
+ border-radius: 0.5rem;
+ padding: 0.5rem;
+ margin-left: -0.5rem;
+ transition: background-color 0.2s ease;
+ }
+
+ .message.assistant .message-content:hover {
+ background-color: #f9fafb;
}
.message.user .message-content {
@@ -115,6 +124,21 @@
border-radius: 1.25rem;
padding: 0.8rem 1rem;
max-width: 65%;
+ cursor: pointer;
+ transition: background-color 0.2s ease;
+ }
+
+ .message.user .message-content:hover {
+ background-color: #e5e7eb;
+ }
+
+ .message.console .message-content {
+ font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', 'Consolas', 'Courier New', monospace;
+ font-size: 0.875rem;
+ background-color: #fafafa;
+ padding: 0.75rem 1rem;
+ color: #374151;
+ max-width: 80%;
}
.input-container {
@@ -255,6 +279,8 @@
let messages = [];
let isGenerating = false;
+ let currentTemperature = 0.8;
+ let currentTopK = 50;
chatInput.addEventListener('input', function() {
this.style.height = 'auto';
@@ -289,7 +315,7 @@
chatInput.focus();
}
- function addMessage(role, content) {
+ function addMessage(role, content, messageIndex = null) {
const messageDiv = document.createElement('div');
messageDiv.className = `message ${role}`;
@@ -297,6 +323,28 @@
contentDiv.className = 'message-content';
contentDiv.textContent = content;
+ // Add click handler for user messages to enable editing
+ if (role === 'user' && messageIndex !== null) {
+ contentDiv.setAttribute('data-message-index', messageIndex);
+ contentDiv.setAttribute('title', 'Click to edit and restart from here');
+ contentDiv.addEventListener('click', function() {
+ if (!isGenerating) {
+ editMessage(messageIndex);
+ }
+ });
+ }
+
+ // Add click handler for assistant messages to enable regeneration
+ if (role === 'assistant' && messageIndex !== null) {
+ contentDiv.setAttribute('data-message-index', messageIndex);
+ contentDiv.setAttribute('title', 'Click to regenerate this response');
+ contentDiv.addEventListener('click', function() {
+ if (!isGenerating) {
+ regenerateMessage(messageIndex);
+ }
+ });
+ }
+
messageDiv.appendChild(contentDiv);
chatWrapper.appendChild(messageDiv);
@@ -304,17 +352,35 @@
return contentDiv;
}
- async function sendMessage() {
- const message = chatInput.value.trim();
- if (!message || isGenerating) return;
+ function editMessage(messageIndex) {
+ // Find the message in the messages array
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
- isGenerating = true;
- chatInput.value = '';
+ const messageToEdit = messages[messageIndex];
+ if (messageToEdit.role !== 'user') return;
+
+ // Copy message content to input
+ chatInput.value = messageToEdit.content;
chatInput.style.height = 'auto';
- sendButton.disabled = true;
+ chatInput.style.height = Math.min(chatInput.scrollHeight, 200) + 'px';
- messages.push({ role: 'user', content: message });
- addMessage('user', message);
+ // Remove this message and all subsequent messages from the array
+ messages = messages.slice(0, messageIndex);
+
+ // Remove message elements from DOM starting from messageIndex
+ const allMessages = chatWrapper.querySelectorAll('.message');
+ for (let i = messageIndex; i < allMessages.length; i++) {
+ allMessages[i].remove();
+ }
+
+ // Enable send button and focus input
+ sendButton.disabled = false;
+ chatInput.focus();
+ }
+
+ async function generateAssistantResponse() {
+ isGenerating = true;
+ sendButton.disabled = true;
const assistantContent = addMessage('assistant', '');
assistantContent.innerHTML = '';
@@ -327,8 +393,8 @@
},
body: JSON.stringify({
messages: messages,
- stream: true,
- temperature: 0.8,
+ temperature: currentTemperature,
+ top_k: currentTopK,
max_tokens: 512
}),
});
@@ -364,8 +430,18 @@
}
}
+ const assistantMessageIndex = messages.length;
messages.push({ role: 'assistant', content: fullResponse });
+ // Add click handler to regenerate this assistant message
+ assistantContent.setAttribute('data-message-index', assistantMessageIndex);
+ assistantContent.setAttribute('title', 'Click to regenerate this response');
+ assistantContent.addEventListener('click', function() {
+ if (!isGenerating) {
+ regenerateMessage(assistantMessageIndex);
+ }
+ });
+
} catch (error) {
console.error('Error:', error);
assistantContent.innerHTML = `
Error: ${error.message}
`;
@@ -375,6 +451,97 @@
}
}
+ async function regenerateMessage(messageIndex) {
+ // Find the message in the messages array
+ if (messageIndex < 0 || messageIndex >= messages.length) return;
+
+ const messageToRegenerate = messages[messageIndex];
+ if (messageToRegenerate.role !== 'assistant') return;
+
+ // Remove this message and all subsequent messages from the array
+ messages = messages.slice(0, messageIndex);
+
+ // Remove message elements from DOM starting from messageIndex
+ const allMessages = chatWrapper.querySelectorAll('.message');
+ for (let i = messageIndex; i < allMessages.length; i++) {
+ allMessages[i].remove();
+ }
+
+ // Regenerate the assistant response
+ await generateAssistantResponse();
+ }
+
+ function handleSlashCommand(command) {
+ const parts = command.trim().split(/\s+/);
+ const cmd = parts[0].toLowerCase();
+ const arg = parts[1];
+
+ if (cmd === '/temperature') {
+ if (arg === undefined) {
+ addMessage('console', `Current temperature: ${currentTemperature}`);
+ } else {
+ const temp = parseFloat(arg);
+ if (isNaN(temp) || temp < 0 || temp > 2) {
+ addMessage('console', 'Invalid temperature. Must be between 0.0 and 2.0');
+ } else {
+ currentTemperature = temp;
+ addMessage('console', `Temperature set to ${currentTemperature}`);
+ }
+ }
+ return true;
+ } else if (cmd === '/topk') {
+ if (arg === undefined) {
+ addMessage('console', `Current top-k: ${currentTopK}`);
+ } else {
+ const topk = parseInt(arg);
+ if (isNaN(topk) || topk < 1 || topk > 200) {
+ addMessage('console', 'Invalid top-k. Must be between 1 and 200');
+ } else {
+ currentTopK = topk;
+ addMessage('console', `Top-k set to ${currentTopK}`);
+ }
+ }
+ return true;
+ } else if (cmd === '/clear') {
+ newConversation();
+ return true;
+ } else if (cmd === '/help') {
+ addMessage('console',
+ 'Available commands:\n' +
+ '/temperature - Show current temperature\n' +
+ '/temperature - Set temperature (0.0-2.0)\n' +
+ '/topk - Show current top-k\n' +
+ '/topk - Set top-k (1-200)\n' +
+ '/clear - Clear conversation\n' +
+ '/help - Show this help message'
+ );
+ return true;
+ }
+ return false;
+ }
+
+ async function sendMessage() {
+ const message = chatInput.value.trim();
+ if (!message || isGenerating) return;
+
+ // Handle slash commands
+ if (message.startsWith('/')) {
+ chatInput.value = '';
+ chatInput.style.height = 'auto';
+ handleSlashCommand(message);
+ return;
+ }
+
+ chatInput.value = '';
+ chatInput.style.height = 'auto';
+
+ const userMessageIndex = messages.length;
+ messages.push({ role: 'user', content: message });
+ addMessage('user', message, userMessageIndex);
+
+ await generateAssistantResponse();
+ }
+
sendButton.disabled = false;
// Autofocus the chat input on page load
diff --git a/run1000.sh b/run1000.sh
new file mode 100644
index 0000000..7d41327
--- /dev/null
+++ b/run1000.sh
@@ -0,0 +1,94 @@
+# The $1000 tier of nanochat
+# Designed to run end-to-end for $1000/24 ~= 41.6 hours on an 8XH100 node
+# A bit sparser on comments, see speedrun.sh for more detail
+
+# all the setup stuff
+export OMP_NUM_THREADS=1
+NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
+mkdir -p $NANOCHAT_BASE_DIR
+command -v uv &> /dev/null || curl -LsSf https://astral.sh/uv/install.sh | sh
+[ -d ".venv" ] || uv venv
+uv sync
+source .venv/bin/activate
+if [ -z "$WANDB_RUN" ]; then
+ WANDB_RUN=dummy
+fi
+python -m nanochat.report reset
+curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
+source "$HOME/.cargo/env"
+uv run maturin develop --release --manifest-path rustbpe/Cargo.toml
+EVAL_BUNDLE_URL=https://karpathy-public.s3.us-west-2.amazonaws.com/eval_bundle.zip
+if [ ! -d "$NANOCHAT_BASE_DIR/eval_bundle" ]; then
+ curl -L -o eval_bundle.zip $EVAL_BUNDLE_URL
+ unzip -q eval_bundle.zip
+ rm eval_bundle.zip
+ mv eval_bundle $NANOCHAT_BASE_DIR
+fi
+
+# train tokenizer on ~4B characters and kick off download of the rest for pretraining
+python -m nanochat.dataset -n 16
+# start downloading the rest of the shards for a total of 800 (see below why 800)
+python -m nanochat.dataset -n 800 &
+# todo: download the rest of it
+python -m scripts.tok_train --max_chars=4000000000
+python -m scripts.tok_eval
+
+# Documenting my process for determining the hyperparameters for this run1000.sh script:
+# We want a budget of approx. $1000 ~= 41.6 hours of 8XH100 compute
+# 1) I guessed the model size for this to be about depth=32
+# 2) Determine the device_batch_size that fits:
+# Running the base_train.py script with --depth=32, I saw that --device_batch_size=16
+# runs out of memory, but --device_batch_size=8 fits. Inspecting `nvidia-smi` during training,
+# I saw all GPUs were at about 78/80GB VRAM, so it just barely fits and we have good MFU at ~50%.
+# So the training script was running ok and showed:
+# Vocab size: 65,536
+# num_layers: 32
+# model_dim: 2048
+# num_heads: 16
+# num_kv_heads: 16
+# Tokens / micro-batch / rank: 8 x 2048 = 16,384
+# Tokens / micro-batch: 131,072
+# Total batch size 524,288 => gradient accumulation steps: 4
+# Number of parameters: 1,879,048,192
+# Estimated FLOPs per token: 1.207960e+10
+# Calculated number of iterations from target data:param ratio: 71,680
+# Total number of training tokens: 37,580,963,840
+# Tokens : Params ratio: 20.00
+# Total training FLOPs estimate: 4.539628e+20
+# step 00004/71680 (0.01%) | loss: 8.813754 | lrm: 1.00 | dt: 1571.88ms | tok/sec: 83,385 | mfu: 50.92 | total time: 0.00m
+# step 00005/71680 (0.01%) | loss: 8.488074 | lrm: 1.00 | dt: 1572.76ms | tok/sec: 83,338 | mfu: 50.89 | total time: 0.00m
+# ...
+# 3) validate that the runtime fits our budget:
+# The training script uses the Chinchilla scaling law to compute-optimally set #tokens = 20 * #params. In particular:
+# The script shows that we will be training for 71,680 steps, and each step takes 1.574s so:
+# estimated time to train: 71,680 * 1.574s / 60 / 60 = 31.3 hours.
+# This is OK, fits our budget, and leaves ~10 hours for midtraining and SFT and evals and maybe RL.
+# It's possible that we might even fit depth=33 or depth=34, but for now let's go along with this.
+# 4) The last thing to pay attention to is the amount of training data required for the run.
+# The script above calculated that "Total number of training tokens: 37,580,963,840"
+# The tok_eval.py script reports about ~4.8 chars/token on average for the default tokenizer settings.
+# So ~38B tokens # ~4.8 chars/token = ~185B chars.
+# Each data shard is ~250M chars, so we need ~185B / 250M ~= 740 shards.
+# For safety, I bumped that up to 800 shards, and that's why up above I used -n 800 when pre-downloading dataset shards.
+# If we didn't have enough data, the training script would loop around and do multiple epochs over the same data,
+# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
+# start to overfit hard.
+# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
+torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8
+torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
+torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
+
+# midtrain
+# NOTE: ensure that we use the same device_batch_size here as the base training script.
+torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
+torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
+
+# sft
+torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
+torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
+
+# generate final report
+python -m nanochat.report generate
+
+# talk to it
+python -m scripts.chat_web
diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py
index 0b5be36..433f9ba 100644
--- a/scripts/chat_sft.py
+++ b/scripts/chat_sft.py
@@ -11,7 +11,6 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
-import copy
import wandb
import torch
@@ -23,11 +22,9 @@ from nanochat.checkpoint_manager import save_checkpoint
from nanochat.engine import Engine
from scripts.chat_eval import run_chat_eval
-from tasks.common import TaskMixture, TaskSequence
-from tasks.mmlu import MMLU
+from tasks.common import TaskMixture
from tasks.arc import ARC
from tasks.gsm8k import GSM8K
-from tasks.humaneval import HumanEval
from tasks.smoltalk import SmolTalk
# -----------------------------------------------------------------------------
@@ -186,7 +183,7 @@ for step in range(num_iterations):
})
model.train()
- # evlauate MMLU accuracy
+ # evlauate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}
@@ -194,8 +191,6 @@ for step in range(num_iterations):
# note that because these are inside no_grad, we can usually afford to at least ~2X the batch size
metrics["mmlu_acc"] = run_chat_eval("MMLU", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
metrics["arc_easy_acc"] = run_chat_eval("ARC-Easy", model, tokenizer, engine, batch_size=device_batch_size*2, max_problems=1024)
- metrics["gsm8k_acc"] = run_chat_eval("GSM8K", model, tokenizer, engine, max_problems=64)
- metrics["humaneval_acc"] = run_chat_eval("HumanEval", model, tokenizer, engine, max_problems=64)
metrics_str = ', '.join(f'{k}: {v:.6f}' for k, v in metrics.items())
print0(f"Step {step:05d} | {metrics_str}")
wandb_run.log({
diff --git a/scripts/chat_web.py b/scripts/chat_web.py
index 55abc79..c07725e 100644
--- a/scripts/chat_web.py
+++ b/scripts/chat_web.py
@@ -1,26 +1,67 @@
#!/usr/bin/env python3
"""
Unified web chat server - serves both UI and API from a single FastAPI instance.
-Run with: python web_chat.py
-Then open http://localhost:8000 in your browser.
+
+Uses data parallelism to distribute requests across multiple GPUs. Each GPU loads
+a full copy of the model, and incoming requests are distributed to available workers.
+
+Launch examples:
+
+- single available GPU (default)
+python -m scripts.chat_web
+
+- 4 GPUs
+python -m scripts.chat_web --num-gpus 4
+
+To chat, open the URL printed in the console. (If on cloud box, make sure to use public IP)
+
+Endpoints:
+ GET / - Chat UI
+ POST /chat/completions - Chat API (streaming only)
+ GET /health - Health check with worker pool status
+ GET /stats - Worker pool statistics and GPU utilization
+
+Abuse Prevention:
+ - Maximum 500 messages per request
+ - Maximum 8000 characters per message
+ - Maximum 32000 characters total conversation length
+ - Temperature clamped to 0.0-2.0
+ - Top-k clamped to 1-200
+ - Max tokens clamped to 1-4096
"""
import argparse
import json
import os
import torch
+import asyncio
+import logging
+import random
from contextlib import asynccontextmanager
-from fastapi import FastAPI
+from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
+from dataclasses import dataclass
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
+# Abuse prevention limits
+MAX_MESSAGES_PER_REQUEST = 500
+MAX_MESSAGE_LENGTH = 8000
+MAX_TOTAL_CONVERSATION_LENGTH = 32000
+MIN_TEMPERATURE = 0.0
+MAX_TEMPERATURE = 2.0
+MIN_TOP_K = 1
+MAX_TOP_K = 200
+MIN_MAX_TOKENS = 1
+MAX_MAX_TOKENS = 4096
+
parser = argparse.ArgumentParser(description='NanoChat Web Server')
+parser.add_argument('-n', '--num-gpus', type=int, default=1, help='Number of GPUs to use (default: 1)')
parser.add_argument('-i', '--source', type=str, default="sft", help="Source of the model: sft|mid|rl")
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
@@ -31,8 +72,64 @@ parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run th
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
+# Configure logging for conversation traffic
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+)
+logger = logging.getLogger(__name__)
+
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
-autocast_ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16)
+
+@dataclass
+class Worker:
+ """A worker with a model loaded on a specific GPU."""
+ gpu_id: int
+ device: torch.device
+ engine: Engine
+ tokenizer: object
+ autocast_ctx: torch.amp.autocast
+
+class WorkerPool:
+ """Pool of workers, each with a model replica on a different GPU."""
+
+ def __init__(self, num_gpus: Optional[int] = None):
+ self.num_gpus = num_gpus if num_gpus is not None else torch.cuda.device_count()
+ self.workers: List[Worker] = []
+ self.available_workers: asyncio.Queue = asyncio.Queue()
+
+ async def initialize(self, source: str, model_tag: Optional[str] = None, step: Optional[int] = None):
+ """Load model on each GPU."""
+ print(f"Initializing worker pool with {self.num_gpus} GPUs...")
+
+ for gpu_id in range(self.num_gpus):
+ device = torch.device(f"cuda:{gpu_id}")
+ print(f"Loading model on GPU {gpu_id}...")
+
+ model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
+ engine = Engine(model, tokenizer)
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
+
+ worker = Worker(
+ gpu_id=gpu_id,
+ device=device,
+ engine=engine,
+ tokenizer=tokenizer,
+ autocast_ctx=autocast_ctx
+ )
+ self.workers.append(worker)
+ await self.available_workers.put(worker)
+
+ print(f"All {self.num_gpus} workers initialized!")
+
+ async def acquire_worker(self) -> Worker:
+ """Get an available worker from the pool."""
+ return await self.available_workers.get()
+
+ async def release_worker(self, worker: Worker):
+ """Return a worker to the pool."""
+ await self.available_workers.put(worker)
class ChatMessage(BaseModel):
role: str
@@ -43,14 +140,76 @@ class ChatRequest(BaseModel):
temperature: Optional[float] = None
max_tokens: Optional[int] = None
top_k: Optional[int] = None
- stream: Optional[bool] = True
+
+def validate_chat_request(request: ChatRequest):
+ """Validate chat request to prevent abuse."""
+ # Check number of messages
+ if len(request.messages) == 0:
+ raise HTTPException(status_code=400, detail="At least one message is required")
+ if len(request.messages) > MAX_MESSAGES_PER_REQUEST:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Too many messages. Maximum {MAX_MESSAGES_PER_REQUEST} messages allowed per request"
+ )
+
+ # Check individual message lengths and total conversation length
+ total_length = 0
+ for i, message in enumerate(request.messages):
+ if not message.content:
+ raise HTTPException(status_code=400, detail=f"Message {i} has empty content")
+
+ msg_length = len(message.content)
+ if msg_length > MAX_MESSAGE_LENGTH:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Message {i} is too long. Maximum {MAX_MESSAGE_LENGTH} characters allowed per message"
+ )
+ total_length += msg_length
+
+ if total_length > MAX_TOTAL_CONVERSATION_LENGTH:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Total conversation is too long. Maximum {MAX_TOTAL_CONVERSATION_LENGTH} characters allowed"
+ )
+
+ # Validate role values
+ for i, message in enumerate(request.messages):
+ if message.role not in ["user", "assistant"]:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Message {i} has invalid role. Must be 'user', 'assistant', or 'system'"
+ )
+
+ # Validate temperature
+ if request.temperature is not None:
+ if not (MIN_TEMPERATURE <= request.temperature <= MAX_TEMPERATURE):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Temperature must be between {MIN_TEMPERATURE} and {MAX_TEMPERATURE}"
+ )
+
+ # Validate top_k
+ if request.top_k is not None:
+ if not (MIN_TOP_K <= request.top_k <= MAX_TOP_K):
+ raise HTTPException(
+ status_code=400,
+ detail=f"top_k must be between {MIN_TOP_K} and {MAX_TOP_K}"
+ )
+
+ # Validate max_tokens
+ if request.max_tokens is not None:
+ if not (MIN_MAX_TOKENS <= request.max_tokens <= MAX_MAX_TOKENS):
+ raise HTTPException(
+ status_code=400,
+ detail=f"max_tokens must be between {MIN_MAX_TOKENS} and {MAX_MAX_TOKENS}"
+ )
@asynccontextmanager
async def lifespan(app: FastAPI):
- """Load model on startup."""
- print("Loading nanochat model...")
- app.state.model, app.state.tokenizer, _ = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
- app.state.engine = Engine(app.state.model, app.state.tokenizer)
+ """Load models on all GPUs on startup."""
+ print("Loading nanochat models across GPUs...")
+ app.state.worker_pool = WorkerPool(num_gpus=args.num_gpus)
+ await app.state.worker_pool.initialize(args.source, model_tag=args.model_tag, step=args.step)
print(f"Server ready at http://localhost:{args.port}")
yield
@@ -85,8 +244,7 @@ async def logo():
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
- engine,
- tokenizer,
+ worker: Worker,
tokens,
temperature=None,
max_new_tokens=None,
@@ -97,98 +255,141 @@ async def generate_stream(
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
top_k = top_k if top_k is not None else args.top_k
- assistant_end = tokenizer.encode_special("<|assistant_end|>")
- bos = tokenizer.get_bos_token_id()
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
+ bos = worker.tokenizer.get_bos_token_id()
- with autocast_ctx:
- for token_column, token_masks in engine.generate(
+ # Accumulate tokens to properly handle multi-byte UTF-8 characters (like emojis)
+ accumulated_tokens = []
+ # Track the last complete UTF-8 string (without replacement characters)
+ last_clean_text = ""
+
+ with worker.autocast_ctx:
+ for token_column, token_masks in worker.engine.generate(
tokens,
num_samples=1,
max_tokens=max_new_tokens,
temperature=temperature,
- top_k=top_k
+ top_k=top_k,
+ seed=random.randint(0, 2**31 - 1)
):
token = token_column[0]
+ # Stopping criteria
if token == assistant_end or token == bos:
break
- token_text = tokenizer.decode([token])
- yield f"data: {json.dumps({'token': token_text})}\n\n"
+ # Append the token to sequence
+ accumulated_tokens.append(token)
+ # Decode all accumulated tokens to get proper UTF-8 handling
+ # Note that decode is a quite efficient operation, basically table lookup and string concat
+ current_text = worker.tokenizer.decode(accumulated_tokens)
+ # Only emit text if it doesn't end with a replacement character
+ # This ensures we don't emit incomplete UTF-8 sequences
+ if not current_text.endswith('�'):
+ # Extract only the new text since last clean decode
+ new_text = current_text[len(last_clean_text):]
+ if new_text: # Only yield if there's new content
+ yield f"data: {json.dumps({'token': new_text, 'gpu': worker.gpu_id}, ensure_ascii=False)}\n\n"
+ last_clean_text = current_text
yield f"data: {json.dumps({'done': True})}\n\n"
@app.post("/chat/completions")
async def chat_completions(request: ChatRequest):
- """Chat completion endpoint with streaming."""
- engine = app.state.engine
- tokenizer = app.state.tokenizer
+ """Chat completion endpoint (streaming only) - uses worker pool for multi-GPU."""
- # Build conversation tokens
- bos = tokenizer.get_bos_token_id()
- user_start = tokenizer.encode_special("<|user_start|>")
- user_end = tokenizer.encode_special("<|user_end|>")
- assistant_start = tokenizer.encode_special("<|assistant_start|>")
- assistant_end = tokenizer.encode_special("<|assistant_end|>")
+ # Basic validation to prevent abuse
+ validate_chat_request(request)
- conversation_tokens = [bos]
- for message in request.messages:
- if message.role == "user":
- conversation_tokens.append(user_start)
- conversation_tokens.extend(tokenizer.encode(message.content))
- conversation_tokens.append(user_end)
- elif message.role == "assistant":
- conversation_tokens.append(assistant_start)
- conversation_tokens.extend(tokenizer.encode(message.content))
- conversation_tokens.append(assistant_end)
+ # Log incoming conversation to console
+ logger.info("="*20)
+ for i, message in enumerate(request.messages):
+ logger.info(f"[{message.role.upper()}]: {message.content}")
+ logger.info("-"*20)
- conversation_tokens.append(assistant_start)
+ # Acquire a worker from the pool (will wait if all are busy)
+ worker_pool = app.state.worker_pool
+ worker = await worker_pool.acquire_worker()
+
+ try:
+ # Build conversation tokens
+ bos = worker.tokenizer.get_bos_token_id()
+ user_start = worker.tokenizer.encode_special("<|user_start|>")
+ user_end = worker.tokenizer.encode_special("<|user_end|>")
+ assistant_start = worker.tokenizer.encode_special("<|assistant_start|>")
+ assistant_end = worker.tokenizer.encode_special("<|assistant_end|>")
+
+ conversation_tokens = [bos]
+ for message in request.messages:
+ if message.role == "user":
+ conversation_tokens.append(user_start)
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
+ conversation_tokens.append(user_end)
+ elif message.role == "assistant":
+ conversation_tokens.append(assistant_start)
+ conversation_tokens.extend(worker.tokenizer.encode(message.content))
+ conversation_tokens.append(assistant_end)
+
+ conversation_tokens.append(assistant_start)
+
+ # Streaming response with worker release after completion
+ response_tokens = []
+ async def stream_and_release():
+ try:
+ async for chunk in generate_stream(
+ worker,
+ conversation_tokens,
+ temperature=request.temperature,
+ max_new_tokens=request.max_tokens,
+ top_k=request.top_k
+ ):
+ # Accumulate response for logging
+ chunk_data = json.loads(chunk.replace("data: ", "").strip())
+ if "token" in chunk_data:
+ response_tokens.append(chunk_data["token"])
+ yield chunk
+ finally:
+ # Log the assistant response to console
+ full_response = "".join(response_tokens)
+ logger.info(f"[ASSISTANT] (GPU {worker.gpu_id}): {full_response}")
+ logger.info("="*20)
+ # Release worker back to pool after streaming is done
+ await worker_pool.release_worker(worker)
- if request.stream:
return StreamingResponse(
- generate_stream(
- engine,
- tokenizer,
- conversation_tokens,
- temperature=request.temperature,
- max_new_tokens=request.max_tokens,
- top_k=request.top_k
- ),
+ stream_and_release(),
media_type="text/event-stream"
)
- else:
- # Non-streaming response
- temperature = request.temperature if request.temperature is not None else args.temperature
- max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
- top_k = request.top_k if request.top_k is not None else args.top_k
-
- with autocast_ctx:
- result_tokens, masks = engine.generate_batch(
- conversation_tokens,
- num_samples=1,
- max_tokens=max_tokens,
- temperature=temperature,
- top_k=top_k
- )[0]
-
- response_tokens = result_tokens[len(conversation_tokens):]
- response_text = tokenizer.decode(response_tokens)
- return {
- "choices": [{
- "message": {
- "role": "assistant",
- "content": response_text
- },
- "finish_reason": "stop"
- }]
- }
+ except Exception as e:
+ # Make sure to release worker even on error
+ await worker_pool.release_worker(worker)
+ raise e
@app.get("/health")
async def health():
"""Health check endpoint."""
+ worker_pool = getattr(app.state, 'worker_pool', None)
return {
"status": "ok",
- "ready": hasattr(app.state, 'model') and app.state.model is not None
+ "ready": worker_pool is not None and len(worker_pool.workers) > 0,
+ "num_gpus": worker_pool.num_gpus if worker_pool else 0,
+ "available_workers": worker_pool.available_workers.qsize() if worker_pool else 0
+ }
+
+@app.get("/stats")
+async def stats():
+ """Get worker pool statistics."""
+ worker_pool = app.state.worker_pool
+ return {
+ "total_workers": len(worker_pool.workers),
+ "available_workers": worker_pool.available_workers.qsize(),
+ "busy_workers": len(worker_pool.workers) - worker_pool.available_workers.qsize(),
+ "workers": [
+ {
+ "gpu_id": w.gpu_id,
+ "device": str(w.device)
+ } for w in worker_pool.workers
+ ]
}
if __name__ == "__main__":
diff --git a/scripts/mid_train.py b/scripts/mid_train.py
index 18daedf..c45f349 100644
--- a/scripts/mid_train.py
+++ b/scripts/mid_train.py
@@ -40,10 +40,10 @@ embedding_lr = 0.2
matrix_lr = 0.02
init_lr_frac = 1.0 # initial learning rate is this fraction of the base learning rate
weight_decay = 0.0
-final_lr_frac = 0.0 # final LR is this fraction of the initial LR
eval_every = 150
eval_tokens = 20*524288
total_batch_size = 524288
+dry_run = 0 # dry_run=1 is for experiments: we will log to wandb but we won't write checkpoints or report
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
@@ -141,7 +141,8 @@ progress = 0 # will go from 0 to 1 over the course of the epoch
# Learning rate scheduler
def get_lr_multiplier(progress):
- return progress * 1.0 + (1 - progress) * final_lr_frac
+ # first 80% of training: no decay, then linearly ramp down to 0.
+ return 1 if progress < 0.8 else 1 - (progress - 0.8) / 0.2
# Momentum scheduler for Muon optimizer
def get_muon_momentum(it):
@@ -185,7 +186,7 @@ while True:
model.train()
# save checkpoint at the end of the run (only on master process)
- if master_process and last_step:
+ if master_process and last_step and not dry_run:
output_dirname = f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "mid_checkpoints", output_dirname)
save_checkpoint(
@@ -273,17 +274,18 @@ print0(f"Total training time: {total_training_time/60:.2f}m")
print0(f"Minimum validation bpb: {min_val_bpb:.4f}")
# Log to report
-from nanochat.report import get_report
-get_report().log(section="Midtraining", data=[
- user_config, # CLI args
- { # stats about the training setup
- "Number of iterations": step,
- "DDP world size": ddp_world_size,
- },
- { # stats about training outcomes
- "Minimum validation bpb": min_val_bpb,
- }
-])
+if not dry_run:
+ from nanochat.report import get_report
+ get_report().log(section="Midtraining", data=[
+ user_config, # CLI args
+ { # stats about the training setup
+ "Number of iterations": step,
+ "DDP world size": ddp_world_size,
+ },
+ { # stats about training outcomes
+ "Minimum validation bpb": min_val_bpb,
+ }
+ ])
# cleanup
wandb_run.finish() # wandb run finish
diff --git a/speedrun.sh b/speedrun.sh
index add2a01..dccfb90 100644
--- a/speedrun.sh
+++ b/speedrun.sh
@@ -17,6 +17,7 @@ export OMP_NUM_THREADS=1
# For example, for a gfx1151 GPU, we can use gfx1100 (11.0.0).
export HSA_OVERRIDE_GFX_VERSION=11.0.0
NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
+export NANOCHAT_BASE_DIR="$HOME/.cache/nanochat"
mkdir -p $NANOCHAT_BASE_DIR
# -----------------------------------------------------------------------------