mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
hypercube-abacus
This commit is contained in:
parent
dfc88334b6
commit
317e4b65df
2
config/custom_config.py
Normal file
2
config/custom_config.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
n_layer = 24
|
||||
n_embd = 1024
|
||||
24
nanochat/abacus_encoder.py
Normal file
24
nanochat/abacus_encoder.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class AbacusEncoder(nn.Module):
|
||||
def __init__(self, input_dim: int, embedding_dim: int):
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# Simple linear layer to encode abacus-like patterns into the embedding space
|
||||
self.encoder_layer = nn.Linear(input_dim, embedding_dim)
|
||||
|
||||
def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
|
||||
# abacus_pattern is expected to be a tensor of shape (batch_size, input_dim)
|
||||
if abacus_pattern.shape[-1] != self.input_dim:
|
||||
raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}")
|
||||
return self.encoder_layer(abacus_pattern)
|
||||
|
||||
def decode(self, concept_vector: torch.Tensor) -> torch.Tensor:
|
||||
# Placeholder for decoding functionality
|
||||
raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.")
|
||||
|
||||
def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
|
||||
return self.encode(abacus_pattern)
|
||||
26
nanochat/abacus_state_memory.py
Normal file
26
nanochat/abacus_state_memory.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class AbacusStateMemory(nn.Module):
|
||||
def __init__(self, max_memory_size: int = 100, abacus_input_dim: int = 64):
|
||||
super().__init__()
|
||||
self.max_memory_size = max_memory_size
|
||||
self.abacus_input_dim = abacus_input_dim
|
||||
self.memory = [] # Stores abacus patterns (tensors)
|
||||
|
||||
def store(self, abacus_pattern: torch.Tensor):
|
||||
if len(self.memory) >= self.max_memory_size:
|
||||
self.memory.pop(0) # Remove the oldest entry
|
||||
self.memory.append(abacus_pattern.detach().cpu()) # Store detached CPU tensor
|
||||
|
||||
def retrieve(self, num_to_retrieve: int = 1) -> list[torch.Tensor]:
|
||||
# Retrieve the most recent abacus patterns
|
||||
return self.memory[-num_to_retrieve:]
|
||||
|
||||
def clear(self):
|
||||
self.memory = []
|
||||
|
||||
def forward(self, abacus_pattern: torch.Tensor):
|
||||
# For now, forward simply stores the pattern. More complex logic can be added later.
|
||||
self.store(abacus_pattern)
|
||||
return abacus_pattern
|
||||
48
nanochat/conscious_integration.py
Normal file
48
nanochat/conscious_integration.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from nanochat.abacus_encoder import AbacusEncoder
|
||||
|
||||
class ConsciousIntegrationLayer(nn.Module):
|
||||
def __init__(self, config, abacus_encoder: AbacusEncoder):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.abacus_encoder = abacus_encoder
|
||||
|
||||
# Linear layer to project the integrated state to the vocabulary size
|
||||
self.concept_projection = nn.Linear(config.n_embd, config.vocab_size)
|
||||
|
||||
def forward(self, id_output: torch.Tensor, ego_output: torch.Tensor, superego_output: torch.Tensor, long_term_memory_embeddings: torch.Tensor, memetic_fitness: torch.Tensor | None, abacus_state: torch.Tensor) -> torch.Tensor:
|
||||
# Ensure all inputs are of the same shape for integration
|
||||
# For simplicity, let's assume they are all (B, T, C) or (B, C)
|
||||
# If they are (B, C), we might need to unsqueeze for sequence dimension if T > 1
|
||||
|
||||
# Example integration: simple summation. More complex mechanisms can be added here.
|
||||
# The goal is to synthesize these into a unified conceptual state.
|
||||
synthesized_state = id_output + ego_output + superego_output
|
||||
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Ensure dimensions match for addition. long_term_memory_embeddings might be (B, C)
|
||||
# If synthesized_state is (B, T, C), expand long_term_memory_embeddings
|
||||
if synthesized_state.dim() == 3 and long_term_memory_embeddings.dim() == 2:
|
||||
long_term_memory_embeddings = long_term_memory_embeddings.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
|
||||
synthesized_state = synthesized_state + long_term_memory_embeddings
|
||||
|
||||
if memetic_fitness is not None:
|
||||
# Ensure dimensions match for addition
|
||||
if synthesized_state.dim() == 3 and memetic_fitness.dim() == 2:
|
||||
memetic_fitness = memetic_fitness.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
|
||||
synthesized_state = synthesized_state + memetic_fitness
|
||||
|
||||
# Abacus Encoder for logical consistency
|
||||
# The abacus_state is already an encoded representation, so we might use it to gate or modulate
|
||||
# the synthesized_state, or simply include it in the synthesis.
|
||||
# For now, let's add it, assuming it's compatible.
|
||||
if synthesized_state.dim() == 3 and abacus_state.dim() == 2:
|
||||
abacus_state = abacus_state.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
|
||||
synthesized_state = synthesized_state + abacus_state
|
||||
|
||||
# Project the synthesized state to the vocabulary size
|
||||
concept_logits_from_synthesis = self.concept_projection(synthesized_state)
|
||||
|
||||
return concept_logits_from_synthesis
|
||||
|
|
@ -11,7 +11,9 @@ Notes:
|
|||
The whole thing is made as efficient as possible.
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import signal
|
||||
import warnings
|
||||
|
|
@ -19,6 +21,15 @@ from contextlib import contextmanager
|
|||
from collections import deque
|
||||
from nanochat.common import compute_init
|
||||
from nanochat.checkpoint_manager import load_model
|
||||
from nanochat.gpt import GPT
|
||||
from nanochat.kv_cache import KVCache
|
||||
from nanochat.hypercube import HypercubeTopology, HypercubeEmbeddingLayer, LongTermMemory # Import HypercubeTopology
|
||||
from nanochat.abacus_encoder import AbacusEncoder
|
||||
from nanochat.self_model import InternalSelfModel
|
||||
from nanochat.abacus_state_memory import AbacusStateMemory
|
||||
from nanochat.memetic_learning import MemeticLearningLayer
|
||||
from nanochat.conscious_integration import ConsciousIntegrationLayer
|
||||
from nanochat.gpt import GPT, PsycheController
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Calculator tool helpers
|
||||
|
|
@ -151,76 +162,302 @@ class KVCache:
|
|||
self.pos = t1
|
||||
return key_view, value_view
|
||||
|
||||
def retrieve_episode(self, start_pos, end_pos):
|
||||
"""
|
||||
Retrieves a slice of the KV cache representing an 'episode' or a specific attention span.
|
||||
"""
|
||||
assert self.kv_cache is not None, "KV cache is empty, cannot retrieve episode."
|
||||
assert start_pos >= 0 and end_pos <= self.pos, "Invalid start or end position for episode retrieval."
|
||||
assert start_pos < end_pos, "Start position must be less than end position."
|
||||
|
||||
# Return a view of the relevant part of the cache
|
||||
# kv_cache shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
|
||||
# We want to retrieve for all layers, and both k and v
|
||||
episode_k = self.kv_cache[:, 0, :, :, start_pos:end_pos, :]
|
||||
episode_v = self.kv_cache[:, 1, :, :, start_pos:end_pos, :]
|
||||
return episode_k, episode_v
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@torch.inference_mode()
|
||||
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
|
||||
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
|
||||
def sample_from_logits(logits, rng, temperature, top_k):
|
||||
"""Sample a single next concept ID from given logits of shape (B, num_concept_ids). Returns (B, 1)."""
|
||||
assert temperature >= 0.0, "temperature must be non-negative"
|
||||
if temperature == 0.0:
|
||||
return torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
if top_k is not None:
|
||||
k = min(top_k, logits.size(-1))
|
||||
vals, idx = torch.topk(logits, k, dim=-1)
|
||||
vals = vals / temperature
|
||||
probs = F.softmax(vals, dim=-1)
|
||||
choice = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
return idx.gather(1, choice)
|
||||
else:
|
||||
logits = logits / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
return torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
logits = logits / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
idx_next = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
return idx_next
|
||||
|
||||
class RowState:
|
||||
# Per-row state tracking during generation
|
||||
def __init__(self, current_tokens=None):
|
||||
self.current_tokens = current_tokens or [] # Current token sequence for this row
|
||||
self.forced_tokens = deque() # Queue of tokens to force inject
|
||||
self.in_python_block = False # Whether we are inside a python block
|
||||
self.python_expr_tokens = [] # Tokens of the current python expression
|
||||
self.completed = False # Whether this row has completed generation
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, model, tokenizer):
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer # needed for tool use
|
||||
def __init__(self, config: GPTConfig):
|
||||
self.config = config
|
||||
self.model = GPT(config)
|
||||
self.model.eval()
|
||||
self.model.init_weights()
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
|
||||
"""Same as generate, but does single prefill and then clones the KV cache."""
|
||||
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
|
||||
device = self.model.get_device()
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
# Initialize HypercubeEmbeddingLayer
|
||||
self.concept_embedding_layer = HypercubeEmbeddingLayer(
|
||||
num_raw_concepts=config.num_concept_ids,
|
||||
embedding_dim=config.n_embd,
|
||||
hypercube_topology=self.hypercube_topology
|
||||
)
|
||||
|
||||
# Get the special tokens we need to coordinate the tool use state machine
|
||||
get_special = lambda s: self.tokenizer.encode_special(s)
|
||||
python_start = get_special("<|python_start|>")
|
||||
python_end = get_special("<|python_end|>")
|
||||
output_start = get_special("<|output_start|>")
|
||||
output_end = get_special("<|output_end|>")
|
||||
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
# Initialize AbacusEncoder
|
||||
self.abacus_encoder = AbacusEncoder(
|
||||
input_dim=config.abacus_input_dim,
|
||||
embedding_dim=config.n_embd
|
||||
)
|
||||
nn.init.normal_(self.concept_embedding_layer.weight, mean=0.0, std=1.0)
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt tokens
|
||||
# Initialize HypercubeTopology
|
||||
n_hypercube_dim = int(math.log2(config.num_concept_ids))
|
||||
if not (2**n_hypercube_dim == config.num_concept_ids):
|
||||
raise ValueError("config.num_concept_ids must be a power of 2 for hypercube topology.")
|
||||
self.hypercube_topology = HypercubeTopology(n_hypercube_dim)
|
||||
|
||||
# Initialize LongTermMemory
|
||||
self.long_term_memory = LongTermMemory(
|
||||
embedding_dim=config.n_embd,
|
||||
max_memory_size=1000, # Default max memory size
|
||||
top_k_retrieval=5 # Default top-k retrieval
|
||||
)
|
||||
|
||||
# Initialize InternalSelfModel
|
||||
self.internal_self_model = InternalSelfModel(
|
||||
embedding_dim=config.n_embd,
|
||||
num_concepts=config.num_concept_ids
|
||||
)
|
||||
|
||||
# Initialize AbacusStateMemory
|
||||
self.abacus_state_memory = AbacusStateMemory(
|
||||
max_memory_size=100, # Default max memory size
|
||||
abacus_input_dim=config.abacus_input_dim
|
||||
)
|
||||
|
||||
# Initialize MemeticLearningLayer
|
||||
self.memetic_learning_layer = MemeticLearningLayer(
|
||||
config=config,
|
||||
abacus_encoder=self.abacus_encoder,
|
||||
internal_self_model=self.internal_self_model
|
||||
)
|
||||
|
||||
# Initialize PsycheController
|
||||
self.psyche_controller = PsycheController(
|
||||
config=config
|
||||
)
|
||||
|
||||
# Initialize ConsciousIntegrationLayer
|
||||
self.conscious_integration_layer = ConsciousIntegrationLayer(
|
||||
config=config,
|
||||
abacus_encoder=self.abacus_encoder
|
||||
)
|
||||
|
||||
self._concept_encoder_placeholder = None # This will be set later
|
||||
self._concept_attention_fusion_placeholder = nn.Linear(config.n_embd * 2, config.n_embd)
|
||||
|
||||
# Placeholder for concept encoder
|
||||
self.concept_encoder = self._concept_encoder_placeholder
|
||||
|
||||
def _concept_encoder_placeholder(self, concept_list: list[int] | dict) -> torch.Tensor:
|
||||
"""
|
||||
Placeholder for a more sophisticated concept encoder.
|
||||
For now, it assumes concept_list directly contains integer concept IDs or a dict with 'abacus_pattern'.
|
||||
"""
|
||||
if isinstance(concept_list, dict) and "abacus_pattern" in concept_list:
|
||||
# If an abacus pattern is provided, encode it using the AbacusEncoder
|
||||
abacus_pattern = concept_list["abacus_pattern"]
|
||||
# Ensure abacus_pattern is a tensor and has the correct shape
|
||||
if not isinstance(abacus_pattern, torch.Tensor):
|
||||
abacus_pattern = torch.tensor(abacus_pattern, dtype=torch.float32)
|
||||
# Add batch dimension if missing
|
||||
if abacus_pattern.dim() == 1:
|
||||
abacus_pattern = abacus_pattern.unsqueeze(0)
|
||||
return self.abacus_encoder(abacus_pattern)
|
||||
elif isinstance(concept_list, list):
|
||||
# Otherwise, assume it's a list of integer concept IDs
|
||||
return torch.tensor(concept_list, dtype=torch.long, device=self.gpt.get_device())
|
||||
else:
|
||||
raise ValueError("concept_list must be a list of integers or a dict with 'abacus_pattern'.")
|
||||
|
||||
def _concept_memory_retrieve_placeholder(self, current_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# Placeholder for concept memory retrieval logic
|
||||
# This would typically involve searching a memory bank for relevant concepts
|
||||
# based on the current_embedding and returning their embeddings.
|
||||
# For now, it returns a zero tensor of appropriate shape.
|
||||
return torch.zeros_like(current_embedding)
|
||||
|
||||
def _concept_attention_fusion_placeholder(self, transformer_output: torch.Tensor, retrieved_concepts: torch.Tensor) -> torch.Tensor:
|
||||
# Placeholder for concept attention fusion logic
|
||||
# This would combine the transformer's output with the retrieved concept embeddings
|
||||
# using some attention mechanism.
|
||||
# For now, it just returns the transformer_output unchanged.
|
||||
return transformer_output
|
||||
|
||||
def sample_from_logits(self, concept_logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
|
||||
# Apply temperature
|
||||
if temperature == 0.0:
|
||||
next_concept_id = torch.argmax(concept_logits, dim=-1)
|
||||
else:
|
||||
concept_logits = concept_logits / temperature
|
||||
# Apply top-k filtering
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(concept_logits, min(top_k, concept_logits.size(-1)))
|
||||
concept_logits[concept_logits < v[:, [-1]]] = -float('Inf')
|
||||
probs = torch.softmax(concept_logits, dim=-1)
|
||||
next_concept_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
|
||||
return next_concept_id
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_embeddings: torch.Tensor, max_new_concepts: int = 20, temperature: float = 1.0, abacus_embedding: torch.Tensor | None = None, working_memory_window: int = 0) -> list[int]:
|
||||
B, T, C = input_embeddings.size()
|
||||
generated_embeddings = []
|
||||
|
||||
if abacus_embedding is None:
|
||||
abacus_embedding = torch.zeros(B, 1, self.config.abacus_input_dim, device=input_embeddings.device)
|
||||
|
||||
# Long-Term Memory retrieval for prefill
|
||||
prefill_long_term_memory_embeddings = self.long_term_memory.retrieve(input_embeddings[:, -1, :].squeeze(0))
|
||||
|
||||
# Get psyche weights from PsycheController for prefill
|
||||
prefill_psyche_weights = self.psyche_controller(input_embeddings[:, -1, :])
|
||||
|
||||
# Prefill the model with input_embeddings
|
||||
concept_logits, kv_cache, x_id_prefill, x_ego_prefill, x_superego_prefill = self.model.forward_prefill(
|
||||
input_embeddings,
|
||||
abacus_embedding=abacus_embedding,
|
||||
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
|
||||
psyche_weights=prefill_psyche_weights
|
||||
)
|
||||
|
||||
# Conscious Integration Layer for prefill
|
||||
synthesized_state_prefill = self.conscious_integration_layer.forward(
|
||||
id_output=x_id_prefill,
|
||||
ego_output=x_ego_prefill,
|
||||
superego_output=x_superego_prefill,
|
||||
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
|
||||
memetic_fitness=None, # Memetic fitness is not available during prefill
|
||||
abacus_state=abacus_embedding
|
||||
)
|
||||
|
||||
# Combine concept_logits from GPT and synthesized_state_prefill
|
||||
concept_logits = concept_logits + synthesized_state_prefill
|
||||
|
||||
# Sample the first concept ID from the last token's logits
|
||||
next_concept_id = self.sample_from_logits(concept_logits[:, -1, :], temperature)
|
||||
next_embedding = self.concept_embedding_layer(next_concept_id)
|
||||
generated_embeddings.append(next_embedding)
|
||||
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
|
||||
|
||||
# Abacus State Memory and Encoder integration for the first step
|
||||
abacus_pattern = self.abacus_encoder(next_embedding)
|
||||
self.abacus_state_memory.store(abacus_pattern)
|
||||
abacus_embedding = self.abacus_state_memory.retrieve()
|
||||
|
||||
for _ in range(max_new_concepts - 1):
|
||||
# Working Memory retrieval
|
||||
episodic_kv = None
|
||||
if working_memory_window > 0 and kv_cache.get_pos() > working_memory_window:
|
||||
start_pos = kv_cache.get_pos() - working_memory_window
|
||||
end_pos = kv_cache.get_pos()
|
||||
episodic_kv = kv_cache.retrieve_episode(start_pos, end_pos)
|
||||
|
||||
# Long-Term Memory retrieval
|
||||
long_term_memory_embeddings = self.long_term_memory.retrieve(next_embedding.squeeze(0))
|
||||
|
||||
# Get psyche weights from PsycheController
|
||||
psyche_weights = self.psyche_controller(next_embedding)
|
||||
|
||||
concept_logits, kv_cache, x_id_step, x_ego_step, x_superego_step = self.model.forward_step(
|
||||
next_embedding,
|
||||
kv_cache,
|
||||
abacus_embedding=abacus_embedding,
|
||||
episodic_kv=episodic_kv,
|
||||
long_term_memory_embeddings=long_term_memory_embeddings,
|
||||
psyche_weights=psyche_weights
|
||||
)
|
||||
|
||||
# Memetic Learning Layer integration
|
||||
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
|
||||
|
||||
# Conscious Integration Layer for step
|
||||
synthesized_state_step = self.conscious_integration_layer.forward(
|
||||
id_output=x_id_step,
|
||||
ego_output=x_ego_step,
|
||||
superego_output=x_superego_step,
|
||||
long_term_memory_embeddings=long_term_memory_embeddings,
|
||||
memetic_fitness=memetic_fitness,
|
||||
abacus_state=abacus_embedding
|
||||
)
|
||||
|
||||
# Combine concept_logits from GPT and synthesized_state_step
|
||||
concept_logits = concept_logits + synthesized_state_step.squeeze(1) # Squeeze to match dimensions
|
||||
|
||||
next_concept_id = self.sample_from_logits(concept_logits, temperature)
|
||||
next_embedding = self.concept_embedding_layer(next_concept_id)
|
||||
generated_embeddings.append(next_embedding)
|
||||
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
|
||||
|
||||
# Abacus State Memory and Encoder integration for subsequent steps
|
||||
abacus_pattern = self.abacus_encoder(next_embedding)
|
||||
self.abacus_state_memory.store(abacus_pattern)
|
||||
abacus_embedding = self.abacus_state_memory.retrieve() # Update abacus_embedding for the next step
|
||||
|
||||
# Memetic Learning Layer integration
|
||||
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
|
||||
|
||||
return torch.stack(generated_embeddings, dim=1)
|
||||
|
||||
def generate_from_concepts(self, concept_list: list[int] | dict, max_new_concepts: int = 20, temperature: float = 1.0) -> list[int]:
|
||||
# Encode the concept_list into initial input embeddings
|
||||
encoded_concepts = self.concept_encoder(concept_list)
|
||||
|
||||
if encoded_concepts.dtype == torch.long:
|
||||
# If it's concept IDs, get embeddings from the concept_embedding_layer
|
||||
input_embeddings = self.concept_embedding_layer(encoded_concepts)
|
||||
abacus_embedding = None # No abacus embedding in this case
|
||||
elif encoded_concepts.dtype == torch.float:
|
||||
# If it's an abacus embedding, use it directly and set abacus_embedding
|
||||
input_embeddings = encoded_concepts
|
||||
abacus_embedding = encoded_concepts # The abacus embedding is the input embedding itself
|
||||
else:
|
||||
raise TypeError("Unexpected return type from concept_encoder.")
|
||||
|
||||
# Call the main generate method
|
||||
return self.generate(input_embeddings, max_new_concepts, temperature, abacus_embedding=abacus_embedding)
|
||||
|
||||
|
||||
# Special tokens are no longer directly used with concept embeddings.
|
||||
# The tool use logic will need to be re-evaluated or removed if not applicable.
|
||||
# get_special = lambda s: self.tokenizer.encode_special(s)
|
||||
# python_start = get_special("<|python_start|>")
|
||||
# python_end = get_special("<|python_end|>")
|
||||
# output_start = get_special("<|output_start|>")
|
||||
# output_end = get_special("<|output_end|>")
|
||||
# assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
|
||||
# bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
|
||||
|
||||
# 1) Run a batch 1 prefill of the prompt embeddings
|
||||
m = self.model.config
|
||||
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
|
||||
kv_cache_prefill = KVCache(
|
||||
batch_size=1,
|
||||
seq_len=len(tokens),
|
||||
batch_size=input_embeddings.size(0),
|
||||
seq_len=input_embeddings.size(1),
|
||||
**kv_model_kwargs,
|
||||
)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device)
|
||||
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
|
||||
logits = logits[:, -1, :]
|
||||
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
|
||||
sampled_tokens = next_ids[:, 0].tolist()
|
||||
logits = self.model.forward(input_embeddings, kv_cache=kv_cache_prefill)
|
||||
# Removed token-based sampling logic (replace with embedding generation logic later)
|
||||
|
||||
# 2) Replicate the KV cache for each sample/row
|
||||
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
kv_length_hint = (input_embeddings.size(1) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
|
||||
kv_cache_decode = KVCache(
|
||||
batch_size=num_samples,
|
||||
seq_len=kv_length_hint,
|
||||
|
|
@ -230,7 +467,7 @@ class Engine:
|
|||
del kv_cache_prefill # no need to keep this memory around
|
||||
|
||||
# 3) Initialize states for each sample
|
||||
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
|
||||
row_states = [RowState(input_embeddings[i].tolist()) for i in range(num_samples)] # Assuming input_embeddings is (B, T, C)
|
||||
|
||||
# 4) Main generation loop
|
||||
num_generated = 0
|
||||
|
|
|
|||
415
nanochat/gpt.py
415
nanochat/gpt.py
|
|
@ -25,13 +25,124 @@ from nanochat.adamw import DistAdamW
|
|||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
sequence_len: int = 1024
|
||||
vocab_size: int = 50304
|
||||
n_layer: int = 12
|
||||
n_head: int = 6 # number of query heads
|
||||
n_kv_head: int = 6 # number of key/value heads (MQA)
|
||||
n_embd: int = 768
|
||||
|
||||
def __init__(self,
|
||||
n_layer=12,
|
||||
n_head=12,
|
||||
n_embd=768,
|
||||
sequence_len=1024,
|
||||
n_kv_head=None,
|
||||
num_concept_ids=4096, # Updated to 4096 for n=12 hypercube
|
||||
abacus_input_dim=64, # Default input dimension for the AbacusEncoder
|
||||
dropout=0.0,
|
||||
bias=False,
|
||||
multiple_of=256,
|
||||
norm_eps=1e-5,
|
||||
rope_theta=10000,
|
||||
|
||||
# For training
|
||||
batch_size=1,
|
||||
gradient_accumulation_steps=1,
|
||||
max_iters=0,
|
||||
lr=6e-4,
|
||||
min_lr=6e-5,
|
||||
weight_decay=1e-1,
|
||||
beta1=0.9,
|
||||
beta2=0.95,
|
||||
grad_clip=1.0,
|
||||
decay_lr=True,
|
||||
warmup_iters=2000,
|
||||
lr_decay_iters=600000,
|
||||
|
||||
# For checkpointing
|
||||
out_dir='out',
|
||||
eval_interval=2000,
|
||||
log_interval=1,
|
||||
eval_iters=200,
|
||||
eval_only=False,
|
||||
always_save_checkpoint=True,
|
||||
|
||||
# For distributed training
|
||||
backend='nccl',
|
||||
|
||||
# For system
|
||||
device='cpu',
|
||||
dtype='bfloat16',
|
||||
compile=False,
|
||||
|
||||
# For data
|
||||
dataset='openwebtext',
|
||||
|
||||
# For inference
|
||||
init_from='scratch',
|
||||
|
||||
# For chat
|
||||
chat=False,
|
||||
|
||||
# For concept
|
||||
concept_memory_size=1000,
|
||||
concept_memory_top_k=5,
|
||||
use_concept_attention=False,
|
||||
|
||||
# For psyche
|
||||
psyche_id_lr_scale=1.0,
|
||||
psyche_ego_lr_scale=1.0,
|
||||
psyche_superego_lr_scale=1.0,
|
||||
|
||||
**kwargs):
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.n_embd = n_embd
|
||||
self.sequence_len = sequence_len
|
||||
self.n_kv_head = n_kv_head if n_kv_head is not None else n_head
|
||||
self.num_concept_ids = num_concept_ids
|
||||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
self.multiple_of = multiple_of
|
||||
self.norm_eps = norm_eps
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.gradient_accumulation_steps = gradient_accumulation_steps
|
||||
self.max_iters = max_iters
|
||||
self.lr = lr
|
||||
self.min_lr = min_lr
|
||||
self.weight_decay = weight_decay
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.grad_clip = grad_clip
|
||||
self.decay_lr = decay_lr
|
||||
self.warmup_iters = warmup_iters
|
||||
self.lr_decay_iters = lr_decay_iters
|
||||
|
||||
self.out_dir = out_dir
|
||||
self.eval_interval = eval_interval
|
||||
self.log_interval = log_interval
|
||||
self.eval_iters = eval_iters
|
||||
self.eval_only = eval_only
|
||||
self.always_save_checkpoint = always_save_checkpoint
|
||||
|
||||
self.backend = backend
|
||||
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
self.compile = compile
|
||||
|
||||
self.dataset = dataset
|
||||
|
||||
self.init_from = init_from
|
||||
|
||||
self.chat = chat
|
||||
|
||||
self.concept_memory_size = concept_memory_size
|
||||
self.concept_memory_top_k = concept_memory_top_k
|
||||
self.use_concept_attention = use_concept_attention
|
||||
|
||||
self.psyche_id_lr_scale = psyche_id_lr_scale
|
||||
self.psyche_ego_lr_scale = psyche_ego_lr_scale
|
||||
self.psyche_superego_lr_scale = psyche_superego_lr_scale
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
def norm(x):
|
||||
# Purely functional rmsnorm with no learnable params
|
||||
|
|
@ -63,7 +174,7 @@ class CausalSelfAttention(nn.Module):
|
|||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
||||
|
||||
def forward(self, x, cos_sin, kv_cache):
|
||||
def forward(self, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
|
||||
B, T, C = x.size()
|
||||
|
||||
# Project the input to get queries, keys, and values
|
||||
|
|
@ -80,6 +191,14 @@ class CausalSelfAttention(nn.Module):
|
|||
# Apply KV cache: insert current k,v into cache, get the full view so far
|
||||
if kv_cache is not None:
|
||||
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
|
||||
|
||||
# If episodic_kv is provided, prepend it to the current k and v
|
||||
if episodic_kv is not None:
|
||||
episode_k_layer = episodic_kv[self.layer_idx, 0]
|
||||
episode_v_layer = episodic_kv[self.layer_idx, 1]
|
||||
k = torch.cat([episode_k_layer, k], dim=2)
|
||||
v = torch.cat([episode_v_layer, v], dim=2)
|
||||
|
||||
Tq = q.size(2) # number of queries in this forward pass
|
||||
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
|
||||
|
||||
|
|
@ -135,15 +254,42 @@ class Block(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class PsycheController(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.controller_head = nn.Linear(config.n_embd, 3) # 3 for id, ego, superego
|
||||
|
||||
def forward(self, x):
|
||||
# For now, just return equal weights for each psyche layer
|
||||
# In the future, this will be trained to dynamically blend psyche outputs
|
||||
return torch.softmax(self.controller_head(x), dim=-1)
|
||||
|
||||
|
||||
class GPT(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = nn.ModuleDict({
|
||||
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
||||
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
|
||||
})
|
||||
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
# Partition transformer layers into psyche layers
|
||||
total_layers = config.n_layer
|
||||
id_end = total_layers // 3
|
||||
ego_end = 2 * total_layers // 3
|
||||
|
||||
self.id_layers = self.transformer.h[:id_end]
|
||||
self.ego_layers = self.transformer.h[id_end:ego_end]
|
||||
self.superego_layers = self.transformer.h[ego_end:]
|
||||
|
||||
self.psyche_registry = {
|
||||
"id": self.id_layers,
|
||||
"ego": self.ego_layers,
|
||||
"superego": self.superego_layers
|
||||
}
|
||||
|
||||
self.concept_head = nn.Linear(config.n_embd, config.num_concept_ids, bias=False) # New concept head
|
||||
self.psyche_controller = PsycheController(config) # Initialize PsycheController
|
||||
# To support meta device initialization, we init the rotary embeddings here, but it's fake
|
||||
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
|
||||
# so let's just over-compute them, but assert fail if we ever reach that amount.
|
||||
|
|
@ -156,8 +302,8 @@ class GPT(nn.Module):
|
|||
|
||||
def init_weights(self):
|
||||
self.apply(self._init_weights)
|
||||
# zero out classifier weights
|
||||
torch.nn.init.zeros_(self.lm_head.weight)
|
||||
# zero out concept_head weights
|
||||
torch.nn.init.zeros_(self.concept_head.weight)
|
||||
# zero out c_proj weights in all blocks
|
||||
for block in self.transformer.h:
|
||||
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
||||
|
|
@ -166,9 +312,6 @@ class GPT(nn.Module):
|
|||
head_dim = self.config.n_embd // self.config.n_head
|
||||
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
||||
self.cos, self.sin = cos, sin
|
||||
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
|
||||
if self.transformer.wte.weight.device.type == "cuda":
|
||||
self.transformer.wte.to(dtype=torch.bfloat16)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, nn.Linear):
|
||||
|
|
@ -186,7 +329,7 @@ class GPT(nn.Module):
|
|||
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
||||
# autodetect the device from model embeddings
|
||||
if device is None:
|
||||
device = self.transformer.wte.weight.device
|
||||
device = self.concept_head.weight.device
|
||||
# stride the channels
|
||||
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
||||
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
||||
|
|
@ -200,12 +343,13 @@ class GPT(nn.Module):
|
|||
return cos, sin
|
||||
|
||||
def get_device(self):
|
||||
return self.transformer.wte.weight.device
|
||||
# Get device from concept_head weight
|
||||
return self.concept_head.weight.device
|
||||
|
||||
def estimate_flops(self):
|
||||
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
|
||||
nparams = sum(p.numel() for p in self.parameters())
|
||||
nparams_embedding = self.transformer.wte.weight.numel()
|
||||
nparams_embedding = 0 # No separate embedding layer now
|
||||
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
|
||||
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
|
||||
return num_flops_per_token
|
||||
|
|
@ -213,95 +357,190 @@ class GPT(nn.Module):
|
|||
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
|
||||
matrix_params = list(self.transformer.h.parameters())
|
||||
embedding_params = list(self.transformer.wte.parameters())
|
||||
lm_head_params = list(self.lm_head.parameters())
|
||||
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
|
||||
# Create the AdamW optimizer for the embedding and lm_head
|
||||
# Separate out all parameters into 3 groups (matrix, embedding, concept_head)
|
||||
# matrix_params = list(self.transformer.h.parameters())
|
||||
id_params = list(self.id_layers.parameters())
|
||||
ego_params = list(self.ego_layers.parameters())
|
||||
superego_params = list(self.superego_layers.parameters())
|
||||
|
||||
embedding_params = [] # No separate embedding layer now
|
||||
concept_head_params = list(self.concept_head.parameters()) # New concept head params
|
||||
psyche_controller_params = list(self.psyche_controller.parameters()) # Psyche controller params
|
||||
|
||||
# assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(concept_head_params)
|
||||
# Create the AdamW optimizer for the embedding and concept_head
|
||||
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
|
||||
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
||||
if rank == 0:
|
||||
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
||||
adam_groups = [
|
||||
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
|
||||
dict(params=concept_head_params, lr=unembedding_lr * dmodel_lr_scale), # Use unembedding_lr for concept_head
|
||||
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
|
||||
dict(params=id_params, lr=3e-4 * dmodel_lr_scale), # Id layers learning rate
|
||||
dict(params=ego_params, lr=1e-4 * dmodel_lr_scale), # Ego layers learning rate
|
||||
dict(params=superego_params, lr=5e-5 * dmodel_lr_scale), # Superego layers learning rate
|
||||
dict(params=psyche_controller_params, lr=1e-4 * dmodel_lr_scale), # Psyche controller learning rate
|
||||
]
|
||||
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
|
||||
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
|
||||
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
|
||||
# Create the Muon optimizer for the linear layers
|
||||
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
MuonFactory = DistMuon if ddp else Muon
|
||||
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
|
||||
# MuonFactory = DistMuon if ddp else Muon
|
||||
# muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
|
||||
# Combine them the two optimizers into one list
|
||||
optimizers = [adamw_optimizer, muon_optimizer]
|
||||
optimizers = [adamw_optimizer]
|
||||
for opt in optimizers:
|
||||
for group in opt.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizers
|
||||
|
||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||
B, T = idx.size()
|
||||
def _run_layers(self, layers, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
|
||||
for block in layers:
|
||||
x = block(x, cos_sin, kv_cache, episodic_kv)
|
||||
return x
|
||||
|
||||
def forward(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
|
||||
B, T, C = input_embeddings.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
|
||||
T0 = 0 if kv_cache is None else kv_cache.get_pos()
|
||||
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
|
||||
cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
|
||||
|
||||
x = input_embeddings
|
||||
|
||||
# Process Id layers
|
||||
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Ego layers
|
||||
x_ego = x_id
|
||||
if abacus_embedding is not None:
|
||||
# Broadcast abacus_embedding to match the sequence length of x_ego
|
||||
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
|
||||
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
|
||||
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Superego layers
|
||||
x_superego = x_ego
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Dynamically blend the outputs based on psyche_weights
|
||||
# Reshape psyche_weights for broadcasting: (B, 1, 3)
|
||||
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
|
||||
|
||||
# Stack the outputs and apply weighted sum
|
||||
# Stack will result in (B, T, 3, C)
|
||||
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
|
||||
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
|
||||
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
|
||||
|
||||
# Final concept head
|
||||
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
|
||||
|
||||
def forward_prefill(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
|
||||
B, T, C = input_embeddings.size()
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
|
||||
|
||||
x = input_embeddings
|
||||
|
||||
# Process Id layers
|
||||
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Ego layers
|
||||
x_ego = x_id
|
||||
if abacus_embedding is not None:
|
||||
# Broadcast abacus_embedding to match the sequence length of x_ego
|
||||
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
|
||||
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
|
||||
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Superego layers
|
||||
x_superego = x_ego
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Dynamically blend the outputs based on psyche_weights
|
||||
# Reshape psyche_weights for broadcasting: (B, 1, 3)
|
||||
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
|
||||
|
||||
# Stack the outputs and apply weighted sum
|
||||
# Stack will result in (B, T, 3, C)
|
||||
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
|
||||
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
|
||||
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
|
||||
|
||||
# Final concept head
|
||||
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
|
||||
|
||||
def forward_step(self, next_embedding: torch.Tensor, kv_cache, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
|
||||
B, C = next_embedding.size()
|
||||
T = kv_cache[0].size(1) + 1 # Current sequence length after adding next_embedding
|
||||
|
||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||
cos_sin = self.cos[:, T-1:T, :, :], self.sin[:, T-1:T, :, :]
|
||||
|
||||
x = next_embedding.unsqueeze(1) # Add sequence dimension for consistency
|
||||
|
||||
# Process Id layers
|
||||
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Ego layers
|
||||
x_ego = x_id
|
||||
if abacus_embedding is not None:
|
||||
# Broadcast abacus_embedding to match the sequence length of x_ego
|
||||
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
|
||||
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
|
||||
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
|
||||
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Process Superego layers
|
||||
x_superego = x_ego
|
||||
if long_term_memory_embeddings is not None:
|
||||
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
|
||||
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
|
||||
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
|
||||
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
|
||||
|
||||
# Dynamically blend the outputs based on psyche_weights
|
||||
# Reshape psyche_weights for broadcasting: (B, 1, 3)
|
||||
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
|
||||
|
||||
# Stack the outputs and apply weighted sum
|
||||
# Stack will result in (B, T, 3, C)
|
||||
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
|
||||
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
|
||||
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
|
||||
|
||||
# Forward the trunk of the Transformer
|
||||
x = self.transformer.wte(idx)
|
||||
x = norm(x)
|
||||
for block in self.transformer.h:
|
||||
x = block(x, cos_sin, kv_cache)
|
||||
x = norm(x)
|
||||
|
||||
# Forward the lm_head (compute logits)
|
||||
softcap = 15
|
||||
if targets is not None:
|
||||
# training mode: compute and return the loss
|
||||
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
logits = logits.float() # use tf32/fp32 for logits
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
|
||||
return loss
|
||||
else:
|
||||
# inference mode: compute and return the logits
|
||||
logits = self.lm_head(x)
|
||||
logits = softcap * torch.tanh(logits / softcap) # logits softcap
|
||||
return logits
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
|
||||
"""
|
||||
Naive autoregressive streaming inference.
|
||||
To make it super simple, let's assume:
|
||||
- batch size is 1
|
||||
- ids and the yielded tokens are simple Python lists and ints
|
||||
"""
|
||||
assert isinstance(tokens, list)
|
||||
device = self.get_device()
|
||||
rng = None
|
||||
if temperature > 0:
|
||||
rng = torch.Generator(device=device)
|
||||
rng.manual_seed(seed)
|
||||
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
|
||||
for _ in range(max_tokens):
|
||||
logits = self.forward(ids) # (B, T, vocab_size)
|
||||
logits = logits[:, -1, :] # (B, vocab_size)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
if temperature > 0:
|
||||
logits = logits / temperature
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
|
||||
else:
|
||||
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
ids = torch.cat((ids, next_ids), dim=1)
|
||||
token = next_ids.item()
|
||||
yield token
|
||||
return self.concept_head(x.squeeze(1)), kv_cache, x_id, x_ego, x_superego
|
||||
|
|
|
|||
149
nanochat/hypercube.py
Normal file
149
nanochat/hypercube.py
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class HypercubeTopology:
|
||||
def __init__(self, n: int):
|
||||
if not (1 <= n <= 16): # Limiting n for practical reasons, can be adjusted
|
||||
raise ValueError("Hypercube dimension 'n' must be between 1 and 16.")
|
||||
self.n = n
|
||||
self.num_vertices = 2**n
|
||||
|
||||
def to_binary(self, vertex_id: int) -> str:
|
||||
if not (0 <= vertex_id < self.num_vertices):
|
||||
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
|
||||
return format(vertex_id, f'0{self.n}b')
|
||||
|
||||
def from_binary(self, binary_string: str) -> int:
|
||||
if len(binary_string) != self.n:
|
||||
raise ValueError(f"Binary string length must be {self.n} for a {self.n}-dimensional hypercube.")
|
||||
return int(binary_string, 2)
|
||||
|
||||
def get_neighbors(self, vertex_id: int) -> list[int]:
|
||||
if not (0 <= vertex_id < self.num_vertices):
|
||||
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
|
||||
|
||||
neighbors = []
|
||||
for i in range(self.n):
|
||||
# Flip the i-th bit
|
||||
neighbor_id = vertex_id ^ (1 << i)
|
||||
neighbors.append(neighbor_id)
|
||||
return neighbors
|
||||
|
||||
def get_all_edges(self) -> list[tuple[int, int]]:
|
||||
edges = []
|
||||
for i in range(self.num_vertices):
|
||||
for neighbor in self.get_neighbors(i):
|
||||
# Ensure each edge is added only once (e.g., (0,1) not (1,0))
|
||||
if i < neighbor:
|
||||
edges.append((i, neighbor))
|
||||
return edges
|
||||
|
||||
def get_random_vertex(self) -> int:
|
||||
return torch.randint(0, self.num_vertices, (1,)).item()
|
||||
|
||||
def get_random_path(self, start_vertex: int, length: int) -> list[int]:
|
||||
if not (0 <= start_vertex < self.num_vertices):
|
||||
raise ValueError(f"Start vertex ID {start_vertex} is out of range.")
|
||||
if length < 1:
|
||||
raise ValueError("Path length must be at least 1.")
|
||||
|
||||
path = [start_vertex]
|
||||
current_vertex = start_vertex
|
||||
|
||||
for _ in range(length - 1):
|
||||
neighbors = self.get_neighbors(current_vertex)
|
||||
if not neighbors:
|
||||
break # Should not happen in a hypercube unless n=0
|
||||
|
||||
# Choose a random neighbor that is not the previous vertex if possible
|
||||
next_vertex = neighbors[torch.randint(0, len(neighbors), (1,)).item()]
|
||||
path.append(next_vertex)
|
||||
current_vertex = next_vertex
|
||||
return path
|
||||
|
||||
class HypercubeEmbeddingLayer(nn.Module):
|
||||
def __init__(self, num_raw_concepts, embedding_dim, hypercube_topology):
|
||||
super().__init__()
|
||||
self.raw_embedding_layer = nn.Embedding(num_raw_concepts, embedding_dim)
|
||||
self.hypercube_topology = hypercube_topology
|
||||
# Embeddings for hypercube vertices. The number of vertices is hypercube_topology.num_vertices
|
||||
self.vertex_embeddings = nn.Embedding(hypercube_topology.num_vertices, embedding_dim)
|
||||
|
||||
def forward(self, concept_ids):
|
||||
# Get initial embeddings for the input concept_ids
|
||||
initial_embeddings = self.raw_embedding_layer(concept_ids) # (batch_size, embedding_dim)
|
||||
|
||||
# Find the nearest hypercube vertex for each initial embedding
|
||||
|
||||
# Get all hypercube vertex embeddings
|
||||
# Ensure all_vertex_ids are on the same device as concept_ids
|
||||
all_vertex_ids = torch.arange(self.hypercube_topology.num_vertices, device=concept_ids.device)
|
||||
all_vertex_embeddings = self.vertex_embeddings(all_vertex_ids) # (num_vertices, embedding_dim)
|
||||
|
||||
# Calculate squared Euclidean distance
|
||||
# initial_embeddings_expanded: (batch_size, 1, embedding_dim)
|
||||
# all_vertex_embeddings_expanded: (1, num_vertices, embedding_dim)
|
||||
initial_embeddings_expanded = initial_embeddings.unsqueeze(1)
|
||||
all_vertex_embeddings_expanded = all_vertex_embeddings.unsqueeze(0)
|
||||
|
||||
distances = torch.sum((initial_embeddings_expanded - all_vertex_embeddings_expanded)**2, dim=2) # (batch_size, num_vertices)
|
||||
|
||||
# Find the index of the nearest vertex for each initial embedding
|
||||
nearest_vertex_indices = torch.argmin(distances, dim=1) # (batch_size,)
|
||||
|
||||
# Retrieve the embeddings of the nearest vertices
|
||||
final_embeddings = self.vertex_embeddings(nearest_vertex_indices)
|
||||
|
||||
return final_embeddings
|
||||
|
||||
|
||||
class LongTermMemory(nn.Module):
|
||||
def __init__(self, embedding_dim: int, max_memory_size: int = 1000, top_k_retrieval: int = 5):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.max_memory_size = max_memory_size
|
||||
self.top_k_retrieval = top_k_retrieval
|
||||
|
||||
# Initialize an empty memory bank. We'll use a list of tensors for simplicity initially,
|
||||
# but this could be replaced with a more efficient data structure or a learnable embedding layer.
|
||||
self.memory_bank = []
|
||||
self.memory_bank_tensor = None # Will store concatenated memories for efficient retrieval
|
||||
|
||||
def store(self, embedding: torch.Tensor):
|
||||
# Store a new embedding in the memory bank
|
||||
if len(self.memory_bank) >= self.max_memory_size:
|
||||
# Simple FIFO eviction for now
|
||||
self.memory_bank.pop(0)
|
||||
self.memory_bank.append(embedding.detach().cpu()) # Store on CPU to save GPU memory
|
||||
self.memory_bank_tensor = None # Invalidate cached tensor
|
||||
|
||||
def retrieve(self, query_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# Retrieve top-k most similar embeddings from the memory bank
|
||||
if not self.memory_bank:
|
||||
return torch.zeros(query_embedding.shape[0], self.top_k_retrieval, self.embedding_dim, device=query_embedding.device)
|
||||
|
||||
if self.memory_bank_tensor is None:
|
||||
self.memory_bank_tensor = torch.stack(self.memory_bank).to(query_embedding.device)
|
||||
|
||||
# Normalize query and memory bank for cosine similarity
|
||||
query_norm = F.normalize(query_embedding, p=2, dim=-1)
|
||||
memory_norm = F.normalize(self.memory_bank_tensor, p=2, dim=-1)
|
||||
|
||||
# Calculate cosine similarity
|
||||
# query_norm: (batch_size, embedding_dim)
|
||||
# memory_norm: (num_memories, embedding_dim)
|
||||
# similarities: (batch_size, num_memories)
|
||||
similarities = torch.matmul(query_norm, memory_norm.transpose(0, 1))
|
||||
|
||||
# Get top-k similar memories
|
||||
# top_k_values: (batch_size, top_k_retrieval)
|
||||
# top_k_indices: (batch_size, top_k_retrieval)
|
||||
top_k_values, top_k_indices = torch.topk(similarities, min(self.top_k_retrieval, len(self.memory_bank)), dim=-1)
|
||||
|
||||
# Retrieve the actual embeddings
|
||||
retrieved_memories = self.memory_bank_tensor[top_k_indices]
|
||||
|
||||
return retrieved_memories
|
||||
|
||||
def forward(self, query_embedding: torch.Tensor) -> torch.Tensor:
|
||||
return self.retrieve(query_embedding)
|
||||
46
nanochat/memetic_learning.py
Normal file
46
nanochat/memetic_learning.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class MemeticLearningLayer(nn.Module):
|
||||
def __init__(self, config, abacus_encoder, internal_self_model):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.abacus_encoder = abacus_encoder
|
||||
self.internal_self_model = internal_self_model
|
||||
|
||||
# Placeholder for memetic fitness evaluation mechanism
|
||||
# This could be a simple linear layer or a more complex network
|
||||
self.fitness_evaluator = nn.Linear(config.abacus_input_dim, 1)
|
||||
|
||||
# Placeholder for concept mapping expansion (analogy/metaphor)
|
||||
# This might involve a transformation of embeddings or a retrieval mechanism
|
||||
self.concept_mapper = nn.Sequential(
|
||||
nn.Linear(config.n_embd * 2, config.n_embd), # Input: two concept embeddings
|
||||
nn.ReLU(),
|
||||
nn.Linear(config.n_embd, config.n_embd)
|
||||
)
|
||||
|
||||
def evaluate_memetic_fitness(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
|
||||
# Evaluate the fitness of a memetic pattern using the abacus encoder output
|
||||
fitness_score = self.fitness_evaluator(abacus_pattern)
|
||||
return fitness_score
|
||||
|
||||
def expand_concept_mapping(self, concept1_embedding: torch.Tensor, concept2_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# Expand concept mapping via analogy and metaphor
|
||||
# This takes two concept embeddings and generates a new, related concept embedding
|
||||
combined_concepts = torch.cat([concept1_embedding, concept2_embedding], dim=-1)
|
||||
new_concept_embedding = self.concept_mapper(combined_concepts)
|
||||
return new_concept_embedding
|
||||
|
||||
def forward(self, current_concept_embedding: torch.Tensor, abacus_pattern: torch.Tensor):
|
||||
# Orchestrates the memetic learning process
|
||||
# 1. Evaluate memetic fitness
|
||||
fitness = self.evaluate_memetic_fitness(abacus_pattern)
|
||||
|
||||
# 2. Potentially update internal self-model based on fitness or new concepts
|
||||
# self.internal_self_model.update_beliefs(current_concept_embedding) # Example interaction
|
||||
|
||||
# 3. Generate new concepts or analogies
|
||||
# new_concept = self.expand_concept_mapping(current_concept_embedding, some_other_concept)
|
||||
|
||||
return fitness
|
||||
58
nanochat/self_model.py
Normal file
58
nanochat/self_model.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class InternalSelfModel(nn.Module):
|
||||
def __init__(self, embedding_dim: int, config):
|
||||
super().__init__()
|
||||
self.embedding_dim = embedding_dim
|
||||
self.config = config
|
||||
|
||||
# Placeholder for belief representation (e.g., a learned embedding or a set of parameters)
|
||||
self.belief_embedding = nn.Parameter(torch.randn(embedding_dim))
|
||||
|
||||
# Placeholder for a simple prediction head to calculate prediction error
|
||||
self.prediction_head = nn.Linear(embedding_dim, 1) # Predicts a scalar error signal
|
||||
|
||||
# Placeholder for conceptual growth mechanism (e.g., a small MLP or attention mechanism)
|
||||
self.conceptual_growth_mlp = nn.Sequential(
|
||||
nn.Linear(embedding_dim * 2, embedding_dim), # Input: current belief + new concept
|
||||
nn.ReLU(),
|
||||
nn.Linear(embedding_dim, embedding_dim)
|
||||
)
|
||||
|
||||
def update_beliefs(self, current_concept_embedding: torch.Tensor):
|
||||
# This method would update the internal belief representation based on new information.
|
||||
# For now, a simple update rule (e.g., moving average or attention-based update)
|
||||
# In a more advanced implementation, this could involve a recurrent neural network.
|
||||
# Example: simple weighted average
|
||||
alpha = 0.1 # Learning rate for belief update
|
||||
self.belief_embedding.data = (1 - alpha) * self.belief_embedding.data + alpha * current_concept_embedding.mean(dim=0)
|
||||
|
||||
def calculate_prediction_error(self, predicted_output: torch.Tensor, actual_output: torch.Tensor) -> torch.Tensor:
|
||||
# This method calculates the discrepancy between predicted and actual outcomes.
|
||||
# For now, a simple mean squared error.
|
||||
error = F.mse_loss(predicted_output, actual_output)
|
||||
return error
|
||||
|
||||
def promote_conceptual_growth(self, current_belief_embedding: torch.Tensor, new_concept_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# This method facilitates the integration of new concepts into the existing conceptual framework.
|
||||
# For now, a simple MLP that takes the concatenation of current belief and new concept.
|
||||
combined_embedding = torch.cat([current_belief_embedding, new_concept_embedding], dim=-1)
|
||||
updated_concept_embedding = self.conceptual_growth_mlp(combined_embedding)
|
||||
return updated_concept_embedding
|
||||
|
||||
def forward(self, current_concept_embedding: torch.Tensor, predicted_output: torch.Tensor, actual_output: torch.Tensor):
|
||||
# Update beliefs
|
||||
self.update_beliefs(current_concept_embedding)
|
||||
|
||||
# Calculate prediction error
|
||||
error = self.calculate_prediction_error(predicted_output, actual_output)
|
||||
|
||||
# Promote conceptual growth (example: using the current belief and concept embedding)
|
||||
# This part would be more complex in a full implementation, potentially involving attention
|
||||
# or other mechanisms to decide how to grow concepts based on error and new info.
|
||||
# For demonstration, let's assume conceptual growth is triggered by new concept embeddings.
|
||||
# updated_concept = self.promote_conceptual_growth(self.belief_embedding, current_concept_embedding)
|
||||
|
||||
return error, self.belief_embedding
|
||||
Loading…
Reference in New Issue
Block a user