mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-14 08:12:17 +00:00
Instrument main gpt model with logging
This commit is contained in:
parent
8a40915246
commit
62cfe4d4c3
118
nanochat/gpt.py
118
nanochat/gpt.py
|
|
@ -12,16 +12,17 @@ Notable features:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from logger import log
|
||||||
|
|
||||||
from nanochat.common import get_dist_info, print0
|
|
||||||
from nanochat.muon import Muon, DistMuon
|
|
||||||
from nanochat.adamw import DistAdamW
|
from nanochat.adamw import DistAdamW
|
||||||
|
from nanochat.common import get_dist_info, print0
|
||||||
|
from nanochat.muon import DistMuon, Muon
|
||||||
from nanochat.vsa import HRROperations
|
from nanochat.vsa import HRROperations
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -40,12 +41,13 @@ class GPTConfig:
|
||||||
d_vsa: int = 256
|
d_vsa: int = 256
|
||||||
|
|
||||||
|
|
||||||
|
@log
|
||||||
def norm(x):
|
def norm(x):
|
||||||
# Purely functional rmsnorm with no learnable params
|
# Purely functional rmsnorm with no learnable params
|
||||||
return F.rms_norm(x, (x.size(-1),))
|
return F.rms_norm(x, (x.size(-1),))
|
||||||
|
|
||||||
|
|
||||||
|
@log
|
||||||
def apply_rotary_emb(x, cos, sin):
|
def apply_rotary_emb(x, cos, sin):
|
||||||
assert x.ndim == 4 # multihead attention
|
assert x.ndim == 4 # multihead attention
|
||||||
d = x.shape[3] // 2
|
d = x.shape[3] // 2
|
||||||
|
|
@ -58,6 +60,7 @@ def apply_rotary_emb(x, cos, sin):
|
||||||
|
|
||||||
|
|
||||||
class CausalSelfAttention(nn.Module):
|
class CausalSelfAttention(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
@ -72,6 +75,7 @@ class CausalSelfAttention(nn.Module):
|
||||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(self, x, cos_sin, kv_cache):
|
||||||
B, T, C = x.size()
|
B, T, C = x.size()
|
||||||
|
|
||||||
|
|
@ -112,15 +116,11 @@ class CausalSelfAttention(nn.Module):
|
||||||
if kv_cache is None or Tq == Tk:
|
if kv_cache is None or Tq == Tk:
|
||||||
# During training (no KV cache), attend as usual with causal attention
|
# During training (no KV cache), attend as usual with causal attention
|
||||||
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
# And even if there is KV cache, we can still use this simple version when Tq == Tk
|
||||||
y = F.scaled_dot_product_attention(
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||||
q, k, v, is_causal=True
|
|
||||||
)
|
|
||||||
elif Tq == 1:
|
elif Tq == 1:
|
||||||
# During inference but with a single query in this forward pass:
|
# During inference but with a single query in this forward pass:
|
||||||
# The query has to attend to all the keys/values in the cache
|
# The query has to attend to all the keys/values in the cache
|
||||||
y = F.scaled_dot_product_attention(
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
|
||||||
q, k, v, is_causal=False
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# During inference AND we have a chunk of queries in this forward pass:
|
# During inference AND we have a chunk of queries in this forward pass:
|
||||||
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
# First, each query attends to all the cached keys/values (i.e. full prefix)
|
||||||
|
|
@ -134,9 +134,7 @@ class CausalSelfAttention(nn.Module):
|
||||||
attn_mask[:, prefix_len:] = torch.tril(
|
attn_mask[:, prefix_len:] = torch.tril(
|
||||||
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
|
torch.ones((Tq, Tq), dtype=torch.bool, device=q.device)
|
||||||
)
|
)
|
||||||
y = F.scaled_dot_product_attention(
|
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
|
||||||
q, k, v, attn_mask=attn_mask
|
|
||||||
)
|
|
||||||
|
|
||||||
# Re-assemble the heads side by side and project back to residual stream
|
# Re-assemble the heads side by side and project back to residual stream
|
||||||
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
y = y.transpose(1, 2).contiguous().view(B, T, -1)
|
||||||
|
|
@ -145,18 +143,22 @@ class CausalSelfAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
||||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.c_fc(x)
|
x = self.c_fc(x)
|
||||||
x = F.relu(x).square()
|
x = F.relu(x).square()
|
||||||
x = self.c_proj(x)
|
x = self.c_proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Router(nn.Module):
|
class Router(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = config.num_experts
|
self.num_experts = config.num_experts
|
||||||
|
|
@ -173,19 +175,26 @@ class Router(nn.Module):
|
||||||
# Project tokens to VSA space
|
# Project tokens to VSA space
|
||||||
self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False)
|
self.token_to_vsa = nn.Linear(config.n_embd, self.d_vsa, bias=False)
|
||||||
# Fixed expert signatures in VSA space (not learnable)
|
# Fixed expert signatures in VSA space (not learnable)
|
||||||
self.register_buffer('expert_signatures', torch.randn(self.num_experts, self.d_vsa))
|
self.register_buffer(
|
||||||
|
"expert_signatures", torch.randn(self.num_experts, self.d_vsa)
|
||||||
|
)
|
||||||
# Single VSA superposition cache containing all token-expert bindings
|
# Single VSA superposition cache containing all token-expert bindings
|
||||||
self.register_buffer('vsa_cache', torch.zeros(self.d_vsa))
|
self.register_buffer("vsa_cache", torch.zeros(self.d_vsa))
|
||||||
# Combine normal and VSA routing
|
# Combine normal and VSA routing
|
||||||
self.vsa_weight = nn.Parameter(torch.tensor(0.5)) # Learnable weight for VSA contribution
|
self.vsa_weight = nn.Parameter(
|
||||||
|
torch.tensor(0.5)
|
||||||
|
) # Learnable weight for VSA contribution
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x: (B, T, D)
|
# x: (B, T, D)
|
||||||
B, T, D = x.size()
|
B, T, D = x.size()
|
||||||
|
|
||||||
# Normal routing to determine which experts are activated
|
# Normal routing to determine which experts are activated
|
||||||
normal_logits = self.w_gate(x) # (B, T, num_experts)
|
normal_logits = self.w_gate(x) # (B, T, num_experts)
|
||||||
topk_logits, topk_indices = torch.topk(normal_logits, self.top_k, dim=-1) # (B, T, top_k)
|
topk_logits, topk_indices = torch.topk(
|
||||||
|
normal_logits, self.top_k, dim=-1
|
||||||
|
) # (B, T, top_k)
|
||||||
|
|
||||||
if self.use_vsa:
|
if self.use_vsa:
|
||||||
# Project tokens to VSA space
|
# Project tokens to VSA space
|
||||||
|
|
@ -196,7 +205,9 @@ class Router(nn.Module):
|
||||||
if self.training:
|
if self.training:
|
||||||
# Vectorized cache update
|
# Vectorized cache update
|
||||||
# Flatten indices and tokens for batch processing
|
# Flatten indices and tokens for batch processing
|
||||||
flat_tokens = token_vsa.unsqueeze(2).expand(-1, -1, self.top_k, -1) # (B, T, top_k, d_vsa)
|
flat_tokens = token_vsa.unsqueeze(2).expand(
|
||||||
|
-1, -1, self.top_k, -1
|
||||||
|
) # (B, T, top_k, d_vsa)
|
||||||
flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa)
|
flat_tokens = flat_tokens.reshape(-1, self.d_vsa) # (B*T*top_k, d_vsa)
|
||||||
flat_indices = topk_indices.reshape(-1) # (B*T*top_k,)
|
flat_indices = topk_indices.reshape(-1) # (B*T*top_k,)
|
||||||
|
|
||||||
|
|
@ -211,40 +222,66 @@ class Router(nn.Module):
|
||||||
self.vsa_cache = self.vsa_cache + cache_update.detach()
|
self.vsa_cache = self.vsa_cache + cache_update.detach()
|
||||||
|
|
||||||
# Compute VSA routing scores by querying the single superposition
|
# Compute VSA routing scores by querying the single superposition
|
||||||
vsa_scores = torch.zeros(B, T, self.num_experts, device=x.device, dtype=x.dtype)
|
vsa_scores = torch.zeros(
|
||||||
|
B, T, self.num_experts, device=x.device, dtype=x.dtype
|
||||||
|
)
|
||||||
if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content
|
if torch.norm(self.vsa_cache) > 1e-8: # Only if cache has content
|
||||||
# Vectorized similarity computation
|
# Vectorized similarity computation
|
||||||
# Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa)
|
# Expand tokens to (B, T, num_experts, d_vsa) and expert_signatures to (B, T, num_experts, d_vsa)
|
||||||
token_expanded = token_vsa.unsqueeze(2).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
token_expanded = token_vsa.unsqueeze(2).expand(
|
||||||
expert_expanded = self.expert_signatures.unsqueeze(0).unsqueeze(0).expand(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
B, T, self.num_experts, self.d_vsa
|
||||||
|
) # (B, T, num_experts, d_vsa)
|
||||||
|
expert_expanded = (
|
||||||
|
self.expert_signatures.unsqueeze(0)
|
||||||
|
.unsqueeze(0)
|
||||||
|
.expand(B, T, self.num_experts, self.d_vsa)
|
||||||
|
) # (B, T, num_experts, d_vsa)
|
||||||
|
|
||||||
# Vectorized binding: bind all tokens with all expert signatures
|
# Vectorized binding: bind all tokens with all expert signatures
|
||||||
queries = self.vsa_ops.bind(token_expanded.reshape(-1, self.d_vsa), expert_expanded.reshape(-1, self.d_vsa)) # (B*T*num_experts, d_vsa)
|
queries = self.vsa_ops.bind(
|
||||||
queries = queries.reshape(B, T, self.num_experts, self.d_vsa) # (B, T, num_experts, d_vsa)
|
token_expanded.reshape(-1, self.d_vsa),
|
||||||
|
expert_expanded.reshape(-1, self.d_vsa),
|
||||||
|
) # (B*T*num_experts, d_vsa)
|
||||||
|
queries = queries.reshape(
|
||||||
|
B, T, self.num_experts, self.d_vsa
|
||||||
|
) # (B, T, num_experts, d_vsa)
|
||||||
|
|
||||||
# Compute similarities with the cache for all queries at once
|
# Compute similarities with the cache for all queries at once
|
||||||
# Flatten queries and compute all similarities in one batch
|
# Flatten queries and compute all similarities in one batch
|
||||||
queries_flat = queries.reshape(-1, self.d_vsa) # (B*T*num_experts, d_vsa)
|
queries_flat = queries.reshape(
|
||||||
cache_expanded = self.vsa_cache.unsqueeze(0).expand(queries_flat.size(0), -1) # (B*T*num_experts, d_vsa)
|
-1, self.d_vsa
|
||||||
|
) # (B*T*num_experts, d_vsa)
|
||||||
|
cache_expanded = self.vsa_cache.unsqueeze(0).expand(
|
||||||
|
queries_flat.size(0), -1
|
||||||
|
) # (B*T*num_experts, d_vsa)
|
||||||
|
|
||||||
# Use VSA similarity operation for all queries
|
# Use VSA similarity operation for all queries
|
||||||
similarities_flat = self.vsa_ops.similarity(queries_flat, cache_expanded) # (B*T*num_experts,)
|
similarities_flat = self.vsa_ops.similarity(
|
||||||
|
queries_flat, cache_expanded
|
||||||
|
) # (B*T*num_experts,)
|
||||||
vsa_scores = similarities_flat.reshape(B, T, self.num_experts)
|
vsa_scores = similarities_flat.reshape(B, T, self.num_experts)
|
||||||
|
|
||||||
# Combine normal and VSA routing with learnable weight
|
# Combine normal and VSA routing with learnable weight
|
||||||
combined_logits = normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
|
combined_logits = (
|
||||||
topk_logits, topk_indices = torch.topk(combined_logits, self.top_k, dim=-1) # (B, T, top_k)
|
normal_logits + torch.sigmoid(self.vsa_weight) * vsa_scores
|
||||||
|
)
|
||||||
|
topk_logits, topk_indices = torch.topk(
|
||||||
|
combined_logits, self.top_k, dim=-1
|
||||||
|
) # (B, T, top_k)
|
||||||
|
|
||||||
topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k)
|
topk_gates = F.softmax(topk_logits, dim=-1) # (B, T, top_k)
|
||||||
return topk_indices, topk_gates # both are (B, T, top_k)
|
return topk_indices, topk_gates # both are (B, T, top_k)
|
||||||
|
|
||||||
|
|
||||||
class MoE(nn.Module):
|
class MoE(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_experts = config.num_experts
|
self.num_experts = config.num_experts
|
||||||
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
|
self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
|
||||||
self.router = Router(config)
|
self.router = Router(config)
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# x: (B, T, D)
|
# x: (B, T, D)
|
||||||
B, T, D = x.size()
|
B, T, D = x.size()
|
||||||
|
|
@ -259,14 +296,18 @@ class MoE(nn.Module):
|
||||||
|
|
||||||
for i in range(self.num_experts):
|
for i in range(self.num_experts):
|
||||||
# Find which tokens use expert i and their corresponding gate values
|
# Find which tokens use expert i and their corresponding gate values
|
||||||
expert_mask = (topk_indices_flat == i) # (B*T, top_k)
|
expert_mask = topk_indices_flat == i # (B*T, top_k)
|
||||||
if expert_mask.any():
|
if expert_mask.any():
|
||||||
# Get tokens that route to this expert
|
# Get tokens that route to this expert
|
||||||
token_indices, k_indices = torch.where(expert_mask)
|
token_indices, k_indices = torch.where(expert_mask)
|
||||||
if len(token_indices) > 0:
|
if len(token_indices) > 0:
|
||||||
selected_tokens = x_flat[token_indices] # (num_selected, D)
|
selected_tokens = x_flat[token_indices] # (num_selected, D)
|
||||||
expert_output = self.experts[i](selected_tokens) # (num_selected, D)
|
expert_output = self.experts[i](
|
||||||
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(-1) # (num_selected, 1)
|
selected_tokens
|
||||||
|
) # (num_selected, D)
|
||||||
|
gate_weights = topk_gates_flat[token_indices, k_indices].unsqueeze(
|
||||||
|
-1
|
||||||
|
) # (num_selected, 1)
|
||||||
|
|
||||||
# Add weighted expert output back to the corresponding tokens
|
# Add weighted expert output back to the corresponding tokens
|
||||||
output_flat[token_indices] += expert_output * gate_weights
|
output_flat[token_indices] += expert_output * gate_weights
|
||||||
|
|
@ -275,6 +316,7 @@ class MoE(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Block(nn.Module):
|
class Block(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config, layer_idx):
|
def __init__(self, config, layer_idx):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = CausalSelfAttention(config, layer_idx)
|
self.attn = CausalSelfAttention(config, layer_idx)
|
||||||
|
|
@ -283,6 +325,7 @@ class Block(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, x, cos_sin, kv_cache):
|
def forward(self, x, cos_sin, kv_cache):
|
||||||
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
x = x + self.attn(norm(x), cos_sin, kv_cache)
|
||||||
x = x + self.mlp(norm(x))
|
x = x + self.mlp(norm(x))
|
||||||
|
|
@ -290,6 +333,7 @@ class Block(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class GPT(nn.Module):
|
class GPT(nn.Module):
|
||||||
|
@log
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
@ -316,15 +360,16 @@ class GPT(nn.Module):
|
||||||
) # persistent=False means it's not saved to the checkpoint
|
) # persistent=False means it's not saved to the checkpoint
|
||||||
self.register_buffer("sin", sin, persistent=False)
|
self.register_buffer("sin", sin, persistent=False)
|
||||||
|
|
||||||
|
@log
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
self.apply(self._init_weights)
|
self.apply(self._init_weights)
|
||||||
# zero out classifier weights
|
# zero out classifier weights
|
||||||
torch.nn.init.zeros_(self.lm_head.weight)
|
torch.nn.init.zeros_(self.lm_head.weight)
|
||||||
# zero out c_proj weights in all blocks
|
# zero out c_proj weights in all blocks
|
||||||
for block in self.transformer.h:
|
for block in self.transformer.h:
|
||||||
if hasattr(block.mlp, 'c_proj'): # Regular MLP
|
if hasattr(block.mlp, "c_proj"): # Regular MLP
|
||||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||||
elif hasattr(block.mlp, 'experts'): # MoE
|
elif hasattr(block.mlp, "experts"): # MoE
|
||||||
for expert in block.mlp.experts:
|
for expert in block.mlp.experts:
|
||||||
torch.nn.init.zeros_(expert.c_proj.weight)
|
torch.nn.init.zeros_(expert.c_proj.weight)
|
||||||
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
||||||
|
|
@ -336,6 +381,7 @@ class GPT(nn.Module):
|
||||||
if self.transformer.wte.weight.device.type == "cuda":
|
if self.transformer.wte.weight.device.type == "cuda":
|
||||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
@log
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
# https://arxiv.org/pdf/2310.17813
|
# https://arxiv.org/pdf/2310.17813
|
||||||
|
|
@ -349,6 +395,7 @@ class GPT(nn.Module):
|
||||||
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
torch.nn.init.normal_(module.weight, mean=0.0, std=1.0)
|
||||||
|
|
||||||
# TODO: bump base theta more, e.g. 100K is more common more recently
|
# TODO: bump base theta more, e.g. 100K is more common more recently
|
||||||
|
@log
|
||||||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||||
# autodetect the device from model embeddings
|
# autodetect the device from model embeddings
|
||||||
if device is None:
|
if device is None:
|
||||||
|
|
@ -368,9 +415,11 @@ class GPT(nn.Module):
|
||||||
) # add batch and head dims for later broadcasting
|
) # add batch and head dims for later broadcasting
|
||||||
return cos, sin
|
return cos, sin
|
||||||
|
|
||||||
|
@log
|
||||||
def get_device(self):
|
def get_device(self):
|
||||||
return self.transformer.wte.weight.device
|
return self.transformer.wte.weight.device
|
||||||
|
|
||||||
|
@log
|
||||||
def estimate_flops(self):
|
def estimate_flops(self):
|
||||||
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
|
"""Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311"""
|
||||||
nparams = sum(p.numel() for p in self.parameters())
|
nparams = sum(p.numel() for p in self.parameters())
|
||||||
|
|
@ -384,6 +433,7 @@ class GPT(nn.Module):
|
||||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||||
return num_flops_per_token
|
return num_flops_per_token
|
||||||
|
|
||||||
|
@log
|
||||||
def setup_optimizers(
|
def setup_optimizers(
|
||||||
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
|
self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0
|
||||||
):
|
):
|
||||||
|
|
@ -421,6 +471,7 @@ class GPT(nn.Module):
|
||||||
group["initial_lr"] = group["lr"]
|
group["initial_lr"] = group["lr"]
|
||||||
return optimizers
|
return optimizers
|
||||||
|
|
||||||
|
@log
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction="mean"):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
|
|
@ -467,6 +518,7 @@ class GPT(nn.Module):
|
||||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
@log
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user