diff --git a/nanochat/checkpoint_manager.py b/nanochat/checkpoint_manager.py index cca6294..8179c96 100644 --- a/nanochat/checkpoint_manager.py +++ b/nanochat/checkpoint_manager.py @@ -98,7 +98,7 @@ def build_model(checkpoint_dir, step, device, phase): model = GPT(model_config) # Load the model state model.to_empty(device=device) - model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init + model.setup_buffers() model.load_state_dict(model_data, strict=True, assign=True) # Put the model in the right training phase / mode if phase == "eval": diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 81ccb0c..29ca8f6 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -176,6 +176,18 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + + def setup_buffers(self): + """ + Only initializes the buffers (Rotary Embeddings) and strict type casts. + """ + head_dim = self.config.n_embd // self.config.n_head + cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) + self.cos, self.sin = cos, sin + + if self.transformer.wte.weight.device.type == "cuda": + self.transformer.wte.to(dtype=torch.bfloat16) + def init_weights(self): """ Initialize the full model in this one function for maximum clarity. @@ -211,14 +223,7 @@ class GPT(nn.Module): self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init - # Rotary embeddings - head_dim = self.config.n_embd // self.config.n_head - cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) - self.cos, self.sin = cos, sin - - # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory - if self.transformer.wte.weight.device.type == "cuda": - self.transformer.wte.to(dtype=torch.bfloat16) + self.setup_buffers() def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently