diff --git a/nanochat/common.py b/nanochat/common.py index d48350f..b3a717d 100644 --- a/nanochat/common.py +++ b/nanochat/common.py @@ -91,6 +91,13 @@ def get_device_type(): return "cpu" return "cuda" +def get_default_dtype(): + """Get the default dtype for training: bfloat16 on GPU, float32 on CPU.""" + # bfloat16 is well-supported on modern GPUs but may have issues on CPU + if torch.cuda.is_available(): + return torch.bfloat16 + return torch.float32 + def get_dist_info(): if is_ddp(): assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE']) diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..649c362 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -19,7 +19,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from nanochat.common import get_dist_info, print0 +from nanochat.common import get_dist_info, print0, get_default_dtype from nanochat.muon import Muon, DistMuon from nanochat.adamw import DistAdamW @@ -169,8 +169,9 @@ class GPT(nn.Module): cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - self.transformer.wte.to(dtype=torch.bfloat16) + # Cast the embeddings to the default dtype: optim can tolerate it and it saves memory: both in the model and the activations + default_dtype = get_default_dtype() + self.transformer.wte.to(dtype=default_dtype) def init_weights(self): self.apply(self._init_weights) @@ -210,7 +211,9 @@ class GPT(nn.Module): # calculate the rotation frequencies at each (time, channel) pair freqs = torch.outer(t, inv_freq) cos, sin = freqs.cos(), freqs.sin() - cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16 + # keep them in the default dtype (bfloat16 on GPU, float32 on CPU) + default_dtype = get_default_dtype() + cos, sin = cos.to(default_dtype), sin.to(default_dtype) cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting return cos, sin @@ -262,7 +265,9 @@ class GPT(nn.Module): # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" + # Rotary embeddings should match the default dtype for the platform + expected_dtype = get_default_dtype() + assert self.cos.dtype == expected_dtype, f"Rotary embeddings must be in {expected_dtype}, but got {self.cos.dtype}" # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache T0 = 0 if kv_cache is None else kv_cache.get_pos() cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length diff --git a/scripts/base_eval.py b/scripts/base_eval.py index 73e82f2..ef4d064 100644 --- a/scripts/base_eval.py +++ b/scripts/base_eval.py @@ -19,7 +19,7 @@ import yaml import pandas as pd import torch -from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type +from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type, get_default_dtype from nanochat.tokenizer import HuggingFaceTokenizer from nanochat.checkpoint_manager import load_model from nanochat.core_eval import evaluate_task @@ -122,7 +122,7 @@ def main(): # distributed / precision setup ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() - autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) + autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # Load model and tokenizer from command line or from file system if len(sys.argv) >= 2: diff --git a/scripts/base_loss.py b/scripts/base_loss.py index b75ee58..e954688 100644 --- a/scripts/base_loss.py +++ b/scripts/base_loss.py @@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss import os import torch from nanochat.checkpoint_manager import load_model -from nanochat.common import compute_init, print0, compute_cleanup, get_device_type +from nanochat.common import compute_init, print0, compute_cleanup, get_device_type, get_default_dtype from nanochat.dataloader import tokenizing_distributed_data_loader from nanochat.tokenizer import get_token_bytes from nanochat.loss_eval import evaluate_bpb @@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really # Set up the precision we'll run with -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # Evaluate the loss on each split tokens_per_step = device_batch_size * sequence_len * ddp_world_size diff --git a/scripts/base_train.py b/scripts/base_train.py index ead1c09..4875701 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -16,7 +16,7 @@ import torch from nanochat.gpt import GPT, GPTConfig from nanochat.dataloader import tokenizing_distributed_data_loader -from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type +from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type, get_default_dtype from nanochat.tokenizer import get_tokenizer, get_token_bytes from nanochat.checkpoint_manager import save_checkpoint from nanochat.loss_eval import evaluate_bpb @@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin # Compute init ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) # wandb logging init use_dummy_wandb = run == "dummy" or not master_process diff --git a/scripts/chat_cli.py b/scripts/chat_cli.py index e90d084..2f8f9d7 100644 --- a/scripts/chat_cli.py +++ b/scripts/chat_cli.py @@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid """ import argparse import torch -from nanochat.common import compute_init, get_device_type +from nanochat.common import compute_init, get_device_type, get_default_dtype from nanochat.engine import Engine from nanochat.checkpoint_manager import load_model @@ -21,7 +21,7 @@ args = parser.parse_args() # Init the model and tokenizer ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step) # Special tokens for the chat state machine diff --git a/scripts/chat_web.py b/scripts/chat_web.py index 412ccd6..ea120eb 100644 --- a/scripts/chat_web.py +++ b/scripts/chat_web.py @@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse from pydantic import BaseModel from typing import List, Optional, AsyncGenerator -from nanochat.common import compute_init, get_device_type +from nanochat.common import compute_init, get_device_type, get_default_dtype from nanochat.checkpoint_manager import load_model from nanochat.engine import Engine @@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th args = parser.parse_args() ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init() -autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16) +autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype()) class ChatMessage(BaseModel): role: str