From 7a337f3d5d64e580bc8006bf99c586162f8f8fb4 Mon Sep 17 00:00:00 2001 From: mnehete32 Date: Tue, 21 Oct 2025 03:25:28 +0530 Subject: [PATCH] adding cut_cross_entropy --- README.md | 1 + nanochat/gpt.py | 14 +++++++++----- pyproject.toml | 1 + scripts/base_train.py | 3 ++- uv.lock | 15 +++++++++++++++ 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 05a214b..23e95ae 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ And a bit more about computing environments that will run nanochat: - The code will run just fine on the Ampere 8XA100 GPU node as well, but a bit slower. - All code will run just fine on even a single GPU by omitting `torchrun`, and will produce ~identical results (code will automatically switch to gradient accumulation), but you'll have to wait 8 times longer. - If your GPU(s) have less than 80GB, you'll have to tune some of the hyperparameters or you will OOM / run out of VRAM. Look for `--device_batch_size` in the scripts and reduce it until things fit. E.g. from 32 (default) to 16, 8, 4, 2, or even 1. Less than that you'll have to know a bit more what you're doing and get more creative. +- Alternatively, you can enable chunked cross-entropy by setting the flag `--use_chunked_ce=True`. This computes the cross-entropy without materializing the logits in GPU memory, allowing larger models to fit, though it may run slightly slower. - Most of the code is fairly vanilla PyTorch so it should run on anything that supports that - xpu, mps, or etc, but I haven't implemented this out of the box so it might take a bit of tinkering. ## Running on CPU / MPS diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..a3473c8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -18,6 +18,7 @@ from dataclasses import dataclass import torch import torch.nn as nn import torch.nn.functional as F +from cut_cross_entropy import linear_cross_entropy from nanochat.common import get_dist_info, print0 from nanochat.muon import Muon, DistMuon @@ -31,6 +32,7 @@ class GPTConfig: n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (MQA) n_embd: int = 768 + use_chunked_ce: bool = False # uses cut_cross_entropy (https://arxiv.org/pdf/2411.09009) def norm(x): @@ -278,11 +280,13 @@ class GPT(nn.Module): softcap = 15 if targets is not None: # training mode: compute and return the loss - # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap - logits = logits.float() # use tf32/fp32 for logits - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) + if (self.config.use_chunked_ce): + loss = linear_cross_entropy(x, self.lm_head.weight, targets=targets, softcap=softcap, ignore_index=-1, reduction=loss_reduction).view(-1) + else: + logits = self.lm_head(x) + logits = softcap * torch.tanh(logits / softcap) # logits softcap + logits = logits.float() # use tf32/fp32 for logits + loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) return loss else: # inference mode: compute and return the logits diff --git a/pyproject.toml b/pyproject.toml index ef3833a..43d128a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "torch>=2.8.0", "uvicorn>=0.36.0", "wandb>=0.21.3", + "cut-cross-entropy>=25.1.1", ] [build-system] diff --git a/scripts/base_train.py b/scripts/base_train.py index 9f2cdff..f4f75e1 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -27,6 +27,7 @@ print_banner() # ----------------------------------------------------------------------------- # User settings run = "dummy" # wandb run name default ("dummy" is special - we won't log to wandb) +use_chunked_ce = False # If True, compute cross-entropy online but slower (no logits materialize) # Model architecture depth = 20 # the depth of the Transformer model to train, rest of the kwargs are derived max_seq_len = 2048 # max context length @@ -92,7 +93,7 @@ print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") print0(f"Total batch size {total_batch_size:,} => gradient accumulation steps: {grad_accum_steps}") # ----------------------------------------------------------------------------- # Initialize the Model -model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim) +model_config_kwargs = dict(sequence_len=max_seq_len, vocab_size=vocab_size, n_layer=num_layers, n_head=num_heads, n_kv_head=num_kv_heads, n_embd=model_dim, use_chunked_ce=use_chunked_ce) with torch.device("meta"): model_config = GPTConfig(**model_config_kwargs) model = GPT(model_config) diff --git a/uv.lock b/uv.lock index 7636b81..94597d6 100644 --- a/uv.lock +++ b/uv.lock @@ -254,6 +254,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "cut-cross-entropy" +version = "25.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "torch" }, + { name = "triton", marker = "sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7e/97/45ff09cfcda7b200389204daa0125168e6544fba257adbbcdf728501d4f9/cut_cross_entropy-25.1.1.tar.gz", hash = "sha256:5fe5924509248b1aea5c890f8887c6a7759f7c8b1ebc0490e42c247c4f7c1e34", size = 22972, upload-time = "2025-01-07T12:21:53.896Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/df/5f/62fdb048f84d19e2123b6bbd722fe09c8c79b4964c50094d1e979db808e2/cut_cross_entropy-25.1.1-py3-none-any.whl", hash = "sha256:e46f26d348f6a67927d17e65c5a212e795be13dcad5b10a77a200d6b8102d9d1", size = 22672, upload-time = "2025-01-07T12:21:51.678Z" }, +] + [[package]] name = "datasets" version = "4.0.0" @@ -755,6 +768,7 @@ name = "nanochat" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "cut-cross-entropy" }, { name = "datasets" }, { name = "fastapi" }, { name = "files-to-prompt" }, @@ -776,6 +790,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "cut-cross-entropy", specifier = ">=25.1.1" }, { name = "datasets", specifier = ">=4.0.0" }, { name = "fastapi", specifier = ">=0.117.1" }, { name = "files-to-prompt", specifier = ">=0.6" },