adding cut_cross_entropy

This commit is contained in:
mnehete32 2025-10-21 03:25:28 +05:30
parent 0f007889dd
commit 7a337f3d5d
5 changed files with 28 additions and 6 deletions

View File

@ -91,6 +91,7 @@ And a bit more about computing environments that will run nanochat:
- The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower.
- All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer.
- If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative.
- Alternatively, you can enable chunked cross-entropy by setting the flag `--use_chunked_ce=True`. This computes the cross-entropy without materializing the logits in GPU memory, allowing larger models to fit, though it may run slightly slower.
- Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering.
## Running on CPU / MPS

View File

@ -18,6 +18,7 @@ from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from cut_cross_entropy import linear_cross_entropy
from nanochat.common import get_dist_info, print0
from nanochat.muon import Muon, DistMuon
@ -31,6 +32,7 @@ class GPTConfig:
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768
use_chunked_ce: bool = False # uses cut_cross_entropy (https://arxiv.org/pdf/2411.09009)
def norm(x):
@ -278,11 +280,13 @@ class GPT(nn.Module):
softcap = 15
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
if (self.config.use_chunked_ce):
loss = linear_cross_entropy(x, self.lm_head.weight, targets=targets, softcap=softcap, ignore_index=-1, reduction=loss_reduction).view(-1)
else:
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits

View File

@ -16,6 +16,7 @@ dependencies = [
"torch>=2.8.0",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
"cut-cross-entropy>=25.1.1",
]
[build-system]

View File

@ -27,6 +27,7 @@ print_banner()
# -----------------------------------------------------------------------------
# User settings
run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb)
use_chunked_ce = False # If True, compute cross-entropy online but slower (no logits materialize)
# Model architecture
depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived
max_seq_len = 2048 # max context length
@ -92,7 +93,7 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}")
# -----------------------------------------------------------------------------
# Initialize the Model
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim)
model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, use_chunked_ce=use_chunked_ce)
with torch.device("meta"):
model_config = GPTConfig(**model_config_kwargs)
model = GPT(model_config)

15
uv.lock
View File

@ -254,6 +254,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" },
]
[[package]]
name = "cut-cross-entropy"
version = "25.1.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "torch" },
{ name = "triton", marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/7e/97/45ff09cfcda7b200389204daa0125168e6544fba257adbbcdf728501d4f9/cut_cross_entropy-25.1.1.tar.gz", hash = "sha256:5fe5924509248b1aea5c890f8887c6a7759f7c8b1ebc0490e42c247c4f7c1e34", size = 22972, upload-time = "2025-01-07T12:21:53.896Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/df/5f/62fdb048f84d19e2123b6bbd722fe09c8c79b4964c50094d1e979db808e2/cut_cross_entropy-25.1.1-py3-none-any.whl", hash = "sha256:e46f26d348f6a67927d17e65c5a212e795be13dcad5b10a77a200d6b8102d9d1", size = 22672, upload-time = "2025-01-07T12:21:51.678Z" },
]
[[package]]
name = "datasets"
version = "4.0.0"
@ -755,6 +768,7 @@ name = "nanochat"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "cut-cross-entropy" },
{ name = "datasets" },
{ name = "fastapi" },
{ name = "files-to-prompt" },
@ -776,6 +790,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "cut-cross-entropy", specifier = ">=25.1.1" },
{ name = "datasets", specifier = ">=4.0.0" },
{ name = "fastapi", specifier = ">=0.117.1" },
{ name = "files-to-prompt", specifier = ">=0.6" },