Refactor GPT initialization to optimize checkpoint loading

This commit is contained in:
Pranoy 2026-01-12 16:30:11 +05:30
parent 4610a838a1
commit 16d691d5f1
2 changed files with 14 additions and 9 deletions

View File

@ -98,7 +98,7 @@ def build_model(checkpoint_dir, step, device, phase):
model = GPT(model_config)
# Load the model state
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.setup_buffers()
model.load_state_dict(model_data, strict=True, assign=True)
# Put the model in the right training phase / mode
if phase == "eval":

View File

@ -176,6 +176,18 @@ class GPT(nn.Module):
self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint
self.register_buffer("sin", sin, persistent=False)
def setup_buffers(self):
"""
Only initializes the buffers (Rotary Embeddings) and strict type casts.
"""
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def init_weights(self):
"""
Initialize the full model in this one function for maximum clarity.
@ -211,14 +223,7 @@ class GPT(nn.Module):
self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init
self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init
# Rotary embeddings
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast token embeddings to bf16: optimizer can tolerate it and it saves memory
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
self.setup_buffers()
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# TODO: bump base theta more? e.g. 100K is more common more recently