mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-08 05:12:16 +00:00
feat: dynamic dtype selection
This commit is contained in:
parent
447567634c
commit
662ff7eb7a
|
|
@ -91,6 +91,13 @@ def get_device_type():
|
||||||
return "cpu"
|
return "cpu"
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def get_default_dtype():
|
||||||
|
"""Get the default dtype for training: bfloat16 on GPU, float32 on CPU."""
|
||||||
|
# bfloat16 is well-supported on modern GPUs but may have issues on CPU
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_dist_info():
|
def get_dist_info():
|
||||||
if is_ddp():
|
if is_ddp():
|
||||||
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
assert all(var in os.environ for var in ['RANK', 'LOCAL_RANK', 'WORLD_SIZE'])
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanochat.common import get_dist_info, print0
|
from nanochat.common import get_dist_info, print0, get_default_dtype
|
||||||
from nanochat.muon import Muon, DistMuon
|
from nanochat.muon import Muon, DistMuon
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
|
|
||||||
|
|
@ -169,8 +169,9 @@ class GPT(nn.Module):
|
||||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||||
self.register_buffer("sin", sin, persistent=False)
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
# Cast the embeddings to the default dtype: optim can tolerate it and it saves memory: both in the model and the activations
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
default_dtype = get_default_dtype()
|
||||||
|
self.transformer.wte.to(dtype=default_dtype)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
|
|
@ -210,7 +211,9 @@ class GPT(nn.Module):
|
||||||
# calculate the rotation frequencies at each (time, channel) pair
|
# calculate the rotation frequencies at each (time, channel) pair
|
||||||
freqs = torch.outer(t, inv_freq)
|
freqs = torch.outer(t, inv_freq)
|
||||||
cos, sin = freqs.cos(), freqs.sin()
|
cos, sin = freqs.cos(), freqs.sin()
|
||||||
cos, sin = cos.bfloat16(), sin.bfloat16() # keep them in bfloat16
|
# keep them in the default dtype (bfloat16 on GPU, float32 on CPU)
|
||||||
|
default_dtype = get_default_dtype()
|
||||||
|
cos, sin = cos.to(default_dtype), sin.to(default_dtype)
|
||||||
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
cos, sin = cos[None, :, None, :], sin[None, :, None, :] # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
|
@ -262,7 +265,9 @@ class GPT(nn.Module):
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
# Rotary embeddings should match the default dtype for the platform
|
||||||
|
expected_dtype = get_default_dtype()
|
||||||
|
assert self.cos.dtype == expected_dtype, f"Rotary embeddings must be in {expected_dtype}, but got {self.cos.dtype}"
|
||||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ import yaml
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, get_base_dir, get_device_type, get_default_dtype
|
||||||
from nanochat.tokenizer import HuggingFaceTokenizer
|
from nanochat.tokenizer import HuggingFaceTokenizer
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.core_eval import evaluate_task
|
from nanochat.core_eval import evaluate_task
|
||||||
|
|
@ -122,7 +122,7 @@ def main():
|
||||||
|
|
||||||
# distributed / precision setup
|
# distributed / precision setup
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
|
||||||
|
|
||||||
# Load model and tokenizer from command line or from file system
|
# Load model and tokenizer from command line or from file system
|
||||||
if len(sys.argv) >= 2:
|
if len(sys.argv) >= 2:
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.common import compute_init, print0, compute_cleanup, get_device_type
|
from nanochat.common import compute_init, print0, compute_cleanup, get_device_type, get_default_dtype
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.tokenizer import get_token_bytes
|
from nanochat.tokenizer import get_token_bytes
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -28,7 +28,7 @@ model, tokenizer, meta = load_model("base", device, phase="eval", model_tag=mode
|
||||||
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
sequence_len = meta["model_config"]["sequence_len"] # could be arbitrary really
|
||||||
|
|
||||||
# Set up the precision we'll run with
|
# Set up the precision we'll run with
|
||||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
|
||||||
|
|
||||||
# Evaluate the loss on each split
|
# Evaluate the loss on each split
|
||||||
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
tokens_per_step = device_batch_size * sequence_len * ddp_world_size
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ import torch
|
||||||
|
|
||||||
from nanochat.gpt import GPT, GPTConfig
|
from nanochat.gpt import GPT, GPTConfig
|
||||||
from nanochat.dataloader import tokenizing_distributed_data_loader
|
from nanochat.dataloader import tokenizing_distributed_data_loader
|
||||||
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type
|
from nanochat.common import compute_init, compute_cleanup, print0, DummyWandb, print_banner, get_base_dir, get_device_type, get_default_dtype
|
||||||
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
from nanochat.tokenizer import get_tokenizer, get_token_bytes
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.loss_eval import evaluate_bpb
|
from nanochat.loss_eval import evaluate_bpb
|
||||||
|
|
@ -59,7 +59,7 @@ user_config = {k: globals()[k] for k in config_keys} # will be useful for loggin
|
||||||
# Compute init
|
# Compute init
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
use_dummy_wandb = run == "dummy" or not master_process
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ python -m scripts.chat_cli -i mid
|
||||||
"""
|
"""
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from nanochat.common import compute_init, get_device_type
|
from nanochat.common import compute_init, get_device_type, get_default_dtype
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
|
|
||||||
|
|
@ -21,7 +21,7 @@ args = parser.parse_args()
|
||||||
|
|
||||||
# Init the model and tokenizer
|
# Init the model and tokenizer
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
|
||||||
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
|
||||||
|
|
||||||
# Special tokens for the chat state machine
|
# Special tokens for the chat state machine
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,7 @@ from fastapi.responses import StreamingResponse, HTMLResponse, FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Optional, AsyncGenerator
|
from typing import List, Optional, AsyncGenerator
|
||||||
|
|
||||||
from nanochat.common import compute_init, get_device_type
|
from nanochat.common import compute_init, get_device_type, get_default_dtype
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
||||||
|
|
@ -32,7 +32,7 @@ parser.add_argument('--host', type=str, default='0.0.0.0', help='Host to bind th
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
|
||||||
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=torch.bfloat16)
|
autocast_ctx = torch.amp.autocast(device_type=get_device_type(), dtype=get_default_dtype())
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: str
|
role: str
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user