Compare commits

...

16 Commits

Author SHA1 Message Date
Evgeny
9d901a4937
Merge 04862cbfea into f66a780f68 2025-11-14 16:44:10 -05:00
Andrej
f66a780f68
Fix torch.dtype mismatching when running engine inline test. 2025-11-14 07:28:29 -08:00
Andrej
4763ce612a
Small fixes to typos 2025-11-14 07:25:59 -08:00
Sofie Van Landeghem
c6f5bd67db
revert change of base to sft for quick inline test 2025-11-14 12:20:03 +01:00
svlandeg
a2fb3c83a6 fix typos 2025-11-14 11:20:25 +01:00
svlandeg
e5efb4b471 add test_engine.py to file structure 2025-11-14 11:13:42 +01:00
Andrej Karpathy
9a71d13688 typo oops 2025-11-13 16:08:30 +00:00
Andrej Karpathy
7b7fd0fe71 thank you Sophie for your help with nanochat 2025-11-13 16:07:54 +00:00
Andrej Karpathy
c6abcdfe3a big change: add pretraining resumption logic so that checkpoints can now be approximately resumed and training can continue. this is useful for very long runs when you don't want the anxiety of your run crashing for some reason. alternatively, it's a way to recover training in the event of loss spikes. i mean, this should have been there in v0 but it's ok. the resumption is approximate to control complexity and bloat, but it's possible we want to change that in the future. to use, set --save_every to a step interval to write checkpoints with, and then use --resume_from_step to resume optimization from a given step. only base model training (pretraining) supports this atm, but it's ok because midtraining is comparably quite a bit faster. 2025-11-13 15:34:40 +00:00
Andrej Karpathy
91f09ccd0d minor fix comment in engine 2025-11-13 15:28:18 +00:00
Andrej Karpathy
adb5d4a16c uv lock has to change when we removed numpy the other commit 2025-11-13 15:16:27 +00:00
howardgao@outlook.com
b399e43168 fix engine test bug 2025-11-06 08:56:45 +08:00
svlandeg
52e85aaf80 Merge branch 'master' into fix/typo 2025-11-02 13:41:13 +01:00
svlandeg
70319851fc fix typo 2025-10-29 19:48:34 +01:00
Evgeny Sorokin
04862cbfea improved error handling for openai sdk 2025-10-16 14:29:39 +02:00
Evgeny Sorokin
28e6b9b9c2 added web_cpu script 2025-10-16 12:49:40 +02:00
11 changed files with 969 additions and 88 deletions

View File

@ -32,6 +32,34 @@ python -m scripts.chat_web
And then visit the URL shown. Make sure to access it correctly, e.g. on Lambda use the public IP of the node you're on, followed by the port, so for example [http://209.20.xxx.xxx:8000/](http://209.20.xxx.xxx:8000/), etc. Then talk to your LLM as you'd normally talk to ChatGPT! Get it to write stories or poems. Ask it to tell you who you are to see a hallucination. Ask it why the sky is blue. Or why it's green. The speedrun is a 4e19 FLOPs capability model so it's a bit like talking to a kindergartener :).
### CPU Inference
If you want to run inference on CPU (e.g., on your laptop or a machine without GPU), use the CPU web server:
```bash
python -m scripts.chat_web_cpu --model-dir /tmp/nanochat
```
This script automatically converts the model to float32 and runs inference on CPU. You can then access the web UI at `http://localhost:8000` or use it via the OpenAI-compatible API.
CPU web server (`chat_web_cpu.py`) is compatible with the OpenAI API specification. This means you can use any OpenAI SDK, tool, or framework with your NanoChat models:
```python
from openai import OpenAI
client = OpenAI(
api_key="not_set"
base_url="http://localhost:8000/v1",
)
response = client.chat.completions.create(
model="nanochat",
messages=[{"role": "user", "content": "Hello!"}]
)
print(response.choices[0].message.content)
```
---
<img width="2672" height="1520" alt="image" src="https://github.com/user-attachments/assets/ed39ddf8-2370-437a-bedc-0f39781e76b5" />
@ -184,6 +212,7 @@ python -m pytest tests/test_rustbpe.py -v -s
│ ├── smoltalk.py # Conglomerate dataset of SmolTalk from HF
│ └── spellingbee.py # Task teaching model to spell/count letters
├── tests
│ └── test_engine.py
│ └── test_rustbpe.py
└── uv.lock
```
@ -201,6 +230,7 @@ Current LLM policy: disclosure. When submitting a PR, please declare any parts t
- Thank you to [HuggingFace](https://huggingface.co/) for fineweb and smoltalk.
- Thank you [Lambda](https://lambda.ai/service/gpu-cloud) for the compute used in developing this project.
- Thank you to chief LLM whisperer 🧙‍♂️ Alec Radford for advice/guidance.
- Thank you to the repo czar Sofie [@svlandeg](https://github.com/svlandeg) for help with managing issues, pull requests and discussions of nanochat.
## Cite

View File

@ -20,33 +20,32 @@ def log0(message):
if int(os.environ.get('RANK', 0)) == 0:
logger.info(message)
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data):
assert int(os.environ.get('RANK', 0)) == 0 # prevent footguns for now
def save_checkpoint(checkpoint_dir, step, model_data, optimizer_data, meta_data, rank=0):
if rank == 0:
os.makedirs(checkpoint_dir, exist_ok=True)
# Save the model state (parameters)
# Save the model state parameters
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
torch.save(model_data, model_path)
log0(f"Saved model file to: {model_path}")
# Save the optimizer state (useful for SFT or any other fine-tuning)
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
torch.save(optimizer_data, optimizer_path)
log0(f"Saved optimizer file to: {optimizer_path}")
logger.info(f"Saved model parameters to: {model_path}")
# Save the metadata dict as json
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta_data, f, indent=2)
log0(f"Saved metadata file to: {meta_path}")
logger.info(f"Saved metadata to: {meta_path}")
# Note that optimizer state is sharded across ranks, so each rank must save its own.
if optimizer_data is not None:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
torch.save(optimizer_data, optimizer_path)
logger.info(f"Saved optimizer state to: {optimizer_path}")
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False):
def load_checkpoint(checkpoint_dir, step, device, load_optimizer=False, rank=0):
# Load the model state
model_path = os.path.join(checkpoint_dir, f"model_{step:06d}.pt")
model_data = torch.load(model_path, map_location=device)
# Load the optimizer state if requested
optimizer_data = None
if load_optimizer:
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}.pt")
optimizer_path = os.path.join(checkpoint_dir, f"optim_{step:06d}_rank{rank:d}.pt")
optimizer_data = torch.load(optimizer_path, map_location=device)
# Load the metadata
meta_path = os.path.join(checkpoint_dir, f"meta_{step:06d}.json")

View File

@ -148,6 +148,8 @@ def compute_init(device_type="cuda"): # cuda|cpu|mps
assert torch.backends.mps.is_available(), "Your PyTorch installation is not configured for MPS but device_type is 'mps'"
# Reproducibility
# Note that we set the global seeds here, but most of the code uses explicit rng objects.
# The only place where global rng might be used is nn.Module initialization of the model weights.
torch.manual_seed(42)
if device_type == "cuda":
torch.cuda.manual_seed(42)

View File

@ -1,49 +1,87 @@
from collections import deque
import torch
import pyarrow.parquet as pq
from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.dataset import list_parquet_files
from nanochat.tokenizer import get_tokenizer
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda"):
"""Stream pretraining text from parquet files, tokenize, yield training batches."""
def tokenizing_distributed_data_loader_with_state(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128, device="cuda", resume_state_dict=None):
"""
Stream pretraining text from parquet files, tokenize, yield training batches.
This implementation became a bit more complex because we wish to support approximate resume training.
Instead of turning this into a Class, we opt to return the state_dict with every batch,
and then the caller can pass in a state_dict to resume training from a desired point.
Note that this resumption is atm only *approximate* for simplicity.
We won't repeat the same documents but we might skip a few.
The state_dict that is returned can be later passed into this function via `resume_state_dict` to approximately resume.
Perfect state resumption is possible but would be a lot more bloated, probably not worth it atm.
"""
assert split in ["train", "val"], "split must be 'train' or 'val'"
# infinite iterator over document batches (list of text strings)
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
def document_batches():
parquet_paths = list_parquet_files()
parquet_paths = parquet_paths[:-1] if split == "train" else parquet_paths[-1:]
resume_pq_idx = resume_state_dict["pq_idx"] if resume_state_dict is not None else 0
resume_rg_idx = resume_state_dict["rg_idx"] if resume_state_dict is not None else None
pq_idx = resume_pq_idx # we kick off parquet files at the resume index (or by default just 0)
while True: # iterate infinitely (multi-epoch)
while pq_idx < len(parquet_paths): # iterate over all parquet files
filepath = parquet_paths[pq_idx]
pf = pq.ParquetFile(filepath)
# Start from resume point if resuming on same file, otherwise from DDP rank
# I know this state resumption is a little bit tricky and a little bit hacky... sigh.
if resume_rg_idx is not None:
base_idx = resume_rg_idx // ddp_world_size # in units of ddp_world_size
base_idx += 1 # advance by 1 so that we definitely don't repeat data after resuming
rg_idx = base_idx * ddp_world_size + ddp_rank
resume_rg_idx = None # set to None as we only want to do this a single time
else:
rg_idx = ddp_rank
while rg_idx < pf.num_row_groups:
rg = pf.read_row_group(rg_idx)
batch = rg.column('text').to_pylist() # each batch is a parquet group, e.g. 1024 rows
# the tokenizer encode might want to go in even smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size], (pq_idx, rg_idx)
rg_idx += ddp_world_size # advance to the next row group (in DDP)
pq_idx += 1 # advance to the next parquet file
batches = document_batches()
# Now emit batches of tokens.
needed_tokens = B * T + 1 # +1 is because we also need the target at the last token
# get the tokenizer and the bos token
tokenizer = get_tokenizer()
bos_token = tokenizer.get_bos_token_id()
# scratch buffer holds the tokens for one iteration
token_buffer = deque() # we stream tokens on the right and pop from the left
# infinite iterator over document batches
def document_batches():
while True:
# batch will iterate in group size of the parquet files, usually e.g. 1024 rows
for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
# for the tokenizer we might want to go in usually smaller batches, e.g. 128 rows
for i in range(0, len(batch), tokenizer_batch_size):
yield batch[i:i+tokenizer_batch_size]
batches = document_batches()
batch_index = 0
while True:
# Accumulate enough tokens for one iteration before yielding.
while len(token_buffer) < needed_tokens:
doc_batch = next(batches)
doc_batch, (pq_idx, rg_idx) = next(batches)
token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
for tokens in token_lists:
token_buffer.extend(tokens)
batch_index += 1
# Move tokens from the deque into the scratch buffer
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
# CUDA supports memory pinning for faster transfers between CPU and GPU:
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
# CUDA supports memory pinning for asynchronous transfers between CPU and GPU
use_cuda_optimizations = device == "cuda"
scratch = torch.tensor(tokens, dtype=torch.long, pin_memory=use_cuda_optimizations) # in PyTorch, long=int64
# Create the inputs/targets as 1D tensors
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
inputs_cpu = scratch[:-1]
targets_cpu = scratch[1:]
# Reshape to 2D and move to GPU async
inputs = inputs_cpu.view(B, T).to(device=device, dtype=torch.int32, non_blocking=True)
targets = targets_cpu.view(B, T).to(device=device, dtype=torch.int64, non_blocking=True)
inputs = inputs_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
targets = targets_cpu.view(B, T).to(device=device, non_blocking=use_cuda_optimizations)
state_dict = {"pq_idx": pq_idx, "rg_idx": rg_idx} # we need this in case we wish to approximately resume training
yield inputs, targets, state_dict
def tokenizing_distributed_data_loader(*args, **kwargs):
# helper function that only emits the inputs/targets and not the state_dict
for inputs, targets, state_dict in tokenizing_distributed_data_loader_with_state(*args, **kwargs):
yield inputs, targets

View File

@ -17,8 +17,9 @@ import signal
import warnings
from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init
from nanochat.common import compute_init, autodetect_device_type
from nanochat.checkpoint_manager import load_model
from contextlib import nullcontext
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -107,8 +108,9 @@ class KVCache:
assert self.kv_cache is None, "Cannot prefill a non-empty KV cache"
assert other.kv_cache is not None, "Cannot prefill with a None KV cache"
for ix, (dim1, dim2) in enumerate(zip(self.kv_shape, other.kv_shape)):
# ix 0: num_layers, 1: k/v, 2: batch_size, 3: num_heads, 4: seq_len, 5: head_dim
if ix in [0, 1, 3, 5]:
# num_layers, batch_size, num_heads, head_dim must match
# num_layers, k/v, num_heads, head_dim must match
assert dim1 == dim2, f"Dim {ix} mismatch: {dim1} != {dim2}"
elif ix == 2:
# batch_size can be expanded
@ -327,6 +329,9 @@ if __name__ == "__main__":
import time
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")
bos_token_id = tokenizer.get_bos_token_id()
@ -339,6 +344,7 @@ if __name__ == "__main__":
torch.cuda.synchronize()
t0 = time.time()
stream = model.generate(prompt_tokens, **kwargs)
with autocast_ctx:
for token in stream:
generated_tokens.append(token)
chunk = tokenizer.decode([token])
@ -354,6 +360,7 @@ if __name__ == "__main__":
stream = engine.generate(prompt_tokens, num_samples=1, **kwargs) # note: runs in fp32
torch.cuda.synchronize()
t0 = time.time()
with autocast_ctx:
for token_column, token_masks in stream:
token = token_column[0] # only print out the first row
generated_tokens.append(token)

View File

@ -9,9 +9,9 @@ import torch.distributed as dist
def evaluate_bpb(model, batches, steps, token_bytes):
"""
Instead of the naive 'mean loss', this function returns the bits per byte (bpb),
which is a tokenization vocab size-indepedent metric, meaning you are still comparing
which is a tokenization vocab size-independent metric, meaning you are still comparing
apples:apples if you change the vocab size. The way this works is that instead of just
calculating the average loss as usual, you calculate the sum loss, and indepependently
calculating the average loss as usual, you calculate the sum loss, and independently
also the sum bytes (of all the target tokens), and divide. This normalizes the loss by
the number of bytes that the target tokens represent.

View File

@ -20,10 +20,10 @@ import wandb
import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.dataloader import tokenizing_distributed_data_loader, tokenizing_distributed_data_loader_with_state
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, autodetect_device_type
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.checkpoint_manager import save_checkpoint, load_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from scripts.base_eval import evaluate_model
@ -52,12 +52,14 @@ grad_clip = 1.0 # gradient clipping value (0.0 = disabled)
warmup_ratio = 0.0 # ratio of iterations for LR warmup
warmdown_ratio = 0.2 # ratio of iterations for LR warmdown
final_lr_frac = 0.0 # final LR is this fraction of the initial LR
resume_from_step = -1 # resume training from this step of the optimization (-1 = disable)
# Evaluation
eval_every = 250 # every how many steps to evaluate the model for val bpb
eval_tokens = 20*524288 # number of tokens to evaluate val loss on
core_metric_every = 2000 # every how many steps to evaluate the core metric (-1 = disable)
core_metric_max_per_task = 500 # examples per task in estimating the core metric
sample_every = 2000 # every how many steps to sample from the model
save_every = -1 # every how many steps to save model checkpoints (-1 = disable, and save only at the end of the run)
# Output
model_tag = "" # optionally override the model tag for the output checkpoint directory name
# now allow CLI to override the settings via the configurator lol
@ -103,16 +105,31 @@ grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
# Create a new model with random weights
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)
model.to_empty(device=device)
model.init_weights()
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
# If we are resuming, overwrite the model parameters with those of the checkpoint
base_dir = get_base_dir()
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
resuming = resume_from_step != -1
if resuming:
print0(f"Resuming optimization from step {resume_from_step}")
model_data, optimizer_data, meta_data = load_checkpoint(checkpoint_dir, resume_from_step, device, load_optimizer=True, rank=ddp_rank)
model.load_state_dict(model_data, strict=True, assign=True)
del model_data # free up this memory after the copy
orig_model = model # original, uncompiled model, for saving raw model state_dict and for inference/evaluation (because the shapes may change shape)
model = torch.compile(model, dynamic=False) # the inputs to model will never change shape so dynamic=False is safe
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
@ -143,12 +160,18 @@ print0(f"Total training FLOPs estimate: {num_flops_per_token * total_tokens:e}")
optimizers = model.setup_optimizers(unembedding_lr=unembedding_lr, embedding_lr=embedding_lr, matrix_lr=matrix_lr, weight_decay=weight_decay)
adamw_optimizer, muon_optimizer = optimizers
if resuming:
for opt, dat in zip(optimizers, optimizer_data):
opt.load_state_dict(dat)
del optimizer_data # free up the memory
# -----------------------------------------------------------------------------
# Initialize the DataLoaders for train/val
base_dir = get_base_dir()
tokens_dir = os.path.join(base_dir, "tokenized_data")
train_loader = tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="train", device=device)
dataloader_resume_state_dict = None if not resuming else meta_data["dataloader_state_dict"]
train_loader = tokenizing_distributed_data_loader_with_state(device_batch_size, max_seq_len, split="train", device=device, resume_state_dict=dataloader_resume_state_dict)
build_val_loader = lambda: tokenizing_distributed_data_loader(device_batch_size, max_seq_len, split="val", device=device)
x, y = next(train_loader) # kick off load of the very first batch of data
x, y, dataloader_state_dict = next(train_loader) # kick off load of the very first batch of data
# -----------------------------------------------------------------------------
# Set up hyperparameter schedulers
@ -171,15 +194,25 @@ def get_muon_momentum(it):
momentum = (1 - frac) * 0.85 + frac * 0.95
return momentum
# -----------------------------------------------------------------------------
# Loop state (variables updated by the training loop)
if not resuming:
step = 0
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
total_training_time = 0 # total wall-clock time of training
else:
step = meta_data["step"]
loop_state = meta_data["loop_state"]
min_val_bpb = loop_state["min_val_bpb"]
smooth_train_loss = loop_state["smooth_train_loss"]
total_training_time = loop_state["total_training_time"]
# -----------------------------------------------------------------------------
# Training loop
min_val_bpb = float("inf")
smooth_train_loss = 0 # EMA of training loss
ema_beta = 0.9 # EMA decay factor
total_training_time = 0 # total wall-clock time of training
# note that we run +1 steps only so that we can eval and save at the end
for step in range(num_iterations + 1):
last_step = step == num_iterations
while True:
last_step = step == num_iterations # loop runs num_iterations+1 times so that we can eval/save at the end
flops_so_far = num_flops_per_token * total_batch_size * step
# once in a while: evaluate the val bpb (all ranks participate)
@ -237,25 +270,31 @@ for step in range(num_iterations + 1):
print0(tokenizer.decode(sample[0]))
model.train()
# save checkpoint at the end of the run (only on master process)
if master_process and last_step:
output_dirname = model_tag if model_tag else f"d{depth}" # e.g. d12
checkpoint_dir = os.path.join(base_dir, "base_checkpoints", output_dirname)
# save checkpoint: at the end of the run, or every save_every steps, except at the first step or the resume step
if last_step or (step > 0 and step != resume_from_step and save_every > 0 and step % save_every == 0):
save_checkpoint(
checkpoint_dir,
step,
orig_model.state_dict(),
[opt.state_dict() for opt in optimizers], # TODO: make sure saving across ranks is done correctly
{
orig_model.state_dict(), # model parameters
[opt.state_dict() for opt in optimizers], # optimizer states
{ # metadata saved as json
"step": step,
"val_bpb": val_bpb, # loss at last step
"model_config": model_config_kwargs,
"user_config": user_config, # inputs to the training script
"device_batch_size": device_batch_size,
"max_seq_len": max_seq_len,
}
"dataloader_state_dict": dataloader_state_dict,
"loop_state": { # all loop state (other than step) so that we can resume training
"min_val_bpb": min_val_bpb,
"smooth_train_loss": smooth_train_loss,
"total_training_time": total_training_time,
},
},
rank=ddp_rank,
)
# termination conditions (TODO: possibly also add loss explosions etc.)
if last_step:
break
@ -270,7 +309,7 @@ for step in range(num_iterations + 1):
train_loss = loss.detach() # for logging
loss = loss / grad_accum_steps # each .backward() is a grad sum => normalize loss here
loss.backward()
x, y = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
x, y, dataloader_state_dict = next(train_loader) # prefetch the next batch while the GPU is busy with forward/backward
# gradient clipping
grad_clip_enabled = grad_clip > 0.0
if grad_clip_enabled:
@ -293,6 +332,7 @@ for step in range(num_iterations + 1):
# -------------------------------------------------------------------------
# logging
ema_beta = 0.9 # EMA decay factor for some smoothing just for nicer logging
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
@ -319,6 +359,9 @@ for step in range(num_iterations + 1):
log_data["train/grad_norm"] = grad_norm
wandb_run.log(log_data)
# state update
step += 1
# print a few more stats
print0(f"Peak memory usage: {get_max_memory() / 1024 / 1024:.2f}MiB")
print0(f"Total training time: {total_training_time/60:.2f}m")

View File

@ -1,6 +1,6 @@
"""
Evaluate the Chat model.
All the generic code lives here, and all the evlauation-specific
All the generic code lives here, and all the evaluation-specific
code lives in nanochat directory and is imported from here.
Example runs:

View File

@ -192,7 +192,7 @@ for step in range(num_iterations):
})
model.train()
# evlauate accuracy of the multiple choice tasks (which are quick to run)
# evaluate accuracy of the multiple choice tasks (which are quick to run)
if last_step or (step > 0 and step % eval_metrics_every == 0):
model.eval()
metrics = {}

764
scripts/chat_web_cpu.py Normal file
View File

@ -0,0 +1,764 @@
#!/usr/bin/env python3
"""
CPU-compatible web chat server - serves both UI and API from a single FastAPI instance.
Run with: python chat_web_cpu.py --model-dir /path/to/model
Then open http://localhost:8000 in your browser.
"""
import argparse
import json
import os
import glob
import pickle
import math
import time
import uuid
import torch
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse, JSONResponse
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, Field
from typing import List, Optional, AsyncGenerator, Literal, Union, Dict, Any
from dataclasses import dataclass
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Minimal GPT implementation (copied from generate_cpu.py)
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6
n_kv_head: int = 6
n_embd: int = 768
def norm(x):
return F.rms_norm(x, (x.size(-1),))
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
out = torch.cat([y1, y2], 3)
out = out.to(x.dtype)
return out
def repeat_kv(x, n_rep):
if n_rep == 1:
return x
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
class CausalSelfAttention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
B, T, C = x.size()
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
cos, sin = cos_sin
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
q, k = norm(q), norm(k)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
Tq = q.size(2)
Tk = k.size(2)
nrep = self.n_head // self.n_kv_head
k, v = repeat_kv(k, nrep), repeat_kv(v, nrep)
if kv_cache is None or Tq == Tk:
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
elif Tq == 1:
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
else:
attn_mask = torch.zeros((Tq, Tk), dtype=torch.bool, device=q.device)
prefix_len = Tk - Tq
if prefix_len > 0:
attn_mask[:, :prefix_len] = True
attn_mask[:, prefix_len:] = torch.tril(torch.ones((Tq, Tq), dtype=torch.bool, device=q.device))
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.attn = CausalSelfAttention(config, layer_idx)
self.mlp = MLP(config)
def forward(self, x, cos_sin, kv_cache):
x = x + self.attn(norm(x), cos_sin, kv_cache)
x = x + self.mlp(norm(x))
return x
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.rotary_seq_len = config.sequence_len * 10
head_dim = config.n_embd // config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
def init_weights(self):
self.apply(self._init_weights)
torch.nn.init.zeros_(self.lm_head.weight)
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
torch.nn.init.zeros_(block.attn.c_proj.weight)
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
def _init_weights(self, module):
if isinstance(module, nn.Linear):
fan_out = module.weight.size(0)
fan_in = module.weight.size(1)
std = 1.0 / math.sqrt(fan_in) * min(1.0, math.sqrt(fan_out / fan_in))
torch.nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
if device is None:
device = self.transformer.wte.weight.device
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
return cos, sin
def forward(self, idx, targets=None, kv_cache=None):
B, T = idx.size()
assert T <= self.cos.size(1)
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T]
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
softcap = 15
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap)
return logits
# -----------------------------------------------------------------------------
# Simple tokenizer wrapper
class SimpleTokenizer:
def __init__(self, enc):
self.enc = enc
try:
self.bos_token_id = enc.encode_single_token("<|bos|>")
except:
try:
self.bos_token_id = enc.encode_single_token("<|endoftext|>")
except:
self.bos_token_id = 0
# Get special tokens
try:
self.user_start = enc.encode_single_token("<|user_start|>")
self.user_end = enc.encode_single_token("<|user_end|>")
self.assistant_start = enc.encode_single_token("<|assistant_start|>")
self.assistant_end = enc.encode_single_token("<|assistant_end|>")
except:
# Fallback if special tokens don't exist
self.user_start = 0
self.user_end = 0
self.assistant_start = 0
self.assistant_end = 0
def get_bos_token_id(self):
return self.bos_token_id
def encode_special(self, token):
try:
return self.enc.encode_single_token(token)
except:
return 0
def encode(self, text):
return self.enc.encode_ordinary(text)
def decode(self, tokens):
return self.enc.decode(tokens)
# -----------------------------------------------------------------------------
# Simple generator (no Engine class needed)
def generate_tokens(model, input_tokens, max_tokens=512, temperature=0.8, top_k=50, device='cpu'):
"""Generate tokens one at a time."""
x = torch.tensor([input_tokens], dtype=torch.long, device=device)
generated = []
with torch.inference_mode():
for _ in range(max_tokens):
logits = model(x)
logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = torch.nn.functional.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated.append(next_token.item())
x = torch.cat([x, next_token], dim=1)
yield next_token.item()
# -----------------------------------------------------------------------------
# FastAPI app
parser = argparse.ArgumentParser(description='NanoChat Web Server (CPU)')
parser.add_argument('--model-dir', type=str, required=True, help='Path to model directory containing model_*.pt, meta_*.json, and tokenizer.pkl')
parser.add_argument('-t', '--temperature', type=float, default=0.8, help='Default temperature for generation')
parser.add_argument('-k', '--top-k', type=int, default=50, help='Default top-k sampling parameter')
parser.add_argument('-m', '--max-tokens', type=int, default=512, help='Default max tokens for generation')
parser.add_argument('-p', '--port', type=int, default=8000, help='Port to run the server on')
parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind the server to')
args = parser.parse_args()
device = torch.device("cpu")
# OpenAI-compatible request/response models
class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str # Only text content supported
name: Optional[str] = None
class ChatCompletionRequest(BaseModel):
model: str = Field(default="nanochat", description="Model to use for completion")
messages: List[ChatMessage]
# Supported parameters
temperature: Optional[float] = Field(default=None, ge=0, le=2)
max_tokens: Optional[int] = Field(default=None, ge=1)
top_k: Optional[int] = Field(default=None, ge=1, description="Top-k sampling (NanoChat-specific)")
stream: Optional[bool] = False
# Accepted but not supported (will be rejected if provided)
top_p: Optional[float] = Field(default=None, ge=0, le=1)
n: Optional[int] = Field(default=None, ge=1)
stop: Optional[Union[str, List[str]]] = None
presence_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
# Not supported features
tools: Optional[List[Dict[str, Any]]] = None
tool_choice: Optional[Union[str, Dict[str, Any]]] = None
functions: Optional[List[Dict[str, Any]]] = None
function_call: Optional[Union[str, Dict[str, Any]]] = None
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: Dict[str, Any]
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
class UsageInfo(BaseModel):
prompt_tokens: int
completion_tokens: int
total_tokens: int
class ChatCompletionResponse(BaseModel):
id: str
object: Literal["chat.completion"] = "chat.completion"
created: int
model: str
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class ChatCompletionStreamResponse(BaseModel):
id: str
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int
model: str
choices: List[ChatCompletionResponseStreamChoice]
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Load model on startup."""
print(f"Loading model from {args.model_dir}...")
# Find model and meta files
model_files = glob.glob(os.path.join(args.model_dir, "model_*.pt"))
if not model_files:
raise FileNotFoundError(f"No model files found in {args.model_dir}")
model_file = model_files[0]
meta_files = glob.glob(os.path.join(args.model_dir, "meta_*.json"))
if not meta_files:
raise FileNotFoundError(f"No meta files found in {args.model_dir}")
meta_file = meta_files[0]
# Load metadata
with open(meta_file, 'r') as f:
meta = json.load(f)
model_config_kwargs = meta["model_config"]
print(f"Model config: {model_config_kwargs}")
# Build the model
model_config = GPTConfig(**model_config_kwargs)
with torch.device("meta"):
model = GPT(model_config)
# Load model weights
print("Loading model weights...")
model_data = torch.load(model_file, map_location=device, weights_only=False)
model_data = {k.lstrip("_orig_mod."): v for k, v in model_data.items()}
# Convert bfloat16 to float32 for CPU
print("Converting model to float32 for CPU...")
model_data = {k: v.float() if v.dtype == torch.bfloat16 else v for k, v in model_data.items()}
model.to_empty(device=device)
model.init_weights()
model.load_state_dict(model_data, strict=True, assign=True)
model.eval()
# Load tokenizer
print("Loading tokenizer...")
tokenizer_path = os.path.join(args.model_dir, "tokenizer.pkl")
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
with open(tokenizer_path, "rb") as f:
import tiktoken
enc = pickle.load(f)
tokenizer = SimpleTokenizer(enc)
app.state.model = model
app.state.tokenizer = tokenizer
print(f"✓ Model loaded successfully!")
print(f"✓ Server ready at http://localhost:{args.port}")
yield
app = FastAPI(lifespan=lifespan)
# Custom exception handler for OpenAI-compatible error responses
class OpenAIError(Exception):
"""Custom exception that returns OpenAI-compatible error format."""
def __init__(self, message: str, error_type: str = "invalid_request_error", param: str = None, code: str = None):
self.message = message
self.error_type = error_type
self.param = param
self.code = code
super().__init__(message)
@app.exception_handler(OpenAIError)
async def openai_error_handler(request: Request, exc: OpenAIError):
"""Return errors in OpenAI API format."""
return JSONResponse(
status_code=400,
content={
"error": {
"message": exc.message,
"type": exc.error_type,
"param": exc.param,
"code": exc.code
}
}
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
"""Handle Pydantic validation errors in OpenAI format."""
errors = exc.errors()
if errors:
first_error = errors[0]
param = ".".join(str(x) for x in first_error.get("loc", []))
message = first_error.get("msg", "Invalid request")
else:
param = None
message = "Invalid request"
return JSONResponse(
status_code=400,
content={
"error": {
"message": message,
"type": "invalid_request_error",
"param": param,
"code": None
}
}
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""Serve the chat UI."""
ui_html_path = os.path.join("nanochat", "ui.html")
with open(ui_html_path, "r") as f:
html_content = f.read()
# Replace the API_URL to use the same origin
html_content = html_content.replace(
"const API_URL = `http://${window.location.hostname}:8000`;",
"const API_URL = '';"
)
return HTMLResponse(content=html_content)
@app.get("/logo.svg")
async def logo():
"""Serve the NanoChat logo for favicon and header."""
logo_path = os.path.join("nanochat", "logo.svg")
return FileResponse(logo_path, media_type="image/svg+xml")
async def generate_stream(
model,
tokenizer,
tokens,
completion_id: str,
model_name: str,
created: int,
temperature=None,
max_new_tokens=None,
top_k=None
) -> AsyncGenerator[str, None]:
"""Generate assistant response with OpenAI-compatible streaming.
Supported parameters: temperature, max_new_tokens, top_k
"""
temperature = temperature if temperature is not None else args.temperature
# Greedy decoding when temperature <= 0
if temperature is not None and temperature <= 0:
temperature = 1e-6
max_new_tokens = max_new_tokens if max_new_tokens is not None else args.max_tokens
# Enforce max 1000 cap
if max_new_tokens is None:
max_new_tokens = 256
max_new_tokens = max(1, min(1000, int(max_new_tokens)))
top_k = top_k if top_k is not None else args.top_k
if top_k is None:
top_k = 50
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
top_k = max(1, min(int(top_k), int(vocab_size)))
assistant_end = tokenizer.encode_special("<|assistant_end|>")
bos = tokenizer.get_bos_token_id()
# Send initial chunk with role
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={"role": "assistant", "content": ""},
finish_reason=None
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
finish_reason = "length"
for token in generate_tokens(model, tokens, max_new_tokens, temperature, top_k, device):
if token == assistant_end or token == bos:
finish_reason = "stop"
break
token_text = tokenizer.decode([token])
# Send content chunk
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={"content": token_text},
finish_reason=None
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
# Send final chunk with finish_reason
chunk = ChatCompletionStreamResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseStreamChoice(
index=0,
delta={},
finish_reason=finish_reason
)]
)
yield f"data: {chunk.model_dump_json()}\n\n"
# OpenAI sends [DONE] at the end
yield "data: [DONE]\n\n"
@app.post("/chat/completions")
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
"""
OpenAI-compatible chat completion endpoint.
Supported parameters:
- messages: Array of message objects (text only)
- temperature: Sampling temperature (0-2)
- max_tokens: Maximum tokens to generate
- top_k: Top-k sampling (NanoChat-specific)
- stream: Enable streaming responses
Not supported (rejected with clear errors):
- top_p, n, stop, presence_penalty, frequency_penalty, logit_bias, user
- tools, functions (function calling not supported)
- Multi-modal content (only text messages supported)
"""
model = app.state.model
tokenizer = app.state.tokenizer
# Validate unsupported features
if request.tools or request.tool_choice or request.functions or request.function_call:
raise OpenAIError(
message="Function calling and tools are not supported by this model. Only text completion is available.",
error_type="invalid_request_error",
code="unsupported_feature"
)
# Reject any unsupported standard params if provided
unsupported_fields = []
if request.n is not None:
unsupported_fields.append("n")
if request.top_p is not None:
unsupported_fields.append("top_p")
if request.stop is not None:
unsupported_fields.append("stop")
if request.presence_penalty is not None:
unsupported_fields.append("presence_penalty")
if request.frequency_penalty is not None:
unsupported_fields.append("frequency_penalty")
if request.logit_bias is not None:
unsupported_fields.append("logit_bias")
if request.user is not None:
unsupported_fields.append("user")
if unsupported_fields:
raise OpenAIError(
message=f"Unsupported parameters for this model: {', '.join(unsupported_fields)}. Supported only: messages, temperature, max_tokens, top_k, stream.",
error_type="invalid_request_error",
param=unsupported_fields[0],
code="unsupported_parameter"
)
# Validate messages are text-only
for i, msg in enumerate(request.messages):
if not isinstance(msg.content, str):
raise OpenAIError(
message=f"Message at index {i} contains non-text content. Only text messages are supported.",
error_type="invalid_request_error",
param=f"messages[{i}].content",
code="invalid_message_content"
)
# Generate unique completion ID and timestamp
completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
created = int(time.time())
model_name = request.model
# Build conversation tokens
bos = tokenizer.get_bos_token_id()
user_start = tokenizer.encode_special("<|user_start|>")
user_end = tokenizer.encode_special("<|user_end|>")
assistant_start = tokenizer.encode_special("<|assistant_start|>")
assistant_end = tokenizer.encode_special("<|assistant_end|>")
system_start = tokenizer.encode_special("<|system_start|>")
system_end = tokenizer.encode_special("<|system_end|>")
conversation_tokens = [bos]
for message in request.messages:
if message.role == "user":
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(user_end)
elif message.role == "assistant":
conversation_tokens.append(assistant_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(assistant_end)
elif message.role == "system":
# Handle system messages if supported
if system_start != 0 and system_end != 0:
conversation_tokens.append(system_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(system_end)
else:
# Fallback: treat system message as user message
conversation_tokens.append(user_start)
conversation_tokens.extend(tokenizer.encode(message.content))
conversation_tokens.append(user_end)
conversation_tokens.append(assistant_start)
prompt_tokens = len(conversation_tokens)
# Use only supported parameters: temperature, max_tokens, top_k
if request.stream:
return StreamingResponse(
generate_stream(
model,
tokenizer,
conversation_tokens,
completion_id=completion_id,
model_name=model_name,
created=created,
temperature=request.temperature,
max_new_tokens=request.max_tokens,
top_k=request.top_k
),
media_type="text/event-stream"
)
else:
# Non-streaming response
temperature = request.temperature if request.temperature is not None else args.temperature
# Enforce max 1000 tokens cap
max_tokens = request.max_tokens if request.max_tokens is not None else args.max_tokens
if max_tokens is None:
max_tokens = 256
max_tokens = max(1, min(1000, int(max_tokens)))
# Validate top_k: 1..vocab_size
top_k = request.top_k if request.top_k is not None else args.top_k
if top_k is None:
top_k = 50
vocab_size = getattr(app.state.model.config, 'vocab_size', 50257)
top_k = max(1, min(int(top_k), int(vocab_size)))
generated_tokens = []
finish_reason = "length"
for token in generate_tokens(model, conversation_tokens, max_tokens, temperature, top_k, device):
if token == assistant_end or token == bos:
finish_reason = "stop"
break
generated_tokens.append(token)
response_text = tokenizer.decode(generated_tokens)
completion_tokens = len(generated_tokens)
return ChatCompletionResponse(
id=completion_id,
created=created,
model=model_name,
choices=[ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role="assistant", content=response_text),
finish_reason=finish_reason
)],
usage=UsageInfo(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens
)
)
@app.get("/v1/models")
@app.get("/models")
async def list_models():
"""
List available models (OpenAI-compatible endpoint).
Returns model information with capabilities annotation.
"""
return {
"object": "list",
"data": [
{
"id": "nanochat",
"object": "model",
"created": int(time.time()),
"owned_by": "nanochat",
"permission": [],
"root": "nanochat",
"parent": None
}
]
}
@app.get("/health")
async def health():
"""Health check endpoint."""
return {
"status": "ok",
"ready": hasattr(app.state, 'model') and app.state.model is not None
}
if __name__ == "__main__":
import uvicorn
print(f"Starting NanoChat Web Server (CPU mode)")
print(f"Temperature: {args.temperature}, Top-k: {args.top_k}, Max tokens: {args.max_tokens}")
uvicorn.run(app, host=args.host, port=args.port)

18
uv.lock
View File

@ -311,7 +311,7 @@ name = "exceptiongroup"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions", marker = "python_full_version < '3.12' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "typing-extensions", marker = "python_full_version < '3.11' or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749, upload-time = "2025-05-10T17:42:51.123Z" }
wheels = [
@ -777,7 +777,6 @@ dependencies = [
{ name = "datasets" },
{ name = "fastapi" },
{ name = "files-to-prompt" },
{ name = "numpy" },
{ name = "psutil" },
{ name = "regex" },
{ name = "setuptools" },
@ -811,7 +810,6 @@ requires-dist = [
{ name = "datasets", specifier = ">=4.0.0" },
{ name = "fastapi", specifier = ">=0.117.1" },
{ name = "files-to-prompt", specifier = ">=0.6" },
{ name = "numpy", specifier = "==1.26.4" },
{ name = "psutil", specifier = ">=7.1.0" },
{ name = "regex", specifier = ">=2025.9.1" },
{ name = "setuptools", specifier = ">=80.9.0" },
@ -951,7 +949,7 @@ name = "nvidia-cudnn-cu12"
version = "9.10.2.21"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/fa/41/e79269ce215c857c935fd86bcfe91a451a584dfc27f1e068f568b9ad1ab7/nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:c9132cc3f8958447b4910a1720036d9eff5928cc3179b0a51fb6d167c6cc87d8", size = 705026878, upload-time = "2025-06-06T21:52:51.348Z" },
@ -964,7 +962,7 @@ name = "nvidia-cufft-cu12"
version = "11.3.3.83"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/60/bc/7771846d3a0272026c416fbb7e5f4c1f146d6d80704534d0b187dd6f4800/nvidia_cufft_cu12-11.3.3.83-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:848ef7224d6305cdb2a4df928759dca7b1201874787083b6e7550dd6765ce69a", size = 193109211, upload-time = "2025-03-07T01:44:56.873Z" },
@ -996,9 +994,9 @@ name = "nvidia-cusolver-cu12"
version = "11.7.3.90"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-cusparse-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-cublas-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-cusparse-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/32/f7cd6ce8a7690544d084ea21c26e910a97e077c9b7f07bf5de623ee19981/nvidia_cusolver_cu12-11.7.3.90-py3-none-manylinux_2_27_aarch64.whl", hash = "sha256:db9ed69dbef9715071232caa9b69c52ac7de3a95773c2db65bdba85916e4e5c0", size = 267229841, upload-time = "2025-03-07T01:46:54.356Z" },
@ -1011,7 +1009,7 @@ name = "nvidia-cusparse-cu12"
version = "12.5.8.93"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "nvidia-nvjitlink-cu12", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/f7/cd777c4109681367721b00a106f491e0d0d15cfa1fd59672ce580ce42a97/nvidia_cusparse_cu12-12.5.8.93-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b6c161cb130be1a07a27ea6923df8141f3c295852f4b260c65f18f3e0a091dc", size = 288117129, upload-time = "2025-03-07T01:47:40.407Z" },
@ -1955,7 +1953,7 @@ name = "triton"
version = "3.4.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "setuptools", marker = "extra == 'extra-8-nanochat-gpu'" },
{ name = "setuptools", marker = "(sys_platform == 'linux' and extra == 'extra-8-nanochat-gpu') or (extra == 'extra-8-nanochat-cpu' and extra == 'extra-8-nanochat-gpu')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/62/ee/0ee5f64a87eeda19bbad9bc54ae5ca5b98186ed00055281fd40fb4beb10e/triton-3.4.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ff2785de9bc02f500e085420273bb5cc9c9bb767584a4aa28d6e360cec70128", size = 155430069, upload-time = "2025-07-30T19:58:21.715Z" },