mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-03 14:15:26 +00:00
adding cut_cross_entropy
This commit is contained in:
parent
0f007889dd
commit
7a337f3d5d
|
|
@ -91,6 +91,7 @@ And a bit more about computing environments that will run nanochat:
|
|||
- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
|
||||
- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
|
||||
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
|
||||
- Alternatively, you can enable chunked cross-entropy by setting the flag `--use_chunked_ce=True`. This computes the cross-entropy without materializing the logits in GPU memory, allowing larger models to fit, though it may run slightly slower.
|
||||
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
|
||||
|
||||
## Running on CPU / MPS
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from dataclasses import dataclass
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from cut_cross_entropy import linear_cross_entropy
|
||||
|
||||
from nanochat.common import get_dist_info, print0
|
||||
from nanochat.muon import Muon, DistMuon
|
||||
|
|
@ -31,6 +32,7 @@ class GPTConfig:
|
|||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_embd: int = 768
|
||||
use_chunked_ce: bool = False # uses cut_cross_entropy (https://arxiv.org/pdf/2411.09009)
|
||||
|
||||
|
||||
def norm(x):
|
||||
|
|
@ -278,11 +280,13 @@ class GPT(nn.Module):
|
|||
softcap = 15
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
if (self.config.use_chunked_ce):
|
||||
loss = linear_cross_entropy(x, self.lm_head.weight, targets=targets, softcap=softcap, ignore_index=-1, reduction=loss_reduction).view(-1)
|
||||
else:
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ dependencies = [
|
|||
"torch>=2.8.0",
|
||||
"uvicorn>=0.36.0",
|
||||
"wandb>=0.21.3",
|
||||
"cut-cross-entropy>=25.1.1",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ print_banner()
|
|||
# -----------------------------------------------------------------------------
|
||||
# User settings
|
||||
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
|
||||
use_chunked_ce = False # If True, compute cross-entropy online but slower (no logits materialize)
|
||||
# Model architecture
|
||||
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
|
||||
max_seq_len = 2048 # max context length
|
||||
|
|
@ -92,7 +93,7 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
|||
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
|
||||
# -----------------------------------------------------------------------------
|
||||
# Initialize the Model
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
|
||||
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, use_chunked_ce=use_chunked_ce)
|
||||
with torch.device("meta"):
|
||||
model_config = GPTConfig(**model_config_kwargs)
|
||||
model = GPT(model_config)
|
||||
|
|
|
|||
15
uv.lock
15
uv.lock
|
|
@ -254,6 +254,19 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cut-cross-entropy"
|
||||
version = "25.1.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "torch" },
|
||||
{ name = "triton", marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7e/97/45ff09cfcda7b200389204daa0125168e6544fba257adbbcdf728501d4f9/cut_cross_entropy-25.1.1.tar.gz", hash = "sha256:5fe5924509248b1aea5c890f8887c6a7759f7c8b1ebc0490e42c247c4f7c1e34", size = 22972, upload-time = "2025-01-07T12:21:53.896Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/df/5f/62fdb048f84d19e2123b6bbd722fe09c8c79b4964c50094d1e979db808e2/cut_cross_entropy-25.1.1-py3-none-any.whl", hash = "sha256:e46f26d348f6a67927d17e65c5a212e795be13dcad5b10a77a200d6b8102d9d1", size = 22672, upload-time = "2025-01-07T12:21:51.678Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "datasets"
|
||||
version = "4.0.0"
|
||||
|
|
@ -755,6 +768,7 @@ name = "nanochat"
|
|||
version = "0.1.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "cut-cross-entropy" },
|
||||
{ name = "datasets" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "files-to-prompt" },
|
||||
|
|
@ -776,6 +790,7 @@ dev = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "cut-cross-entropy", specifier = ">=25.1.1" },
|
||||
{ name = "datasets", specifier = ">=4.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.117.1" },
|
||||
{ name = "files-to-prompt", specifier = ">=0.6" },
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user