mirror of
https://github.com/karpathy/nanochat.git
synced 2026-02-21 02:50:25 +00:00
Refactor GPT initialization to optimize checkpoint loading
This commit is contained in:
parent
4610a838a1
commit
16d691d5f1
|
|
@ -98,7 +98,7 @@ def build_model(checkpoint_dir, step, device, phase):
|
|||
model = GPT(model_config)
|
||||
# Load the model state
|
||||
model.to_empty(device=device)
|
||||
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
|
||||
model.setup_buffers()
|
||||
model.load_state_dict(model_data, strict=True, assign=True)
|
||||
# Put the model in the right training phase / mode
|
||||
if phase == "eval":
|
||||
|
|
|
|||
|
|
@ -176,6 +176,18 @@ class GPT(nn.Module):
|
|||
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
|
||||
self.register_buffer("sin", sin, persistent=False)
|
||||
|
||||
|
||||
def setup_buffers(self):
|
||||
"""
|
||||
Only initializes the buffers (Rotary Embeddings) and strict type casts.
|
||||
"""
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def init_weights(self):
|
||||
"""
|
||||
Initialize the full model in this one function for maximum clarity.
|
||||
|
|
@ -211,14 +223,7 @@ class GPT(nn.Module):
|
|||
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
|
||||
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
|
||||
|
||||
# Rotary embeddings
|
||||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
|
||||
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
self.setup_buffers()
|
||||
|
||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# TODO: bump base theta more? e.g. 100K is more common more recently
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user