feat: dynamic dtype selection

This commit is contained in:
Kirk Lin 2025-10-14 12:22:57 +08:00
parent 447567634c
commit 662ff7eb7a
7 changed files with 27 additions and 15 deletions

View File

@ -91,6 +91,13 @@ def get_device_type():
return "cpu"
return "cuda"
def get_default_dtype():
"""Get the default dtype for training: bfloat16 on GPU, float32 on CPU."""
# bfloat16 is well-supported on modern GPUs but may have issues on CPU
if torch.cuda.is_available():
return torch.bfloat16
return torch.float32
def get_dist_info():
if is_ddp():
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])

View File

@ -19,7 +19,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from nanochat.common import get_dist_info, print0
from nanochat.common import get_dist_info, print0, get_default_dtype
from nanochat.muon import Muon, DistMuon
from nanochat.adamw import DistAdamW
@ -169,8 +169,9 @@ class GPT(nn.Module):
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
self.transformer.wte.to(dtype=torch.bfloat16)
# Cast the embeddings to the default dtype: optim can tolerate it and it saves memory: both in the model and the activations
default_dtype = get_default_dtype()
self.transformer.wte.to(dtype=default_dtype)
def init_weights(self):
self.apply(self._init_weights)
@ -210,7 +211,9 @@ class GPT(nn.Module):
# calculate the rotation frequencies at each (time, channel) pair
freqs = torch.outer(t, inv_freq)
cos, sin = freqs.cos(), freqs.sin()
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
# keep them in the default dtype (bfloat16 on GPU, float32 on CPU)
default_dtype = get_default_dtype()
cos, sin = cos.to(default_dtype), sin.to(default_dtype)
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
return cos, sin
@ -262,7 +265,9 @@ class GPT(nn.Module):
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# Rotary embeddings should match the default dtype for the platform
expected_dtype = get_default_dtype()
assert self.cos.dtype == expected_dtype, f"Rotary embeddings must be in {expected_dtype}, but got {self.cos.dtype}"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length

View File

@ -19,7 +19,7 @@ import yaml
import pandas as pd
import torch
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type, get_default_dtype
from nanochat.tokenizer import HuggingFaceTokenizer
from nanochat.checkpoint_manager import load_model
from nanochat.core_eval import evaluate_task
@ -122,7 +122,7 @@ def main():
# distributed / precision setup
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
# Load model and tokenizer from command line or from file system
if len(sys.argv) >= 2:

View File

@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
import os
import torch
from nanochat.checkpoint_manager import load_model
from nanochat.common import compute_init, print0, compute_cleanup, get_device_type
from nanochat.common import compute_init, print0, compute_cleanup, get_device_type, get_default_dtype
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.tokenizer import get_token_bytes
from nanochat.loss_eval import evaluate_bpb
@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
# Set up the precision we'll run with
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
# Evaluate the loss on each split
tokens_per_step = device_batch_size * sequence_len * ddp_world_size

View File

@ -16,7 +16,7 @@ import torch
from nanochat.gpt import GPT, GPTConfig
from nanochat.dataloader import tokenizing_distributed_data_loader
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type, get_default_dtype
from nanochat.tokenizer import get_tokenizer, get_token_bytes
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
# Compute init
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
# wandb logging init
use_dummy_wandb = run == "dummy" or not master_process

View File

@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid
"""
import argparse
import torch
from nanochat.common import compute_init, get_device_type
from nanochat.common import compute_init, get_device_type, get_default_dtype
from nanochat.engine import Engine
from nanochat.checkpoint_manager import load_model
@ -21,7 +21,7 @@ args = parser.parse_args()
# Init the model and tokenizer
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
from pydantic import BaseModel
from typing import List, Optional, AsyncGenerator
from nanochat.common import compute_init, get_device_type
from nanochat.common import compute_init, get_device_type, get_default_dtype
from nanochat.checkpoint_manager import load_model
from nanochat.engine import Engine
@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
args = parser.parse_args()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
class ChatMessage(BaseModel):
role: str