hypercube-abacus

This commit is contained in:
Nda-jiya Suberu 2025-10-31 03:27:16 +00:00
parent dfc88334b6
commit 317e4b65df
9 changed files with 966 additions and 137 deletions

2
config/custom_config.py Normal file
View File

@ -0,0 +1,2 @@
n_layer = 24
n_embd = 1024

View File

@ -0,0 +1,24 @@
import torch
import torch.nn as nn
class AbacusEncoder(nn.Module):
def __init__(self, input_dim: int, embedding_dim: int):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
# Simple linear layer to encode abacus-like patterns into the embedding space
self.encoder_layer = nn.Linear(input_dim, embedding_dim)
def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
# abacus_pattern is expected to be a tensor of shape (batch_size, input_dim)
if abacus_pattern.shape[-1] != self.input_dim:
raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}")
return self.encoder_layer(abacus_pattern)
def decode(self, concept_vector: torch.Tensor) -> torch.Tensor:
# Placeholder for decoding functionality
raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.")
def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
return self.encode(abacus_pattern)

View File

@ -0,0 +1,26 @@
import torch
import torch.nn as nn
class AbacusStateMemory(nn.Module):
def __init__(self, max_memory_size: int = 100, abacus_input_dim: int = 64):
super().__init__()
self.max_memory_size = max_memory_size
self.abacus_input_dim = abacus_input_dim
self.memory = [] # Stores abacus patterns (tensors)
def store(self, abacus_pattern: torch.Tensor):
if len(self.memory) >= self.max_memory_size:
self.memory.pop(0) # Remove the oldest entry
self.memory.append(abacus_pattern.detach().cpu()) # Store detached CPU tensor
def retrieve(self, num_to_retrieve: int = 1) -> list[torch.Tensor]:
# Retrieve the most recent abacus patterns
return self.memory[-num_to_retrieve:]
def clear(self):
self.memory = []
def forward(self, abacus_pattern: torch.Tensor):
# For now, forward simply stores the pattern. More complex logic can be added later.
self.store(abacus_pattern)
return abacus_pattern

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
from nanochat.abacus_encoder import AbacusEncoder
class ConsciousIntegrationLayer(nn.Module):
def __init__(self, config, abacus_encoder: AbacusEncoder):
super().__init__()
self.config = config
self.abacus_encoder = abacus_encoder
# Linear layer to project the integrated state to the vocabulary size
self.concept_projection = nn.Linear(config.n_embd, config.vocab_size)
def forward(self, id_output: torch.Tensor, ego_output: torch.Tensor, superego_output: torch.Tensor, long_term_memory_embeddings: torch.Tensor, memetic_fitness: torch.Tensor | None, abacus_state: torch.Tensor) -> torch.Tensor:
# Ensure all inputs are of the same shape for integration
# For simplicity, let's assume they are all (B, T, C) or (B, C)
# If they are (B, C), we might need to unsqueeze for sequence dimension if T > 1
# Example integration: simple summation. More complex mechanisms can be added here.
# The goal is to synthesize these into a unified conceptual state.
synthesized_state = id_output + ego_output + superego_output
if long_term_memory_embeddings is not None:
# Ensure dimensions match for addition. long_term_memory_embeddings might be (B, C)
# If synthesized_state is (B, T, C), expand long_term_memory_embeddings
if synthesized_state.dim() == 3 and long_term_memory_embeddings.dim() == 2:
long_term_memory_embeddings = long_term_memory_embeddings.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + long_term_memory_embeddings
if memetic_fitness is not None:
# Ensure dimensions match for addition
if synthesized_state.dim() == 3 and memetic_fitness.dim() == 2:
memetic_fitness = memetic_fitness.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + memetic_fitness
# Abacus Encoder for logical consistency
# The abacus_state is already an encoded representation, so we might use it to gate or modulate
# the synthesized_state, or simply include it in the synthesis.
# For now, let's add it, assuming it's compatible.
if synthesized_state.dim() == 3 and abacus_state.dim() == 2:
abacus_state = abacus_state.unsqueeze(1).expand(-1, synthesized_state.size(1), -1)
synthesized_state = synthesized_state + abacus_state
# Project the synthesized state to the vocabulary size
concept_logits_from_synthesis = self.concept_projection(synthesized_state)
return concept_logits_from_synthesis

View File

@ -11,7 +11,9 @@ Notes:
The whole thing is made as efficient as possible.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import signal
import warnings
@ -19,6 +21,15 @@ from contextlib import contextmanager
from collections import deque
from nanochat.common import compute_init
from nanochat.checkpoint_manager import load_model
from nanochat.gpt import GPT
from nanochat.kv_cache import KVCache
from nanochat.hypercube import HypercubeTopology, HypercubeEmbeddingLayer, LongTermMemory # Import HypercubeTopology
from nanochat.abacus_encoder import AbacusEncoder
from nanochat.self_model import InternalSelfModel
from nanochat.abacus_state_memory import AbacusStateMemory
from nanochat.memetic_learning import MemeticLearningLayer
from nanochat.conscious_integration import ConsciousIntegrationLayer
from nanochat.gpt import GPT, PsycheController
# -----------------------------------------------------------------------------
# Calculator tool helpers
@ -151,76 +162,302 @@ class KVCache:
self.pos = t1
return key_view, value_view
def retrieve_episode(self, start_pos, end_pos):
"""
Retrieves a slice of the KV cache representing an 'episode' or a specific attention span.
"""
assert self.kv_cache is not None, "KV cache is empty, cannot retrieve episode."
assert start_pos >= 0 and end_pos <= self.pos, "Invalid start or end position for episode retrieval."
assert start_pos < end_pos, "Start position must be less than end position."
# Return a view of the relevant part of the cache
# kv_cache shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim)
# We want to retrieve for all layers, and both k and v
episode_k = self.kv_cache[:, 0, :, :, start_pos:end_pos, :]
episode_v = self.kv_cache[:, 1, :, :, start_pos:end_pos, :]
return episode_k, episode_v
# -----------------------------------------------------------------------------
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=1.0, top_k=None):
"""Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1)."""
def sample_from_logits(logits, rng, temperature, top_k):
"""Sample a single next concept ID from given logits of shape (B, num_concept_ids). Returns (B, 1)."""
assert temperature >= 0.0, "temperature must be non-negative"
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
else:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
# -----------------------------------------------------------------------------
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1, generator=rng)
return idx_next
class RowState:
# Per-row state tracking during generation
def __init__(self, current_tokens=None):
self.current_tokens = current_tokens or [] # Current token sequence for this row
self.forced_tokens = deque() # Queue of tokens to force inject
self.in_python_block = False # Whether we are inside a python block
self.python_expr_tokens = [] # Tokens of the current python expression
self.completed = False # Whether this row has completed generation
class Engine:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer # needed for tool use
def __init__(self, config: GPTConfig):
self.config = config
self.model = GPT(config)
self.model.eval()
self.model.init_weights()
@torch.inference_mode()
def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42):
"""Same as generate, but does single prefill and then clones the KV cache."""
assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints"
device = self.model.get_device()
rng = torch.Generator(device=device)
rng.manual_seed(seed)
# Initialize HypercubeEmbeddingLayer
self.concept_embedding_layer = HypercubeEmbeddingLayer(
num_raw_concepts=config.num_concept_ids,
embedding_dim=config.n_embd,
hypercube_topology=self.hypercube_topology
)
# Get the special tokens we need to coordinate the tool use state machine
get_special = lambda s: self.tokenizer.encode_special(s)
python_start = get_special("<|python_start|>")
python_end = get_special("<|python_end|>")
output_start = get_special("<|output_start|>")
output_end = get_special("<|output_end|>")
assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# Initialize AbacusEncoder
self.abacus_encoder = AbacusEncoder(
input_dim=config.abacus_input_dim,
embedding_dim=config.n_embd
)
nn.init.normal_(self.concept_embedding_layer.weight, mean=0.0, std=1.0)
# 1) Run a batch 1 prefill of the prompt tokens
# Initialize HypercubeTopology
n_hypercube_dim = int(math.log2(config.num_concept_ids))
if not (2**n_hypercube_dim == config.num_concept_ids):
raise ValueError("config.num_concept_ids must be a power of 2 for hypercube topology.")
self.hypercube_topology = HypercubeTopology(n_hypercube_dim)
# Initialize LongTermMemory
self.long_term_memory = LongTermMemory(
embedding_dim=config.n_embd,
max_memory_size=1000, # Default max memory size
top_k_retrieval=5 # Default top-k retrieval
)
# Initialize InternalSelfModel
self.internal_self_model = InternalSelfModel(
embedding_dim=config.n_embd,
num_concepts=config.num_concept_ids
)
# Initialize AbacusStateMemory
self.abacus_state_memory = AbacusStateMemory(
max_memory_size=100, # Default max memory size
abacus_input_dim=config.abacus_input_dim
)
# Initialize MemeticLearningLayer
self.memetic_learning_layer = MemeticLearningLayer(
config=config,
abacus_encoder=self.abacus_encoder,
internal_self_model=self.internal_self_model
)
# Initialize PsycheController
self.psyche_controller = PsycheController(
config=config
)
# Initialize ConsciousIntegrationLayer
self.conscious_integration_layer = ConsciousIntegrationLayer(
config=config,
abacus_encoder=self.abacus_encoder
)
self._concept_encoder_placeholder = None # This will be set later
self._concept_attention_fusion_placeholder = nn.Linear(config.n_embd * 2, config.n_embd)
# Placeholder for concept encoder
self.concept_encoder = self._concept_encoder_placeholder
def _concept_encoder_placeholder(self, concept_list: list[int] | dict) -> torch.Tensor:
"""
Placeholder for a more sophisticated concept encoder.
For now, it assumes concept_list directly contains integer concept IDs or a dict with 'abacus_pattern'.
"""
if isinstance(concept_list, dict) and "abacus_pattern" in concept_list:
# If an abacus pattern is provided, encode it using the AbacusEncoder
abacus_pattern = concept_list["abacus_pattern"]
# Ensure abacus_pattern is a tensor and has the correct shape
if not isinstance(abacus_pattern, torch.Tensor):
abacus_pattern = torch.tensor(abacus_pattern, dtype=torch.float32)
# Add batch dimension if missing
if abacus_pattern.dim() == 1:
abacus_pattern = abacus_pattern.unsqueeze(0)
return self.abacus_encoder(abacus_pattern)
elif isinstance(concept_list, list):
# Otherwise, assume it's a list of integer concept IDs
return torch.tensor(concept_list, dtype=torch.long, device=self.gpt.get_device())
else:
raise ValueError("concept_list must be a list of integers or a dict with 'abacus_pattern'.")
def _concept_memory_retrieve_placeholder(self, current_embedding: torch.Tensor) -> torch.Tensor:
# Placeholder for concept memory retrieval logic
# This would typically involve searching a memory bank for relevant concepts
# based on the current_embedding and returning their embeddings.
# For now, it returns a zero tensor of appropriate shape.
return torch.zeros_like(current_embedding)
def _concept_attention_fusion_placeholder(self, transformer_output: torch.Tensor, retrieved_concepts: torch.Tensor) -> torch.Tensor:
# Placeholder for concept attention fusion logic
# This would combine the transformer's output with the retrieved concept embeddings
# using some attention mechanism.
# For now, it just returns the transformer_output unchanged.
return transformer_output
def sample_from_logits(self, concept_logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor:
# Apply temperature
if temperature == 0.0:
next_concept_id = torch.argmax(concept_logits, dim=-1)
else:
concept_logits = concept_logits / temperature
# Apply top-k filtering
if top_k is not None:
v, _ = torch.topk(concept_logits, min(top_k, concept_logits.size(-1)))
concept_logits[concept_logits < v[:, [-1]]] = -float('Inf')
probs = torch.softmax(concept_logits, dim=-1)
next_concept_id = torch.multinomial(probs, num_samples=1).squeeze(-1)
return next_concept_id
@torch.no_grad()
def generate(self, input_embeddings: torch.Tensor, max_new_concepts: int = 20, temperature: float = 1.0, abacus_embedding: torch.Tensor | None = None, working_memory_window: int = 0) -> list[int]:
B, T, C = input_embeddings.size()
generated_embeddings = []
if abacus_embedding is None:
abacus_embedding = torch.zeros(B, 1, self.config.abacus_input_dim, device=input_embeddings.device)
# Long-Term Memory retrieval for prefill
prefill_long_term_memory_embeddings = self.long_term_memory.retrieve(input_embeddings[:, -1, :].squeeze(0))
# Get psyche weights from PsycheController for prefill
prefill_psyche_weights = self.psyche_controller(input_embeddings[:, -1, :])
# Prefill the model with input_embeddings
concept_logits, kv_cache, x_id_prefill, x_ego_prefill, x_superego_prefill = self.model.forward_prefill(
input_embeddings,
abacus_embedding=abacus_embedding,
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
psyche_weights=prefill_psyche_weights
)
# Conscious Integration Layer for prefill
synthesized_state_prefill = self.conscious_integration_layer.forward(
id_output=x_id_prefill,
ego_output=x_ego_prefill,
superego_output=x_superego_prefill,
long_term_memory_embeddings=prefill_long_term_memory_embeddings,
memetic_fitness=None, # Memetic fitness is not available during prefill
abacus_state=abacus_embedding
)
# Combine concept_logits from GPT and synthesized_state_prefill
concept_logits = concept_logits + synthesized_state_prefill
# Sample the first concept ID from the last token's logits
next_concept_id = self.sample_from_logits(concept_logits[:, -1, :], temperature)
next_embedding = self.concept_embedding_layer(next_concept_id)
generated_embeddings.append(next_embedding)
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
# Abacus State Memory and Encoder integration for the first step
abacus_pattern = self.abacus_encoder(next_embedding)
self.abacus_state_memory.store(abacus_pattern)
abacus_embedding = self.abacus_state_memory.retrieve()
for _ in range(max_new_concepts - 1):
# Working Memory retrieval
episodic_kv = None
if working_memory_window > 0 and kv_cache.get_pos() > working_memory_window:
start_pos = kv_cache.get_pos() - working_memory_window
end_pos = kv_cache.get_pos()
episodic_kv = kv_cache.retrieve_episode(start_pos, end_pos)
# Long-Term Memory retrieval
long_term_memory_embeddings = self.long_term_memory.retrieve(next_embedding.squeeze(0))
# Get psyche weights from PsycheController
psyche_weights = self.psyche_controller(next_embedding)
concept_logits, kv_cache, x_id_step, x_ego_step, x_superego_step = self.model.forward_step(
next_embedding,
kv_cache,
abacus_embedding=abacus_embedding,
episodic_kv=episodic_kv,
long_term_memory_embeddings=long_term_memory_embeddings,
psyche_weights=psyche_weights
)
# Memetic Learning Layer integration
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
# Conscious Integration Layer for step
synthesized_state_step = self.conscious_integration_layer.forward(
id_output=x_id_step,
ego_output=x_ego_step,
superego_output=x_superego_step,
long_term_memory_embeddings=long_term_memory_embeddings,
memetic_fitness=memetic_fitness,
abacus_state=abacus_embedding
)
# Combine concept_logits from GPT and synthesized_state_step
concept_logits = concept_logits + synthesized_state_step.squeeze(1) # Squeeze to match dimensions
next_concept_id = self.sample_from_logits(concept_logits, temperature)
next_embedding = self.concept_embedding_layer(next_concept_id)
generated_embeddings.append(next_embedding)
self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding
# Abacus State Memory and Encoder integration for subsequent steps
abacus_pattern = self.abacus_encoder(next_embedding)
self.abacus_state_memory.store(abacus_pattern)
abacus_embedding = self.abacus_state_memory.retrieve() # Update abacus_embedding for the next step
# Memetic Learning Layer integration
memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern)
return torch.stack(generated_embeddings, dim=1)
def generate_from_concepts(self, concept_list: list[int] | dict, max_new_concepts: int = 20, temperature: float = 1.0) -> list[int]:
# Encode the concept_list into initial input embeddings
encoded_concepts = self.concept_encoder(concept_list)
if encoded_concepts.dtype == torch.long:
# If it's concept IDs, get embeddings from the concept_embedding_layer
input_embeddings = self.concept_embedding_layer(encoded_concepts)
abacus_embedding = None # No abacus embedding in this case
elif encoded_concepts.dtype == torch.float:
# If it's an abacus embedding, use it directly and set abacus_embedding
input_embeddings = encoded_concepts
abacus_embedding = encoded_concepts # The abacus embedding is the input embedding itself
else:
raise TypeError("Unexpected return type from concept_encoder.")
# Call the main generate method
return self.generate(input_embeddings, max_new_concepts, temperature, abacus_embedding=abacus_embedding)
# Special tokens are no longer directly used with concept embeddings.
# The tool use logic will need to be re-evaluated or removed if not applicable.
# get_special = lambda s: self.tokenizer.encode_special(s)
# python_start = get_special("<|python_start|>")
# python_end = get_special("<|python_end|>")
# output_start = get_special("<|output_start|>")
# output_end = get_special("<|output_end|>")
# assistant_end = get_special("<|assistant_end|>") # if sampled, ends row
# bos = self.tokenizer.get_bos_token_id() # if sampled, ends row
# 1) Run a batch 1 prefill of the prompt embeddings
m = self.model.config
kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer}
kv_cache_prefill = KVCache(
batch_size=1,
seq_len=len(tokens),
batch_size=input_embeddings.size(0),
seq_len=input_embeddings.size(1),
**kv_model_kwargs,
)
ids = torch.tensor([tokens], dtype=torch.long, device=device)
logits = self.model.forward(ids, kv_cache=kv_cache_prefill)
logits = logits[:, -1, :]
next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1)
sampled_tokens = next_ids[:, 0].tolist()
logits = self.model.forward(input_embeddings, kv_cache=kv_cache_prefill)
# Removed token-based sampling logic (replace with embedding generation logic later)
# 2) Replicate the KV cache for each sample/row
kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_length_hint = (input_embeddings.size(1) + max_tokens) if max_tokens is not None else self.model.config.sequence_len
kv_cache_decode = KVCache(
batch_size=num_samples,
seq_len=kv_length_hint,
@ -230,7 +467,7 @@ class Engine:
del kv_cache_prefill # no need to keep this memory around
# 3) Initialize states for each sample
row_states = [RowState(tokens.copy()) for _ in range(num_samples)]
row_states = [RowState(input_embeddings[i].tolist()) for i in range(num_samples)] # Assuming input_embeddings is (B, T, C)
# 4) Main generation loop
num_generated = 0

View File

@ -25,13 +25,124 @@ from nanochat.adamw import DistAdamW
@dataclass
class GPTConfig:
sequence_len: int = 1024
vocab_size: int = 50304
n_layer: int = 12
n_head: int = 6 # number of query heads
n_kv_head: int = 6 # number of key/value heads (MQA)
n_embd: int = 768
def __init__(self,
n_layer=12,
n_head=12,
n_embd=768,
sequence_len=1024,
n_kv_head=None,
num_concept_ids=4096, # Updated to 4096 for n=12 hypercube
abacus_input_dim=64, # Default input dimension for the AbacusEncoder
dropout=0.0,
bias=False,
multiple_of=256,
norm_eps=1e-5,
rope_theta=10000,
# For training
batch_size=1,
gradient_accumulation_steps=1,
max_iters=0,
lr=6e-4,
min_lr=6e-5,
weight_decay=1e-1,
beta1=0.9,
beta2=0.95,
grad_clip=1.0,
decay_lr=True,
warmup_iters=2000,
lr_decay_iters=600000,
# For checkpointing
out_dir='out',
eval_interval=2000,
log_interval=1,
eval_iters=200,
eval_only=False,
always_save_checkpoint=True,
# For distributed training
backend='nccl',
# For system
device='cpu',
dtype='bfloat16',
compile=False,
# For data
dataset='openwebtext',
# For inference
init_from='scratch',
# For chat
chat=False,
# For concept
concept_memory_size=1000,
concept_memory_top_k=5,
use_concept_attention=False,
# For psyche
psyche_id_lr_scale=1.0,
psyche_ego_lr_scale=1.0,
psyche_superego_lr_scale=1.0,
**kwargs):
self.n_layer = n_layer
self.n_head = n_head
self.n_embd = n_embd
self.sequence_len = sequence_len
self.n_kv_head = n_kv_head if n_kv_head is not None else n_head
self.num_concept_ids = num_concept_ids
self.dropout = dropout
self.bias = bias
self.multiple_of = multiple_of
self.norm_eps = norm_eps
self.rope_theta = rope_theta
self.batch_size = batch_size
self.gradient_accumulation_steps = gradient_accumulation_steps
self.max_iters = max_iters
self.lr = lr
self.min_lr = min_lr
self.weight_decay = weight_decay
self.beta1 = beta1
self.beta2 = beta2
self.grad_clip = grad_clip
self.decay_lr = decay_lr
self.warmup_iters = warmup_iters
self.lr_decay_iters = lr_decay_iters
self.out_dir = out_dir
self.eval_interval = eval_interval
self.log_interval = log_interval
self.eval_iters = eval_iters
self.eval_only = eval_only
self.always_save_checkpoint = always_save_checkpoint
self.backend = backend
self.device = device
self.dtype = dtype
self.compile = compile
self.dataset = dataset
self.init_from = init_from
self.chat = chat
self.concept_memory_size = concept_memory_size
self.concept_memory_top_k = concept_memory_top_k
self.use_concept_attention = use_concept_attention
self.psyche_id_lr_scale = psyche_id_lr_scale
self.psyche_ego_lr_scale = psyche_ego_lr_scale
self.psyche_superego_lr_scale = psyche_superego_lr_scale
for k, v in kwargs.items():
setattr(self, k, v)
def norm(x):
# Purely functional rmsnorm with no learnable params
@ -63,7 +174,7 @@ class CausalSelfAttention(nn.Module):
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
def forward(self, x, cos_sin, kv_cache):
def forward(self, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
B, T, C = x.size()
# Project the input to get queries, keys, and values
@ -80,6 +191,14 @@ class CausalSelfAttention(nn.Module):
# Apply KV cache: insert current k,v into cache, get the full view so far
if kv_cache is not None:
k, v = kv_cache.insert_kv(self.layer_idx, k, v)
# If episodic_kv is provided, prepend it to the current k and v
if episodic_kv is not None:
episode_k_layer = episodic_kv[self.layer_idx, 0]
episode_v_layer = episodic_kv[self.layer_idx, 1]
k = torch.cat([episode_k_layer, k], dim=2)
v = torch.cat([episode_v_layer, v], dim=2)
Tq = q.size(2) # number of queries in this forward pass
Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass)
@ -135,15 +254,42 @@ class Block(nn.Module):
return x
class PsycheController(nn.Module):
def __init__(self, config):
super().__init__()
self.controller_head = nn.Linear(config.n_embd, 3) # 3 for id, ego, superego
def forward(self, x):
# For now, just return equal weights for each psyche layer
# In the future, this will be trained to dynamically blend psyche outputs
return torch.softmax(self.controller_head(x), dim=-1)
class GPT(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.transformer = nn.ModuleDict({
"wte": nn.Embedding(config.vocab_size, config.n_embd),
"h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]),
})
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Partition transformer layers into psyche layers
total_layers = config.n_layer
id_end = total_layers // 3
ego_end = 2 * total_layers // 3
self.id_layers = self.transformer.h[:id_end]
self.ego_layers = self.transformer.h[id_end:ego_end]
self.superego_layers = self.transformer.h[ego_end:]
self.psyche_registry = {
"id": self.id_layers,
"ego": self.ego_layers,
"superego": self.superego_layers
}
self.concept_head = nn.Linear(config.n_embd, config.num_concept_ids, bias=False) # New concept head
self.psyche_controller = PsycheController(config) # Initialize PsycheController
# To support meta device initialization, we init the rotary embeddings here, but it's fake
# As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory,
# so let's just over-compute them, but assert fail if we ever reach that amount.
@ -156,8 +302,8 @@ class GPT(nn.Module):
def init_weights(self):
self.apply(self._init_weights)
# zero out classifier weights
torch.nn.init.zeros_(self.lm_head.weight)
# zero out concept_head weights
torch.nn.init.zeros_(self.concept_head.weight)
# zero out c_proj weights in all blocks
for block in self.transformer.h:
torch.nn.init.zeros_(block.mlp.c_proj.weight)
@ -166,9 +312,6 @@ class GPT(nn.Module):
head_dim = self.config.n_embd // self.config.n_head
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
self.cos, self.sin = cos, sin
# Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations
if self.transformer.wte.weight.device.type == "cuda":
self.transformer.wte.to(dtype=torch.bfloat16)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
@ -186,7 +329,7 @@ class GPT(nn.Module):
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
# autodetect the device from model embeddings
if device is None:
device = self.transformer.wte.weight.device
device = self.concept_head.weight.device
# stride the channels
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
@ -200,12 +343,13 @@ class GPT(nn.Module):
return cos, sin
def get_device(self):
return self.transformer.wte.weight.device
# Get device from concept_head weight
return self.concept_head.weight.device
def estimate_flops(self):
""" Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """
nparams = sum(p.numel() for p in self.parameters())
nparams_embedding = self.transformer.wte.weight.numel()
nparams_embedding = 0 # No separate embedding layer now
l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
return num_flops_per_token
@ -213,95 +357,190 @@ class GPT(nn.Module):
def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
# Separate out all parameters into 3 groups (matrix, embedding, lm_head)
matrix_params = list(self.transformer.h.parameters())
embedding_params = list(self.transformer.wte.parameters())
lm_head_params = list(self.lm_head.parameters())
assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params)
# Create the AdamW optimizer for the embedding and lm_head
# Separate out all parameters into 3 groups (matrix, embedding, concept_head)
# matrix_params = list(self.transformer.h.parameters())
id_params = list(self.id_layers.parameters())
ego_params = list(self.ego_layers.parameters())
superego_params = list(self.superego_layers.parameters())
embedding_params = [] # No separate embedding layer now
concept_head_params = list(self.concept_head.parameters()) # New concept head params
psyche_controller_params = list(self.psyche_controller.parameters()) # Psyche controller params
# assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(concept_head_params)
# Create the AdamW optimizer for the embedding and concept_head
# Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model)
dmodel_lr_scale = (model_dim / 768) ** -0.5
if rank == 0:
print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}")
adam_groups = [
dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale),
dict(params=concept_head_params, lr=unembedding_lr * dmodel_lr_scale), # Use unembedding_lr for concept_head
dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale),
dict(params=id_params, lr=3e-4 * dmodel_lr_scale), # Id layers learning rate
dict(params=ego_params, lr=1e-4 * dmodel_lr_scale), # Ego layers learning rate
dict(params=superego_params, lr=5e-5 * dmodel_lr_scale), # Superego layers learning rate
dict(params=psyche_controller_params, lr=1e-4 * dmodel_lr_scale), # Psyche controller learning rate
]
adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay)
AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True)
adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs)
# Create the Muon optimizer for the linear layers
muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
MuonFactory = DistMuon if ddp else Muon
muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# muon_kwargs = dict(lr=matrix_lr, momentum=0.95)
# MuonFactory = DistMuon if ddp else Muon
# muon_optimizer = MuonFactory(matrix_params, **muon_kwargs)
# Combine them the two optimizers into one list
optimizers = [adamw_optimizer, muon_optimizer]
optimizers = [adamw_optimizer]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
return optimizers
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size()
def _run_layers(self, layers, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None):
for block in layers:
x = block(x, cos_sin, kv_cache, episodic_kv)
return x
def forward(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, T, C = input_embeddings.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
# if kv cache exists, we need to offset the rotary embeddings to the current position in the cache
T0 = 0 if kv_cache is None else kv_cache.get_pos()
cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length
cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
x = input_embeddings
# Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Final concept head
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
def forward_prefill(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, T, C = input_embeddings.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :]
x = input_embeddings
# Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Final concept head
return self.concept_head(x), kv_cache, x_id, x_ego, x_superego
def forward_step(self, next_embedding: torch.Tensor, kv_cache, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None):
B, C = next_embedding.size()
T = kv_cache[0].size(1) + 1 # Current sequence length after adding next_embedding
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
cos_sin = self.cos[:, T-1:T, :, :], self.sin[:, T-1:T, :, :]
x = next_embedding.unsqueeze(1) # Add sequence dimension for consistency
# Process Id layers
x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv)
# Process Ego layers
x_ego = x_id
if abacus_embedding is not None:
# Broadcast abacus_embedding to match the sequence length of x_ego
# Assuming abacus_embedding is (B, C) and x_ego is (B, T, C)
abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_ego
# Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1)
x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer
x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv)
# Process Superego layers
x_superego = x_ego
if long_term_memory_embeddings is not None:
# Broadcast long_term_memory_embeddings to match the sequence length of x_superego
# Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C)
long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1)
x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv)
# Dynamically blend the outputs based on psyche_weights
# Reshape psyche_weights for broadcasting: (B, 1, 3)
psyche_weights_reshaped = psyche_weights.unsqueeze(1)
# Stack the outputs and apply weighted sum
# Stack will result in (B, T, 3, C)
stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2)
# Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C)
x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2)
# Forward the trunk of the Transformer
x = self.transformer.wte(idx)
x = norm(x)
for block in self.transformer.h:
x = block(x, cos_sin, kv_cache)
x = norm(x)
# Forward the lm_head (compute logits)
softcap = 15
if targets is not None:
# training mode: compute and return the loss
# TODO: experiment with Liger Kernels / chunked cross-entropy etc.
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
logits = logits.float() # use tf32/fp32 for logits
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction)
return loss
else:
# inference mode: compute and return the logits
logits = self.lm_head(x)
logits = softcap * torch.tanh(logits / softcap) # logits softcap
return logits
@torch.inference_mode()
def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42):
"""
Naive autoregressive streaming inference.
To make it super simple, let's assume:
- batch size is 1
- ids and the yielded tokens are simple Python lists and ints
"""
assert isinstance(tokens, list)
device = self.get_device()
rng = None
if temperature > 0:
rng = torch.Generator(device=device)
rng.manual_seed(seed)
ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim
for _ in range(max_tokens):
logits = self.forward(ids) # (B, T, vocab_size)
logits = logits[:, -1, :] # (B, vocab_size)
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_ids = torch.multinomial(probs, num_samples=1, generator=rng)
else:
next_ids = torch.argmax(logits, dim=-1, keepdim=True)
ids = torch.cat((ids, next_ids), dim=1)
token = next_ids.item()
yield token
return self.concept_head(x.squeeze(1)), kv_cache, x_id, x_ego, x_superego

149
nanochat/hypercube.py Normal file
View File

@ -0,0 +1,149 @@
import torch
import torch.nn as nn
class HypercubeTopology:
def __init__(self, n: int):
if not (1 <= n <= 16): # Limiting n for practical reasons, can be adjusted
raise ValueError("Hypercube dimension 'n' must be between 1 and 16.")
self.n = n
self.num_vertices = 2**n
def to_binary(self, vertex_id: int) -> str:
if not (0 <= vertex_id < self.num_vertices):
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
return format(vertex_id, f'0{self.n}b')
def from_binary(self, binary_string: str) -> int:
if len(binary_string) != self.n:
raise ValueError(f"Binary string length must be {self.n} for a {self.n}-dimensional hypercube.")
return int(binary_string, 2)
def get_neighbors(self, vertex_id: int) -> list[int]:
if not (0 <= vertex_id < self.num_vertices):
raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.")
neighbors = []
for i in range(self.n):
# Flip the i-th bit
neighbor_id = vertex_id ^ (1 << i)
neighbors.append(neighbor_id)
return neighbors
def get_all_edges(self) -> list[tuple[int, int]]:
edges = []
for i in range(self.num_vertices):
for neighbor in self.get_neighbors(i):
# Ensure each edge is added only once (e.g., (0,1) not (1,0))
if i < neighbor:
edges.append((i, neighbor))
return edges
def get_random_vertex(self) -> int:
return torch.randint(0, self.num_vertices, (1,)).item()
def get_random_path(self, start_vertex: int, length: int) -> list[int]:
if not (0 <= start_vertex < self.num_vertices):
raise ValueError(f"Start vertex ID {start_vertex} is out of range.")
if length < 1:
raise ValueError("Path length must be at least 1.")
path = [start_vertex]
current_vertex = start_vertex
for _ in range(length - 1):
neighbors = self.get_neighbors(current_vertex)
if not neighbors:
break # Should not happen in a hypercube unless n=0
# Choose a random neighbor that is not the previous vertex if possible
next_vertex = neighbors[torch.randint(0, len(neighbors), (1,)).item()]
path.append(next_vertex)
current_vertex = next_vertex
return path
class HypercubeEmbeddingLayer(nn.Module):
def __init__(self, num_raw_concepts, embedding_dim, hypercube_topology):
super().__init__()
self.raw_embedding_layer = nn.Embedding(num_raw_concepts, embedding_dim)
self.hypercube_topology = hypercube_topology
# Embeddings for hypercube vertices. The number of vertices is hypercube_topology.num_vertices
self.vertex_embeddings = nn.Embedding(hypercube_topology.num_vertices, embedding_dim)
def forward(self, concept_ids):
# Get initial embeddings for the input concept_ids
initial_embeddings = self.raw_embedding_layer(concept_ids) # (batch_size, embedding_dim)
# Find the nearest hypercube vertex for each initial embedding
# Get all hypercube vertex embeddings
# Ensure all_vertex_ids are on the same device as concept_ids
all_vertex_ids = torch.arange(self.hypercube_topology.num_vertices, device=concept_ids.device)
all_vertex_embeddings = self.vertex_embeddings(all_vertex_ids) # (num_vertices, embedding_dim)
# Calculate squared Euclidean distance
# initial_embeddings_expanded: (batch_size, 1, embedding_dim)
# all_vertex_embeddings_expanded: (1, num_vertices, embedding_dim)
initial_embeddings_expanded = initial_embeddings.unsqueeze(1)
all_vertex_embeddings_expanded = all_vertex_embeddings.unsqueeze(0)
distances = torch.sum((initial_embeddings_expanded - all_vertex_embeddings_expanded)**2, dim=2) # (batch_size, num_vertices)
# Find the index of the nearest vertex for each initial embedding
nearest_vertex_indices = torch.argmin(distances, dim=1) # (batch_size,)
# Retrieve the embeddings of the nearest vertices
final_embeddings = self.vertex_embeddings(nearest_vertex_indices)
return final_embeddings
class LongTermMemory(nn.Module):
def __init__(self, embedding_dim: int, max_memory_size: int = 1000, top_k_retrieval: int = 5):
super().__init__()
self.embedding_dim = embedding_dim
self.max_memory_size = max_memory_size
self.top_k_retrieval = top_k_retrieval
# Initialize an empty memory bank. We'll use a list of tensors for simplicity initially,
# but this could be replaced with a more efficient data structure or a learnable embedding layer.
self.memory_bank = []
self.memory_bank_tensor = None # Will store concatenated memories for efficient retrieval
def store(self, embedding: torch.Tensor):
# Store a new embedding in the memory bank
if len(self.memory_bank) >= self.max_memory_size:
# Simple FIFO eviction for now
self.memory_bank.pop(0)
self.memory_bank.append(embedding.detach().cpu()) # Store on CPU to save GPU memory
self.memory_bank_tensor = None # Invalidate cached tensor
def retrieve(self, query_embedding: torch.Tensor) -> torch.Tensor:
# Retrieve top-k most similar embeddings from the memory bank
if not self.memory_bank:
return torch.zeros(query_embedding.shape[0], self.top_k_retrieval, self.embedding_dim, device=query_embedding.device)
if self.memory_bank_tensor is None:
self.memory_bank_tensor = torch.stack(self.memory_bank).to(query_embedding.device)
# Normalize query and memory bank for cosine similarity
query_norm = F.normalize(query_embedding, p=2, dim=-1)
memory_norm = F.normalize(self.memory_bank_tensor, p=2, dim=-1)
# Calculate cosine similarity
# query_norm: (batch_size, embedding_dim)
# memory_norm: (num_memories, embedding_dim)
# similarities: (batch_size, num_memories)
similarities = torch.matmul(query_norm, memory_norm.transpose(0, 1))
# Get top-k similar memories
# top_k_values: (batch_size, top_k_retrieval)
# top_k_indices: (batch_size, top_k_retrieval)
top_k_values, top_k_indices = torch.topk(similarities, min(self.top_k_retrieval, len(self.memory_bank)), dim=-1)
# Retrieve the actual embeddings
retrieved_memories = self.memory_bank_tensor[top_k_indices]
return retrieved_memories
def forward(self, query_embedding: torch.Tensor) -> torch.Tensor:
return self.retrieve(query_embedding)

View File

@ -0,0 +1,46 @@
import torch
import torch.nn as nn
class MemeticLearningLayer(nn.Module):
def __init__(self, config, abacus_encoder, internal_self_model):
super().__init__()
self.config = config
self.abacus_encoder = abacus_encoder
self.internal_self_model = internal_self_model
# Placeholder for memetic fitness evaluation mechanism
# This could be a simple linear layer or a more complex network
self.fitness_evaluator = nn.Linear(config.abacus_input_dim, 1)
# Placeholder for concept mapping expansion (analogy/metaphor)
# This might involve a transformation of embeddings or a retrieval mechanism
self.concept_mapper = nn.Sequential(
nn.Linear(config.n_embd * 2, config.n_embd), # Input: two concept embeddings
nn.ReLU(),
nn.Linear(config.n_embd, config.n_embd)
)
def evaluate_memetic_fitness(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
# Evaluate the fitness of a memetic pattern using the abacus encoder output
fitness_score = self.fitness_evaluator(abacus_pattern)
return fitness_score
def expand_concept_mapping(self, concept1_embedding: torch.Tensor, concept2_embedding: torch.Tensor) -> torch.Tensor:
# Expand concept mapping via analogy and metaphor
# This takes two concept embeddings and generates a new, related concept embedding
combined_concepts = torch.cat([concept1_embedding, concept2_embedding], dim=-1)
new_concept_embedding = self.concept_mapper(combined_concepts)
return new_concept_embedding
def forward(self, current_concept_embedding: torch.Tensor, abacus_pattern: torch.Tensor):
# Orchestrates the memetic learning process
# 1. Evaluate memetic fitness
fitness = self.evaluate_memetic_fitness(abacus_pattern)
# 2. Potentially update internal self-model based on fitness or new concepts
# self.internal_self_model.update_beliefs(current_concept_embedding) # Example interaction
# 3. Generate new concepts or analogies
# new_concept = self.expand_concept_mapping(current_concept_embedding, some_other_concept)
return fitness

58
nanochat/self_model.py Normal file
View File

@ -0,0 +1,58 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class InternalSelfModel(nn.Module):
def __init__(self, embedding_dim: int, config):
super().__init__()
self.embedding_dim = embedding_dim
self.config = config
# Placeholder for belief representation (e.g., a learned embedding or a set of parameters)
self.belief_embedding = nn.Parameter(torch.randn(embedding_dim))
# Placeholder for a simple prediction head to calculate prediction error
self.prediction_head = nn.Linear(embedding_dim, 1) # Predicts a scalar error signal
# Placeholder for conceptual growth mechanism (e.g., a small MLP or attention mechanism)
self.conceptual_growth_mlp = nn.Sequential(
nn.Linear(embedding_dim * 2, embedding_dim), # Input: current belief + new concept
nn.ReLU(),
nn.Linear(embedding_dim, embedding_dim)
)
def update_beliefs(self, current_concept_embedding: torch.Tensor):
# This method would update the internal belief representation based on new information.
# For now, a simple update rule (e.g., moving average or attention-based update)
# In a more advanced implementation, this could involve a recurrent neural network.
# Example: simple weighted average
alpha = 0.1 # Learning rate for belief update
self.belief_embedding.data = (1 - alpha) * self.belief_embedding.data + alpha * current_concept_embedding.mean(dim=0)
def calculate_prediction_error(self, predicted_output: torch.Tensor, actual_output: torch.Tensor) -> torch.Tensor:
# This method calculates the discrepancy between predicted and actual outcomes.
# For now, a simple mean squared error.
error = F.mse_loss(predicted_output, actual_output)
return error
def promote_conceptual_growth(self, current_belief_embedding: torch.Tensor, new_concept_embedding: torch.Tensor) -> torch.Tensor:
# This method facilitates the integration of new concepts into the existing conceptual framework.
# For now, a simple MLP that takes the concatenation of current belief and new concept.
combined_embedding = torch.cat([current_belief_embedding, new_concept_embedding], dim=-1)
updated_concept_embedding = self.conceptual_growth_mlp(combined_embedding)
return updated_concept_embedding
def forward(self, current_concept_embedding: torch.Tensor, predicted_output: torch.Tensor, actual_output: torch.Tensor):
# Update beliefs
self.update_beliefs(current_concept_embedding)
# Calculate prediction error
error = self.calculate_prediction_error(predicted_output, actual_output)
# Promote conceptual growth (example: using the current belief and concept embedding)
# This part would be more complex in a full implementation, potentially involving attention
# or other mechanisms to decide how to grow concepts based on error and new info.
# For demonstration, let's assume conceptual growth is triggered by new concept embeddings.
# updated_concept = self.promote_conceptual_growth(self.belief_embedding, current_concept_embedding)
return error, self.belief_embedding