diff --git a/miniseries.sh b/miniseries.sh index 9a4512b6..c42544e3 100644 --- a/miniseries.sh +++ b/miniseries.sh @@ -61,7 +61,6 @@ for d in "${DEPTHS[@]}"; do # No --target-flops, let it use the default ratio from base_train torchrun --standalone --nproc_per_node=$NPROC_PER_NODE -m scripts.base_train -- \ --depth=$d \ - --target-param-data-ratio=8 \ --run="${WANDB_RUN}_d${d}" \ --model-tag="${TAG}" \ --core-metric-every=999999 \ diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 86f440bf..cb4bd05b 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -28,8 +28,8 @@ from nanochat.flash_attention import flash_attn @dataclass class GPTConfig: - sequence_len: int = 1024 - vocab_size: int = 50304 + sequence_len: int = 2048 + vocab_size: int = 32768 n_layer: int = 12 n_head: int = 6 # number of query heads n_kv_head: int = 6 # number of key/value heads (GQA) @@ -37,7 +37,7 @@ class GPTConfig: # Sliding window attention pattern string, tiled across layers. Final layer always L. # Characters: L=long (full context), S=short (half context) # Examples: "L"=all full context, "SL"=alternating, "SSL"=two short then one long - window_pattern: str = "L" + window_pattern: str = "SSSL" def norm(x): @@ -45,6 +45,10 @@ def norm(x): return F.rms_norm(x, (x.size(-1),)) +def has_ve(layer_idx, n_layer): + """Returns True if GPT layer should have Value Embedding (alternating, last layer always included).""" + return layer_idx % 2 == (n_layer - 1) % 2 + def apply_rotary_emb(x, cos, sin): assert x.ndim == 4 # multihead attention d = x.shape[3] // 2 @@ -67,8 +71,10 @@ class CausalSelfAttention(nn.Module): self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) + self.ve_gate_channels = 32 + self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None - def forward(self, x, cos_sin, window_size, kv_cache): + def forward(self, x, ve, cos_sin, window_size, kv_cache): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -77,6 +83,12 @@ class CausalSelfAttention(nn.Module): k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim) v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim) + # Value residual (ResFormer): mix in value embedding with input-dependent gate per head + if ve is not None: + ve = ve.view(B, T, self.n_kv_head, self.head_dim) + gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels])) # (B, T, n_kv_head), range (0, 2) + v = v + gate.unsqueeze(-1) * ve + # Apply Rotary Embeddings to queries and keys to get relative positional encoding cos, sin = cos_sin q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin) @@ -126,8 +138,8 @@ class Block(nn.Module): self.attn = CausalSelfAttention(config, layer_idx) self.mlp = MLP(config) - def forward(self, x, cos_sin, window_size, kv_cache): - x = x + self.attn(norm(x), cos_sin, window_size, kv_cache) + def forward(self, x, ve, cos_sin, window_size, kv_cache): + x = x + self.attn(norm(x), ve, cos_sin, window_size, kv_cache) x = x + self.mlp(norm(x)) return x @@ -160,6 +172,10 @@ class GPT(nn.Module): # Separate parameters so they can have different optimizer treatment self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer)) # fake init, real init in init_weights() self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer)) # fake init, real init in init_weights() + # Value embeddings (ResFormer-style): alternating layers, last layer always included + head_dim = config.n_embd // config.n_head + kv_dim = config.n_kv_head * head_dim + self.value_embeds = nn.ModuleDict({str(i): nn.Embedding(padded_vocab_size, kv_dim) for i in range(config.n_layer) if has_ve(i, config.n_layer)}) # To support meta device initialization, we init the rotary embeddings here, but it's just "fake" meta tensors only. # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them by 10X, but assert fail if we ever reach that amount. @@ -170,6 +186,7 @@ class GPT(nn.Module): self.register_buffer("cos", cos, persistent=False) # persistent=False means it's not saved to the checkpoint self.register_buffer("sin", sin, persistent=False) + @torch.no_grad() def init_weights(self): """ Initialize the full model in this one function for maximum clarity. @@ -201,18 +218,28 @@ class GPT(nn.Module): torch.nn.init.zeros_(block.mlp.c_proj.weight) # Per-layer scalars - with torch.no_grad(): - self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init - self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init + self.resid_lambdas.fill_(1.0) # 1.0 => typical residual connections at init + self.x0_lambdas.fill_(0.0) # 0.0 => skip connection to input is disabled at init + + # Value embeddings (init like c_v: uniform with same std) + for ve in self.value_embeds.values(): + torch.nn.init.uniform_(ve.weight, -s, s) + + # Gate weights init to zero so gates start at sigmoid(0) = 0.5, scaled by 2 -> 1.0 (neutral) + for block in self.transformer.h: + if block.attn.ve_gate is not None: + torch.nn.init.zeros_(block.attn.ve_gate.weight) # Rotary embeddings head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast token embeddings to bf16: optimizer can tolerate it and it saves memory + # Cast embeddings to bf16: optimizer can tolerate it and it saves memory if self.transformer.wte.weight.device.type == "cuda": self.transformer.wte.to(dtype=torch.bfloat16) + for ve in self.value_embeds.values(): + ve.to(dtype=torch.bfloat16) def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # TODO: bump base theta more? e.g. 100K is more common more recently @@ -277,7 +304,9 @@ class GPT(nn.Module): """ nparams = sum(p.numel() for p in self.parameters()) # Exclude non-matmul params: embeddings and per-layer scalars - nparams_exclude = self.transformer.wte.weight.numel() + self.resid_lambdas.numel() + self.x0_lambdas.numel() + value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values()) + nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel + + self.resid_lambdas.numel() + self.x0_lambdas.numel()) h, q, t = self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len # Sum attention FLOPs per layer, accounting for sliding window attn_flops = 0 @@ -303,13 +332,14 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into 5 groups (matrix, embedding, lm_head, resid_lambdas, x0_lambdas) + # Separate out all parameters into groups matrix_params = list(self.transformer.h.parameters()) + value_embeds_params = list(self.value_embeds.parameters()) embedding_params = list(self.transformer.wte.parameters()) lm_head_params = list(self.lm_head.parameters()) resid_params = [self.resid_lambdas] x0_params = [self.x0_lambdas] - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(resid_params) + len(x0_params) + assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params) # Create the AdamW optimizer for the embedding, lm_head, and per-layer scalars # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 @@ -317,6 +347,7 @@ class GPT(nn.Module): adam_groups = [ dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale), # same LR as token embedding dict(params=resid_params, lr=scalar_lr * 0.01), # these are a lot more sensitive because they accumulate in the residual stream dict(params=x0_params, lr=scalar_lr), ] @@ -351,7 +382,8 @@ class GPT(nn.Module): x0 = x # save initial normalized embedding for x0 residual for i, block in enumerate(self.transformer.h): x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0 - x = block(x, cos_sin, self.window_sizes[i], kv_cache) + ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None + x = block(x, ve, cos_sin, self.window_sizes[i], kv_cache) x = norm(x) # Forward the lm_head (compute logits) diff --git a/scaling_laws.sh b/scaling_laws.sh index 321b286a..1f9dab87 100644 --- a/scaling_laws.sh +++ b/scaling_laws.sh @@ -1,20 +1,23 @@ #!/bin/bash +LABEL="jan16" + FLOPS_BUDGETS=( 1e18 3e18 6e18 ) -DEPTHS=(8 10 12 14 16 18 20) +DEPTHS=(6 7 8 9 10 11 12 13 14) + NPROC_PER_NODE="${NPROC_PER_NODE:-8}" -WANDB_RUN="${WANDB_RUN:-scaling}" +WANDB_RUN="${WANDB_RUN:-scaling_${LABEL}}" EVAL_TOKENS=$((100 * 524288)) # ~100M tokens for final eval (default is ~10M) export OMP_NUM_THREADS=1 export NANOCHAT_BASE_DIR="${NANOCHAT_BASE_DIR:-$HOME/.cache/nanochat}" source .venv/bin/activate -RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results" +RESULTS_DIR="$NANOCHAT_BASE_DIR/scaling_laws_results_${LABEL}" mkdir -p "$RESULTS_DIR" RESULTS_FILE="$RESULTS_DIR/results.csv" diff --git a/scripts/base_train.py b/scripts/base_train.py index e051f99b..2d614774 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -47,7 +47,7 @@ parser.add_argument("--window-pattern", type=str, default="SSSL", help="sliding # Training horizon (only one used, in order of precedence) parser.add_argument("--num-iterations", type=int, default=-1, help="explicit number of optimization steps (-1 = disable)") parser.add_argument("--target-flops", type=float, default=-1.0, help="calculate num_iterations to reach target_flops (-1 = disable)") -parser.add_argument("--target-param-data-ratio", type=int, default=8, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") +parser.add_argument("--target-param-data-ratio", type=int, default=4, help="calculate num_iterations to maintain data:param ratio (Chinchilla=20, -1 = disable)") # Optimization parser.add_argument("--device-batch-size", type=int, default=32, help="per-device batch size") parser.add_argument("--total-batch-size", type=int, default=524288, help="total batch size in tokens") @@ -112,21 +112,19 @@ vocab_size = tokenizer.get_vocab_size() print0(f"Vocab size: {vocab_size:,}") # Model kwargs are derived from the desired depth of the model +# We nudge model_dim up to the nearest multiple of head_dim to ensure clean division +# (FA3 requires head_dim divisible by 8, and this guarantees head_dim == args.head_dim exactly) +# (For very small depths, this gives a slight "unfair" advantage to models with odd depths) num_layers = args.depth -model_dim = args.depth * args.aspect_ratio -def find_num_heads(model_dim, target_head_dim): - # Find num_heads that divides model_dim evenly, with head_dim closest to target. - ideal = max(1, round(model_dim / target_head_dim)) - for offset in range(model_dim): - for candidate in [ideal + offset, ideal - offset]: - if candidate > 0 and model_dim % candidate == 0: - return candidate - return 1 -num_heads = find_num_heads(model_dim, args.head_dim) +base_dim = args.depth * args.aspect_ratio +model_dim = ((base_dim + args.head_dim - 1) // args.head_dim) * args.head_dim +num_heads = model_dim // args.head_dim num_kv_heads = num_heads # default is 1:1 GQA (Group Query Attention) ratio (i.e. GQA is disabled) +head_dim = model_dim // num_heads print0(f"num_layers: {num_layers}") -print0(f"model_dim: {model_dim}") +print0(f"model_dim: {model_dim} (base: {base_dim}, nudge: {model_dim - base_dim:+d})") print0(f"num_heads: {num_heads}") +print0(f"head_dim: {head_dim}") print0(f"num_kv_heads: {num_kv_heads}") # Optimizer / data / training length related hyperparameters