hypercube-abacus

This commit is contained in:
Nda-jiya Suberu 2025-10-31 03:27:16 +00:00
parent dfc88334b6
commit 317e4b65df
9 changed files with 966 additions and 137 deletions

2
config/custom_config.py Normal file
View File

@ -0,0 +1,2 @@
n_layer = 24
n_embd = 1024

View File

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
class AbacusEncoder(nn.Module):
def __init__(self, input_dim: int, embedding_dim: int):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
# Simple linear layer to encode abacus-like patterns into the embedding space
self.encoder_layer = nn.Linear(input_dim, embedding_dim)
def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
# abacus_pattern is expected to be a tensor of shape (batch_size, input_dim)
if abacus_pattern.shape[-1] != self.input_dim:
raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}")
return self.encoder_layer(abacus_pattern)
def decode(self, concept_vector: torch.Tensor) -> torch.Tensor:
# Placeholder for decoding functionality
raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.")
def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
return self.encode(abacus_pattern)

View File

@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class AbacusStateMemory(nn.Module):
def __init__(self, max_memory_size: int = 100, abacus_input_dim: int = 64):
super().__init__()
self.max_memory_size = max_memory_size
self.abacus_input_dim = abacus_input_dim
self.memory = [] # Stores abacus patterns (tensors)
def store(self, abacus_pattern: torch.Tensor):
if len(self.memory) >= self.max_memory_size:
self.memory.pop(0) # Remove the oldest entry
self.memory.append(abacus_pattern.detach().cpu()) # Store detached CPU tensor
def retrieve(self, num_to_retrieve: int = 1) -> list[torch.Tensor]:
# Retrieve the most recent abacus patterns
return self.memory[-num_to_retrieve:]
def clear(self):
self.memory = []
def forward(self, abacus_pattern: torch.Tensor):
# For now, forward simply stores the pattern. More complex logic can be added later.
self.store(abacus_pattern)
return abacus_pattern

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
from nanochat.abacus_encoder import AbacusEncoder
class ConsciousIntegrationLayer(nn.Module):
def __init__(self, config, abacus_encoder: AbacusEncoder):
super().__init__()
self.config = config
self.abacus_encoder = abacus_encoder
# Linear layer to project the integrated state to the vocabulary size
self.concept_projection = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, id_output: torch.Tensor, ego_output: torch.Tensor, superego_output: torch.Tensor, long_term_memory_embeddings: torch.Tensor, memetic_fitness: torch.Tensor | None, abacus_state: torch.Tensor) -> torch.Tensor:
# Ensure all inputs are of the same shape for integration
# For simplicity, let's assume they are all (B, T, C) or (B, C)
# If they are (B, C), we might need to unsqueeze for sequence dimension if T > 1
# Example integration: simple summation. More complex mechanisms can be added here.
# The goal is to synthesize these into a unified conceptual state.
synthesized_state = id_output + ego_output + superego_output
if long_term_memory_embeddings is not None:
# Ensure dimensions match for addition. long_term_memory_embeddings might be (B, C)
# If synthesized_state is (B, T, C), expand long_term_memory_embeddings
if synthesized_state.dim() == 3 and long_term_memory_embeddings.dim() == 2:
long_term_memory_embeddings = long_term_memory_embeddings.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + long_term_memory_embeddings
if memetic_fitness is not None:
# Ensure dimensions match for addition
if synthesized_state.dim() == 3 and memetic_fitness.dim() == 2:
memetic_fitness = memetic_fitness.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + memetic_fitness
# Abacus Encoder for logical consistency
# The abacus_state is already an encoded representation, so we might use it to gate or modulate
# the synthesized_state, or simply include it in the synthesis.
# For now, let's add it, assuming it's compatible.
if synthesized_state.dim() == 3 and abacus_state.dim() == 2:
abacus_state = abacus_state.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + abacus_state
# Project the synthesized state to the vocabulary size
concept_logits_from_synthesis = self.concept_projection(synthesized_state)
return concept_logits_from_synthesis

View File

@ -11,7 +11,9 @@ Notes:
The whole thing is made as efficient as possible. The whole thing is made as efficient as possible.
""" """
import math
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import signal import signal
import warnings import warnings
@ -19,6 +21,15 @@ from contextlib import contextmanager
from collections import deque from collections import deque
from nanochat.common import compute_init from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model from nanochat.checkpoint_manager import load_model
from nanochat.gpt import GPT
from nanochat.kv_cache import KVCache
from nanochat.hypercube import HypercubeTopology, HypercubeEmbeddingLayer, LongTermMemory # Import HypercubeTopology
from nanochat.abacus_encoder import AbacusEncoder
from nanochat.self_model import InternalSelfModel
from nanochat.abacus_state_memory import AbacusStateMemory
from nanochat.memetic_learning import MemeticLearningLayer
from nanochat.conscious_integration import ConsciousIntegrationLayer
from nanochat.gpt import GPT, PsycheController
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
# Calculator tool helpers # Calculator tool helpers
@ -151,76 +162,302 @@ class KVCache:
self.pos = t1 self.pos = t1
return key_view, value_view return key_view, value_view
def retrieve_episode(self, start_pos, end_pos):
"""
Retrieves a slice of the KV cache representing an 'episode' or a specific attention span.
"""
assert self.kv_cache is not None, "KV cache is empty, cannot retrieve episode."
assert start_pos >= 0 and end_pos <= self.pos, "Invalid start or end position for episode retrieval."
assert start_pos < end_pos, "Start position must be less than end position."
# Return a view of the relevant part of the cache
# kv_cache shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
# We want to retrieve for all layers, and both k and v
episode_k = self.kv_cache[:, 0, :, :, start_pos:end_pos, :]
episode_v = self.kv_cache[:, 1, :, :, start_pos:end_pos, :]
return episode_k, episode_v
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
@torch.inference_mode() @torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None): def sample_from_logits(logits, rng, temperature, top_k):
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" """Sample a single next concept ID from given logits of shape (B, num_concept_ids). Returns (B, 1)."""
assert temperature >= 0.0, "temperature must be non-negative" assert temperature >= 0.0, "temperature must be non-negative"
if temperature == 0.0: if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True) return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None: if top_k is not None:
k = min(top_k, logits.size(-1)) v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
vals, idx = torch.topk(logits, k, dim=-1) logits[logits < v[:, [-1]]] = -float('Inf')
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# ----------------------------------------------------------------------------- logits = logits / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=rng)
return idx_next
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
class Engine: class Engine:
def __init__(self, model, tokenizer): def __init__(self, config: GPTConfig):
self.model = model self.config = config
self.tokenizer = tokenizer # needed for tool use self.model = GPT(config)
self.model.eval()
self.model.init_weights()
@torch.inference_mode() # Initialize HypercubeEmbeddingLayer
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): self.concept_embedding_layer = HypercubeEmbeddingLayer(
"""Same as generate, but does single prefill and then clones the KV cache.""" num_raw_concepts=config.num_concept_ids,
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" embedding_dim=config.n_embd,
device = self.model.get_device() hypercube_topology=self.hypercube_topology
rng = torch.Generator(device=device) )
rng.manual_seed(seed)
# Get the special tokens we need to coordinate the tool use state machine # Initialize AbacusEncoder
get_special = lambda s: self.tokenizer.encode_special(s) self.abacus_encoder = AbacusEncoder(
python_start = get_special("<|python_start|>") input_dim=config.abacus_input_dim,
python_end = get_special("<|python_end|>") embedding_dim=config.n_embd
output_start = get_special("<|output_start|>") )
output_end = get_special("<|output_end|>") nn.init.normal_(self.concept_embedding_layer.weight, mean=0.0, std=1.0)
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt tokens # Initialize HypercubeTopology
n_hypercube_dim = int(math.log2(config.num_concept_ids))
if not (2**n_hypercube_dim == config.num_concept_ids):
raise ValueError("config.num_concept_ids must be a power of 2 for hypercube topology.")
self.hypercube_topology = HypercubeTopology(n_hypercube_dim)
# Initialize LongTermMemory
self.long_term_memory = LongTermMemory(
embedding_dim=config.n_embd,
max_memory_size=1000, # Default max memory size
top_k_retrieval=5 # Default top-k retrieval
)
# Initialize InternalSelfModel
self.internal_self_model = InternalSelfModel(
embedding_dim=config.n_embd,
num_concepts=config.num_concept_ids
)
# Initialize AbacusStateMemory
self.abacus_state_memory = AbacusStateMemory(
max_memory_size=100, # Default max memory size
abacus_input_dim=config.abacus_input_dim
)
# Initialize MemeticLearningLayer
self.memetic_learning_layer = MemeticLearningLayer(
config=config,
abacus_encoder=self.abacus_encoder,
internal_self_model=self.internal_self_model
)
# Initialize PsycheController
self.psyche_controller = PsycheController(
config=config
)
# Initialize ConsciousIntegrationLayer
self.conscious_integration_layer = ConsciousIntegrationLayer(
config=config,
abacus_encoder=self.abacus_encoder
)
self._concept_encoder_placeholder = None # This will be set later
self._concept_attention_fusion_placeholder = nn.Linear(config.n_embd * 2, config.n_embd)
# Placeholder for concept encoder
self.concept_encoder = self._concept_encoder_placeholder
def _concept_encoder_placeholder(self, concept_list: list[int] | dict) -> torch.Tensor:
"""
Placeholder for a more sophisticated concept encoder.
For now, it assumes concept_list directly contains integer concept IDs or a dict with 'abacus_pattern'.
"""
if isinstance(concept_list, dict) and "abacus_pattern" in concept_list:
# If an abacus pattern is provided, encode it using the AbacusEncoder
abacus_pattern = concept_list["abacus_pattern"]
# Ensure abacus_pattern is a tensor and has the correct shape
if not isinstance(abacus_pattern, torch.Tensor):
abacus_pattern = torch.tensor(abacus_pattern, dtype=torch.float32)
# Add batch dimension if missing
if abacus_pattern.dim() == 1:
abacus_pattern = abacus_pattern.unsqueeze(0)
return self.abacus_encoder(abacus_pattern)
elif isinstance(concept_list, list):
# Otherwise, assume it's a list of integer concept IDs
return torch.tensor(concept_list, dtype=torch.long, device=self.gpt.get_device())
else:
raise ValueError("concept_list must be a list of integers or a dict with 'abacus_pattern'.")
def _concept_memory_retrieve_placeholder(self, current_embedding: torch.Tensor) -> torch.Tensor:
# Placeholder for concept memory retrieval logic
# This would typically involve searching a memory bank for relevant concepts
# based on the current_embedding and returning their embeddings.
# For now, it returns a zero tensor of appropriate shape.
return torch.zeros_like(current_embedding)
def _concept_attention_fusion_placeholder(self, transformer_output: torch.Tensor, retrieved_concepts: torch.Tensor) -> torch.Tensor:
# Placeholder for concept attention fusion logic
# This would combine the transformer's output with the retrieved concept embeddings
# using some attention mechanism.
# For now, it just returns the transformer_output unchanged.
return transformer_output
def sample_from_logits(self, concept_logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
# Apply temperature
if temperature == 0.0:
next_concept_id = torch.argmax(concept_logits, dim=-1)
else:
concept_logits = concept_logits / temperature
# Apply top-k filtering
if top_k is not None:
v, _ = torch.topk(concept_logits, min(top_k, concept_logits.size(-1)))
concept_logits[concept_logits < v[:, [-1]]] = -float('Inf')
probs = torch.softmax(concept_logits, dim=-1)
next_concept_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
return next_concept_id
@torch.no_grad()
def generate(self, input_embeddings: torch.Tensor, max_new_concepts: int = 20, temperature: float = 1.0, abacus_embedding: torch.Tensor | None = None, working_memory_window: int = 0) -> list[int]:
B, T, C = input_embeddings.size()
generated_embeddings = []
if abacus_embedding is None:
abacus_embedding = torch.zeros(B, 1, self.config.abacus_input_dim, device=input_embeddings.device)
# Long-Term Memory retrieval for prefill
prefill_long_term_memory_embeddings = self.long_term_memory.retrieve(input_embeddings[:, -1, :].squeeze(0))
# Get psyche weights from PsycheController for prefill
prefill_psyche_weights = self.psyche_controller(input_embeddings[:, -1, :])
# Prefill the model with input_embeddings
concept_logits, kv_cache, x_id_prefill, x_ego_prefill, x_superego_prefill = self.model.forward_prefill(
input_embeddings,
abacus_embedding=abacus_embedding,
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
psyche_weights=prefill_psyche_weights
)
# Conscious Integration Layer for prefill
synthesized_state_prefill = self.conscious_integration_layer.forward(
id_output=x_id_prefill,
ego_output=x_ego_prefill,
superego_output=x_superego_prefill,
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
memetic_fitness=None, # Memetic fitness is not available during prefill
abacus_state=abacus_embedding
)
# Combine concept_logits from GPT and synthesized_state_prefill
concept_logits = concept_logits + synthesized_state_prefill
# Sample the first concept ID from the last token's logits
next_concept_id = self.sample_from_logits(concept_logits[:, -1, :], temperature)
next_embedding = self.concept_embedding_layer(next_concept_id)
generated_embeddings.append(next_embedding)
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
# Abacus State Memory and Encoder integration for the first step
abacus_pattern = self.abacus_encoder(next_embedding)
self.abacus_state_memory.store(abacus_pattern)
abacus_embedding = self.abacus_state_memory.retrieve()
for _ in range(max_new_concepts - 1):
# Working Memory retrieval
episodic_kv = None
if working_memory_window > 0 and kv_cache.get_pos() > working_memory_window:
start_pos = kv_cache.get_pos() - working_memory_window
end_pos = kv_cache.get_pos()
episodic_kv = kv_cache.retrieve_episode(start_pos, end_pos)
# Long-Term Memory retrieval
long_term_memory_embeddings = self.long_term_memory.retrieve(next_embedding.squeeze(0))
# Get psyche weights from PsycheController
psyche_weights = self.psyche_controller(next_embedding)
concept_logits, kv_cache, x_id_step, x_ego_step, x_superego_step = self.model.forward_step(
next_embedding,
kv_cache,
abacus_embedding=abacus_embedding,
episodic_kv=episodic_kv,
long_term_memory_embeddings=long_term_memory_embeddings,
psyche_weights=psyche_weights
)
# Memetic Learning Layer integration
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
# Conscious Integration Layer for step
synthesized_state_step = self.conscious_integration_layer.forward(
id_output=x_id_step,
ego_output=x_ego_step,
superego_output=x_superego_step,
long_term_memory_embeddings=long_term_memory_embeddings,
memetic_fitness=memetic_fitness,
abacus_state=abacus_embedding
)
# Combine concept_logits from GPT and synthesized_state_step
concept_logits = concept_logits + synthesized_state_step.squeeze(1) # Squeeze to match dimensions
next_concept_id = self.sample_from_logits(concept_logits, temperature)
next_embedding = self.concept_embedding_layer(next_concept_id)
generated_embeddings.append(next_embedding)
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
# Abacus State Memory and Encoder integration for subsequent steps
abacus_pattern = self.abacus_encoder(next_embedding)
self.abacus_state_memory.store(abacus_pattern)
abacus_embedding = self.abacus_state_memory.retrieve() # Update abacus_embedding for the next step
# Memetic Learning Layer integration
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
return torch.stack(generated_embeddings, dim=1)
def generate_from_concepts(self, concept_list: list[int] | dict, max_new_concepts: int = 20, temperature: float = 1.0) -> list[int]:
# Encode the concept_list into initial input embeddings
encoded_concepts = self.concept_encoder(concept_list)
if encoded_concepts.dtype == torch.long:
# If it's concept IDs, get embeddings from the concept_embedding_layer
input_embeddings = self.concept_embedding_layer(encoded_concepts)
abacus_embedding = None # No abacus embedding in this case
elif encoded_concepts.dtype == torch.float:
# If it's an abacus embedding, use it directly and set abacus_embedding
input_embeddings = encoded_concepts
abacus_embedding = encoded_concepts # The abacus embedding is the input embedding itself
else:
raise TypeError("Unexpected return type from concept_encoder.")
# Call the main generate method
return self.generate(input_embeddings, max_new_concepts, temperature, abacus_embedding=abacus_embedding)
# Special tokens are no longer directly used with concept embeddings.
# The tool use logic will need to be re-evaluated or removed if not applicable.
# get_special = lambda s: self.tokenizer.encode_special(s)
# python_start = get_special("<|python_start|>")
# python_end = get_special("<|python_end|>")
# output_start = get_special("<|output_start|>")
# output_end = get_special("<|output_end|>")
# assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
# bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt embeddings
m = self.model.config m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache( kv_cache_prefill = KVCache(
batch_size=1, batch_size=input_embeddings.size(0),
seq_len=len(tokens), seq_len=input_embeddings.size(1),
**kv_model_kwargs, **kv_model_kwargs,
) )
ids = torch.tensor([tokens], dtype=torch.long, device=device) logits = self.model.forward(input_embeddings, kv_cache=kv_cache_prefill)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill) # Removed token-based sampling logic (replace with embedding generation logic later)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
# 2) Replicate the KV cache for each sample/row # 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len kv_length_hint = (input_embeddings.size(1) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache( kv_cache_decode = KVCache(
batch_size=num_samples, batch_size=num_samples,
seq_len=kv_length_hint, seq_len=kv_length_hint,
@ -230,7 +467,7 @@ class Engine:
del kv_cache_prefill # no need to keep this memory around del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample # 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)] row_states = [RowState(input_embeddings[i].tolist()) for i in range(num_samples)] # Assuming input_embeddings is (B, T, C)
# 4) Main generation loop # 4) Main generation loop
num_generated = 0 num_generated = 0

View File

@ -25,13 +25,124 @@ from nanochat.adamw import DistAdamW
@dataclass @dataclass
class GPTConfig: class GPTConfig:
sequence_len: int = 1024 def __init__(self,
vocab_size: int = 50304 n_layer=12,
n_layer: int = 12 n_head=12,
n_head: int = 6 # number of query heads n_embd=768,
n_kv_head: int = 6 # number of key/value heads (MQA) sequence_len=1024,
n_embd: int = 768 n_kv_head=None,
num_concept_ids=4096, # Updated to 4096 for n=12 hypercube
abacus_input_dim=64, # Default input dimension for the AbacusEncoder
dropout=0.0,
bias=False,
multiple_of=256,
norm_eps=1e-5,
rope_theta=10000,
# For training
batch_size=1,
gradient_accumulation_steps=1,
max_iters=0,
lr=6e-4,
min_lr=6e-5,
weight_decay=1e-1,
beta1=0.9,
beta2=0.95,
grad_clip=1.0,
decay_lr=True,
warmup_iters=2000,
lr_decay_iters=600000,
# For checkpointing
out_dir='out',
eval_interval=2000,
log_interval=1,
eval_iters=200,
eval_only=False,
always_save_checkpoint=True,
# For distributed training
backend='nccl',
# For system
device='cpu',
dtype='bfloat16',
compile=False,
# For data
dataset='openwebtext',
# For inference
init_from='scratch',
# For chat
chat=False,
# For concept
concept_memory_size=1000,
concept_memory_top_k=5,
use_concept_attention=False,
# For psyche
psyche_id_lr_scale=1.0,
psyche_ego_lr_scale=1.0,
psyche_superego_lr_scale=1.0,
**kwargs):
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.sequence_len = sequence_len
self.n_kv_head = n_kv_head if n_kv_head is not None else n_head
self.num_concept_ids = num_concept_ids
self.dropout = dropout
self.bias = bias
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.rope_theta = rope_theta
self.batch_size = batch_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_iters = max_iters
self.lr = lr
self.min_lr = min_lr
self.weight_decay = weight_decay
self.beta1 = beta1
self.beta2 = beta2
self.grad_clip = grad_clip
self.decay_lr = decay_lr
self.warmup_iters = warmup_iters
self.lr_decay_iters = lr_decay_iters
self.out_dir = out_dir
self.eval_interval = eval_interval
self.log_interval = log_interval
self.eval_iters = eval_iters
self.eval_only = eval_only
self.always_save_checkpoint = always_save_checkpoint
self.backend = backend
self.device = device
self.dtype = dtype
self.compile = compile
self.dataset = dataset
self.init_from = init_from
self.chat = chat
self.concept_memory_size = concept_memory_size
self.concept_memory_top_k = concept_memory_top_k
self.use_concept_attention = use_concept_attention
self.psyche_id_lr_scale = psyche_id_lr_scale
self.psyche_ego_lr_scale = psyche_ego_lr_scale
self.psyche_superego_lr_scale = psyche_superego_lr_scale
for k, v in kwargs.items():
setattr(self, k, v)
def norm(x): def norm(x):
# Purely functional rmsnorm with no learnable params # Purely functional rmsnorm with no learnable params
@ -63,7 +174,7 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache): def forward(self, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
B, T, C = x.size() B, T, C = x.size()
# Project the input to get queries, keys, and values # Project the input to get queries, keys, and values
@ -80,6 +191,14 @@ class CausalSelfAttention(nn.Module):
# Apply KV cache: insert current k,v into cache, get the full view so far # Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None: if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v) k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# If episodic_kv is provided, prepend it to the current k and v
if episodic_kv is not None:
episode_k_layer = episodic_kv[self.layer_idx, 0]
episode_v_layer = episodic_kv[self.layer_idx, 1]
k = torch.cat([episode_k_layer, k], dim=2)
v = torch.cat([episode_v_layer, v], dim=2)
Tq = q.size(2) # number of queries in this forward pass Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
@ -135,15 +254,42 @@ class Block(nn.Module):
return x return x
class PsycheController(nn.Module):
def __init__(self, config):
super().__init__()
self.controller_head = nn.Linear(config.n_embd, 3) # 3 for id, ego, superego
def forward(self, x):
# For now, just return equal weights for each psyche layer
# In the future, this will be trained to dynamically blend psyche outputs
return torch.softmax(self.controller_head(x), dim=-1)
class GPT(nn.Module): class GPT(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config self.config = config
self.transformer = nn.ModuleDict({ self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
}) })
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Partition transformer layers into psyche layers
total_layers = config.n_layer
id_end = total_layers // 3
ego_end = 2 * total_layers // 3
self.id_layers = self.transformer.h[:id_end]
self.ego_layers = self.transformer.h[id_end:ego_end]
self.superego_layers = self.transformer.h[ego_end:]
self.psyche_registry = {
"id": self.id_layers,
"ego": self.ego_layers,
"superego": self.superego_layers
}
self.concept_head = nn.Linear(config.n_embd, config.num_concept_ids, bias=False) # New concept head
self.psyche_controller = PsycheController(config) # Initialize PsycheController
# To support meta device initialization, we init the rotary embeddings here, but it's fake # To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount. # so let's just over-compute them, but assert fail if we ever reach that amount.
@ -156,8 +302,8 @@ class GPT(nn.Module):
def init_weights(self): def init_weights(self):
self.apply(self._init_weights) self.apply(self._init_weights)
# zero out classifier weights # zero out concept_head weights
torch.nn.init.zeros_(self.lm_head.weight) torch.nn.init.zeros_(self.concept_head.weight)
# zero out c_proj weights in all blocks # zero out c_proj weights in all blocks
for block in self.transformer.h: for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight) torch.nn.init.zeros_(block.mlp.c_proj.weight)
@ -166,9 +312,6 @@ class GPT(nn.Module):
head_dim = self.config.n_embd // self.config.n_head head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
@ -186,7 +329,7 @@ class GPT(nn.Module):
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings # autodetect the device from model embeddings
if device is None: if device is None:
device = self.transformer.wte.weight.device device = self.concept_head.weight.device
# stride the channels # stride the channels
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim)) inv_freq = 1.0 / (base ** (channel_range / head_dim))
@ -200,12 +343,13 @@ class GPT(nn.Module):
return cos, sin return cos, sin
def get_device(self): def get_device(self):
return self.transformer.wte.weight.device # Get device from concept_head weight
return self.concept_head.weight.device
def estimate_flops(self): def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters()) nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel() nparams_embedding = 0 # No separate embedding layer now
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token return num_flops_per_token
@ -213,95 +357,190 @@ class GPT(nn.Module):
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info() ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head) # Separate out all parameters into 3 groups (matrix, embedding, concept_head)
matrix_params = list(self.transformer.h.parameters()) # matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters()) id_params = list(self.id_layers.parameters())
lm_head_params = list(self.lm_head.parameters()) ego_params = list(self.ego_layers.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) superego_params = list(self.superego_layers.parameters())
# Create the AdamW optimizer for the embedding and lm_head
embedding_params = [] # No separate embedding layer now
concept_head_params = list(self.concept_head.parameters()) # New concept head params
psyche_controller_params = list(self.psyche_controller.parameters()) # Psyche controller params
# assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(concept_head_params)
# Create the AdamW optimizer for the embedding and concept_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5 dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0: if rank == 0:
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [ adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), dict(params=concept_head_params, lr=unembedding_lr * dmodel_lr_scale), # Use unembedding_lr for concept_head
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=id_params, lr=3e-4 * dmodel_lr_scale), # Id layers learning rate
dict(params=ego_params, lr=1e-4 * dmodel_lr_scale), # Ego layers learning rate
dict(params=superego_params, lr=5e-5 * dmodel_lr_scale), # Superego layers learning rate
dict(params=psyche_controller_params, lr=1e-4 * dmodel_lr_scale), # Psyche controller learning rate
] ]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers # Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95) # muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
MuonFactory = DistMuon if ddp else Muon # MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) # muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list # Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer] optimizers = [adamw_optimizer]
for opt in optimizers: for opt in optimizers:
for group in opt.param_groups: for group in opt.param_groups:
group["initial_lr"] = group["lr"] group["initial_lr"] = group["lr"]
return optimizers return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): def _run_layers(self, layers, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
B, T = idx.size() for block in layers:
x = block(x, cos_sin, kv_cache, episodic_kv)
return x
def forward(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, T, C = input_embeddings.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache x = input_embeddings
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length # Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Final concept head
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
def forward_prefill(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, T, C = input_embeddings.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
x = input_embeddings
# Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Final concept head
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
def forward_step(self, next_embedding: torch.Tensor, kv_cache, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, C = next_embedding.size()
T = kv_cache[0].size(1) + 1 # Current sequence length after adding next_embedding
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
cos_sin = self.cos[:, T-1:T, :, :], self.sin[:, T-1:T, :, :]
x = next_embedding.unsqueeze(1) # Add sequence dimension for consistency
# Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x) x = norm(x)
for block in self.transformer.h: return self.concept_head(x.squeeze(1)), kv_cache, x_id, x_ego, x_superego
x = block(x, cos_sin, kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume:
- batch size is 1
- ids and the yielded tokens are simple Python lists and ints
"""
assert isinstance(tokens, list)
device = self.get_device()
rng = None
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token

149
nanochat/hypercube.py Normal file
View File

@ -0,0 +1,149 @@
import torch
import torch.nn as nn
class HypercubeTopology:
def __init__(self, n: int):
if not (1 <= n <= 16): # Limiting n for practical reasons, can be adjusted
raise ValueError("Hypercube dimension 'n' must be between 1 and 16.")
self.n = n
self.num_vertices = 2**n
def to_binary(self, vertex_id: int) -> str:
if not (0 <= vertex_id < self.num_vertices):
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
return format(vertex_id, f'0{self.n}b')
def from_binary(self, binary_string: str) -> int:
if len(binary_string) != self.n:
raise ValueError(f"Binary string length must be {self.n} for a {self.n}-dimensional hypercube.")
return int(binary_string, 2)
def get_neighbors(self, vertex_id: int) -> list[int]:
if not (0 <= vertex_id < self.num_vertices):
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
neighbors = []
for i in range(self.n):
# Flip the i-th bit
neighbor_id = vertex_id ^ (1 << i)
neighbors.append(neighbor_id)
return neighbors
def get_all_edges(self) -> list[tuple[int, int]]:
edges = []
for i in range(self.num_vertices):
for neighbor in self.get_neighbors(i):
# Ensure each edge is added only once (e.g., (0,1) not (1,0))
if i < neighbor:
edges.append((i, neighbor))
return edges
def get_random_vertex(self) -> int:
return torch.randint(0, self.num_vertices, (1,)).item()
def get_random_path(self, start_vertex: int, length: int) -> list[int]:
if not (0 <= start_vertex < self.num_vertices):
raise ValueError(f"Start vertex ID {start_vertex} is out of range.")
if length < 1:
raise ValueError("Path length must be at least 1.")
path = [start_vertex]
current_vertex = start_vertex
for _ in range(length - 1):
neighbors = self.get_neighbors(current_vertex)
if not neighbors:
break # Should not happen in a hypercube unless n=0
# Choose a random neighbor that is not the previous vertex if possible
next_vertex = neighbors[torch.randint(0, len(neighbors), (1,)).item()]
path.append(next_vertex)
current_vertex = next_vertex
return path
class HypercubeEmbeddingLayer(nn.Module):
def __init__(self, num_raw_concepts, embedding_dim, hypercube_topology):
super().__init__()
self.raw_embedding_layer = nn.Embedding(num_raw_concepts, embedding_dim)
self.hypercube_topology = hypercube_topology
# Embeddings for hypercube vertices. The number of vertices is hypercube_topology.num_vertices
self.vertex_embeddings = nn.Embedding(hypercube_topology.num_vertices, embedding_dim)
def forward(self, concept_ids):
# Get initial embeddings for the input concept_ids
initial_embeddings = self.raw_embedding_layer(concept_ids) # (batch_size, embedding_dim)
# Find the nearest hypercube vertex for each initial embedding
# Get all hypercube vertex embeddings
# Ensure all_vertex_ids are on the same device as concept_ids
all_vertex_ids = torch.arange(self.hypercube_topology.num_vertices, device=concept_ids.device)
all_vertex_embeddings = self.vertex_embeddings(all_vertex_ids) # (num_vertices, embedding_dim)
# Calculate squared Euclidean distance
# initial_embeddings_expanded: (batch_size, 1, embedding_dim)
# all_vertex_embeddings_expanded: (1, num_vertices, embedding_dim)
initial_embeddings_expanded = initial_embeddings.unsqueeze(1)
all_vertex_embeddings_expanded = all_vertex_embeddings.unsqueeze(0)
distances = torch.sum((initial_embeddings_expanded - all_vertex_embeddings_expanded)**2, dim=2) # (batch_size, num_vertices)
# Find the index of the nearest vertex for each initial embedding
nearest_vertex_indices = torch.argmin(distances, dim=1) # (batch_size,)
# Retrieve the embeddings of the nearest vertices
final_embeddings = self.vertex_embeddings(nearest_vertex_indices)
return final_embeddings
class LongTermMemory(nn.Module):
def __init__(self, embedding_dim: int, max_memory_size: int = 1000, top_k_retrieval: int = 5):
super().__init__()
self.embedding_dim = embedding_dim
self.max_memory_size = max_memory_size
self.top_k_retrieval = top_k_retrieval
# Initialize an empty memory bank. We'll use a list of tensors for simplicity initially,
# but this could be replaced with a more efficient data structure or a learnable embedding layer.
self.memory_bank = []
self.memory_bank_tensor = None # Will store concatenated memories for efficient retrieval
def store(self, embedding: torch.Tensor):
# Store a new embedding in the memory bank
if len(self.memory_bank) >= self.max_memory_size:
# Simple FIFO eviction for now
self.memory_bank.pop(0)
self.memory_bank.append(embedding.detach().cpu()) # Store on CPU to save GPU memory
self.memory_bank_tensor = None # Invalidate cached tensor
def retrieve(self, query_embedding: torch.Tensor) -> torch.Tensor:
# Retrieve top-k most similar embeddings from the memory bank
if not self.memory_bank:
return torch.zeros(query_embedding.shape[0], self.top_k_retrieval, self.embedding_dim, device=query_embedding.device)
if self.memory_bank_tensor is None:
self.memory_bank_tensor = torch.stack(self.memory_bank).to(query_embedding.device)
# Normalize query and memory bank for cosine similarity
query_norm = F.normalize(query_embedding, p=2, dim=-1)
memory_norm = F.normalize(self.memory_bank_tensor, p=2, dim=-1)
# Calculate cosine similarity
# query_norm: (batch_size, embedding_dim)
# memory_norm: (num_memories, embedding_dim)
# similarities: (batch_size, num_memories)
similarities = torch.matmul(query_norm, memory_norm.transpose(0, 1))
# Get top-k similar memories
# top_k_values: (batch_size, top_k_retrieval)
# top_k_indices: (batch_size, top_k_retrieval)
top_k_values, top_k_indices = torch.topk(similarities, min(self.top_k_retrieval, len(self.memory_bank)), dim=-1)
# Retrieve the actual embeddings
retrieved_memories = self.memory_bank_tensor[top_k_indices]
return retrieved_memories
def forward(self, query_embedding: torch.Tensor) -> torch.Tensor:
return self.retrieve(query_embedding)

View File

@ -0,0 +1,46 @@
import torch
import torch.nn as nn
class MemeticLearningLayer(nn.Module):
def __init__(self, config, abacus_encoder, internal_self_model):
super().__init__()
self.config = config
self.abacus_encoder = abacus_encoder
self.internal_self_model = internal_self_model
# Placeholder for memetic fitness evaluation mechanism
# This could be a simple linear layer or a more complex network
self.fitness_evaluator = nn.Linear(config.abacus_input_dim, 1)
# Placeholder for concept mapping expansion (analogy/metaphor)
# This might involve a transformation of embeddings or a retrieval mechanism
self.concept_mapper = nn.Sequential(
nn.Linear(config.n_embd * 2, config.n_embd), # Input: two concept embeddings
nn.ReLU(),
nn.Linear(config.n_embd, config.n_embd)
)
def evaluate_memetic_fitness(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
# Evaluate the fitness of a memetic pattern using the abacus encoder output
fitness_score = self.fitness_evaluator(abacus_pattern)
return fitness_score
def expand_concept_mapping(self, concept1_embedding: torch.Tensor, concept2_embedding: torch.Tensor) -> torch.Tensor:
# Expand concept mapping via analogy and metaphor
# This takes two concept embeddings and generates a new, related concept embedding
combined_concepts = torch.cat([concept1_embedding, concept2_embedding], dim=-1)
new_concept_embedding = self.concept_mapper(combined_concepts)
return new_concept_embedding
def forward(self, current_concept_embedding: torch.Tensor, abacus_pattern: torch.Tensor):
# Orchestrates the memetic learning process
# 1. Evaluate memetic fitness
fitness = self.evaluate_memetic_fitness(abacus_pattern)
# 2. Potentially update internal self-model based on fitness or new concepts
# self.internal_self_model.update_beliefs(current_concept_embedding) # Example interaction
# 3. Generate new concepts or analogies
# new_concept = self.expand_concept_mapping(current_concept_embedding, some_other_concept)
return fitness

58
nanochat/self_model.py Normal file
View File

@ -0,0 +1,58 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class InternalSelfModel(nn.Module):
def __init__(self, embedding_dim: int, config):
super().__init__()
self.embedding_dim = embedding_dim
self.config = config
# Placeholder for belief representation (e.g., a learned embedding or a set of parameters)
self.belief_embedding = nn.Parameter(torch.randn(embedding_dim))
# Placeholder for a simple prediction head to calculate prediction error
self.prediction_head = nn.Linear(embedding_dim, 1) # Predicts a scalar error signal
# Placeholder for conceptual growth mechanism (e.g., a small MLP or attention mechanism)
self.conceptual_growth_mlp = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim), # Input: current belief + new concept
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim)
)
def update_beliefs(self, current_concept_embedding: torch.Tensor):
# This method would update the internal belief representation based on new information.
# For now, a simple update rule (e.g., moving average or attention-based update)
# In a more advanced implementation, this could involve a recurrent neural network.
# Example: simple weighted average
alpha = 0.1 # Learning rate for belief update
self.belief_embedding.data = (1 - alpha) * self.belief_embedding.data + alpha * current_concept_embedding.mean(dim=0)
def calculate_prediction_error(self, predicted_output: torch.Tensor, actual_output: torch.Tensor) -> torch.Tensor:
# This method calculates the discrepancy between predicted and actual outcomes.
# For now, a simple mean squared error.
error = F.mse_loss(predicted_output, actual_output)
return error
def promote_conceptual_growth(self, current_belief_embedding: torch.Tensor, new_concept_embedding: torch.Tensor) -> torch.Tensor:
# This method facilitates the integration of new concepts into the existing conceptual framework.
# For now, a simple MLP that takes the concatenation of current belief and new concept.
combined_embedding = torch.cat([current_belief_embedding, new_concept_embedding], dim=-1)
updated_concept_embedding = self.conceptual_growth_mlp(combined_embedding)
return updated_concept_embedding
def forward(self, current_concept_embedding: torch.Tensor, predicted_output: torch.Tensor, actual_output: torch.Tensor):
# Update beliefs
self.update_beliefs(current_concept_embedding)
# Calculate prediction error
error = self.calculate_prediction_error(predicted_output, actual_output)
# Promote conceptual growth (example: using the current belief and concept embedding)
# This part would be more complex in a full implementation, potentially involving attention
# or other mechanisms to decide how to grow concepts based on error and new info.
# For demonstration, let's assume conceptual growth is triggered by new concept embeddings.
# updated_concept = self.promote_conceptual_growth(self.belief_embedding, current_concept_embedding)
return error, self.belief_embedding