diff --git a/config/custom_config.py b/config/custom_config.py new file mode 100644 index 0000000..4a66e49 --- /dev/null +++ b/config/custom_config.py @@ -0,0 +1,2 @@ +n_layer = 24 +n_embd = 1024 \ No newline at end of file diff --git a/nanochat/abacus_encoder.py b/nanochat/abacus_encoder.py new file mode 100644 index 0000000..8f86c81 --- /dev/null +++ b/nanochat/abacus_encoder.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + +class AbacusEncoder(nn.Module): + def __init__(self, input_dim: int, embedding_dim: int): + super().__init__() + self.input_dim = input_dim + self.embedding_dim = embedding_dim + + # Simple linear layer to encode abacus-like patterns into the embedding space + self.encoder_layer = nn.Linear(input_dim, embedding_dim) + + def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor: + # abacus_pattern is expected to be a tensor of shape (batch_size, input_dim) + if abacus_pattern.shape[-1] != self.input_dim: + raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}") + return self.encoder_layer(abacus_pattern) + + def decode(self, concept_vector: torch.Tensor) -> torch.Tensor: + # Placeholder for decoding functionality + raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.") + + def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor: + return self.encode(abacus_pattern) \ No newline at end of file diff --git a/nanochat/abacus_state_memory.py b/nanochat/abacus_state_memory.py new file mode 100644 index 0000000..e94b0a4 --- /dev/null +++ b/nanochat/abacus_state_memory.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class AbacusStateMemory(nn.Module): + def __init__(self, max_memory_size: int = 100, abacus_input_dim: int = 64): + super().__init__() + self.max_memory_size = max_memory_size + self.abacus_input_dim = abacus_input_dim + self.memory = [] # Stores abacus patterns (tensors) + + def store(self, abacus_pattern: torch.Tensor): + if len(self.memory) >= self.max_memory_size: + self.memory.pop(0) # Remove the oldest entry + self.memory.append(abacus_pattern.detach().cpu()) # Store detached CPU tensor + + def retrieve(self, num_to_retrieve: int = 1) -> list[torch.Tensor]: + # Retrieve the most recent abacus patterns + return self.memory[-num_to_retrieve:] + + def clear(self): + self.memory = [] + + def forward(self, abacus_pattern: torch.Tensor): + # For now, forward simply stores the pattern. More complex logic can be added later. + self.store(abacus_pattern) + return abacus_pattern \ No newline at end of file diff --git a/nanochat/conscious_integration.py b/nanochat/conscious_integration.py new file mode 100644 index 0000000..c9241a8 --- /dev/null +++ b/nanochat/conscious_integration.py @@ -0,0 +1,48 @@ +import torch +import torch.nn as nn + +from nanochat.abacus_encoder import AbacusEncoder + +class ConsciousIntegrationLayer(nn.Module): + def __init__(self, config, abacus_encoder: AbacusEncoder): + super().__init__() + self.config = config + self.abacus_encoder = abacus_encoder + + # Linear layer to project the integrated state to the vocabulary size + self.concept_projection = nn.Linear(config.n_embd, config.vocab_size) + + def forward(self, id_output: torch.Tensor, ego_output: torch.Tensor, superego_output: torch.Tensor, long_term_memory_embeddings: torch.Tensor, memetic_fitness: torch.Tensor | None, abacus_state: torch.Tensor) -> torch.Tensor: + # Ensure all inputs are of the same shape for integration + # For simplicity, let's assume they are all (B, T, C) or (B, C) + # If they are (B, C), we might need to unsqueeze for sequence dimension if T > 1 + + # Example integration: simple summation. More complex mechanisms can be added here. + # The goal is to synthesize these into a unified conceptual state. + synthesized_state = id_output + ego_output + superego_output + + if long_term_memory_embeddings is not None: + # Ensure dimensions match for addition. long_term_memory_embeddings might be (B, C) + # If synthesized_state is (B, T, C), expand long_term_memory_embeddings + if synthesized_state.dim() == 3 and long_term_memory_embeddings.dim() == 2: + long_term_memory_embeddings = long_term_memory_embeddings.unsqueeze(1).expand(-1, synthesized_state.size(1), -1) + synthesized_state = synthesized_state + long_term_memory_embeddings + + if memetic_fitness is not None: + # Ensure dimensions match for addition + if synthesized_state.dim() == 3 and memetic_fitness.dim() == 2: + memetic_fitness = memetic_fitness.unsqueeze(1).expand(-1, synthesized_state.size(1), -1) + synthesized_state = synthesized_state + memetic_fitness + + # Abacus Encoder for logical consistency + # The abacus_state is already an encoded representation, so we might use it to gate or modulate + # the synthesized_state, or simply include it in the synthesis. + # For now, let's add it, assuming it's compatible. + if synthesized_state.dim() == 3 and abacus_state.dim() == 2: + abacus_state = abacus_state.unsqueeze(1).expand(-1, synthesized_state.size(1), -1) + synthesized_state = synthesized_state + abacus_state + + # Project the synthesized state to the vocabulary size + concept_logits_from_synthesis = self.concept_projection(synthesized_state) + + return concept_logits_from_synthesis \ No newline at end of file diff --git a/nanochat/engine.py b/nanochat/engine.py index 44ed16b..22e8e6f 100644 --- a/nanochat/engine.py +++ b/nanochat/engine.py @@ -11,7 +11,9 @@ Notes: The whole thing is made as efficient as possible. """ +import math import torch +import torch.nn as nn import torch.nn.functional as F import signal import warnings @@ -19,6 +21,15 @@ from contextlib import contextmanager from collections import deque from nanochat.common import compute_init from nanochat.checkpoint_manager import load_model +from nanochat.gpt import GPT +from nanochat.kv_cache import KVCache +from nanochat.hypercube import HypercubeTopology, HypercubeEmbeddingLayer, LongTermMemory # Import HypercubeTopology +from nanochat.abacus_encoder import AbacusEncoder +from nanochat.self_model import InternalSelfModel +from nanochat.abacus_state_memory import AbacusStateMemory +from nanochat.memetic_learning import MemeticLearningLayer +from nanochat.conscious_integration import ConsciousIntegrationLayer +from nanochat.gpt import GPT, PsycheController # ----------------------------------------------------------------------------- # Calculator tool helpers @@ -151,76 +162,302 @@ class KVCache: self.pos = t1 return key_view, value_view + def retrieve_episode(self, start_pos, end_pos): + """ + Retrieves a slice of the KV cache representing an 'episode' or a specific attention span. + """ + assert self.kv_cache is not None, "KV cache is empty, cannot retrieve episode." + assert start_pos >= 0 and end_pos <= self.pos, "Invalid start or end position for episode retrieval." + assert start_pos < end_pos, "Start position must be less than end position." + + # Return a view of the relevant part of the cache + # kv_cache shape: (num_layers, 2, batch_size, num_heads, seq_len, head_dim) + # We want to retrieve for all layers, and both k and v + episode_k = self.kv_cache[:, 0, :, :, start_pos:end_pos, :] + episode_v = self.kv_cache[:, 1, :, :, start_pos:end_pos, :] + return episode_k, episode_v + # ----------------------------------------------------------------------------- @torch.inference_mode() -def sample_next_token(logits, rng, temperature=1.0, top_k=None): - """Sample a single next token from given logits of shape (B, vocab_size). Returns (B, 1).""" +def sample_from_logits(logits, rng, temperature, top_k): + """Sample a single next concept ID from given logits of shape (B, num_concept_ids). Returns (B, 1).""" assert temperature >= 0.0, "temperature must be non-negative" if temperature == 0.0: return torch.argmax(logits, dim=-1, keepdim=True) + if top_k is not None: - k = min(top_k, logits.size(-1)) - vals, idx = torch.topk(logits, k, dim=-1) - vals = vals / temperature - probs = F.softmax(vals, dim=-1) - choice = torch.multinomial(probs, num_samples=1, generator=rng) - return idx.gather(1, choice) - else: - logits = logits / temperature - probs = F.softmax(logits, dim=-1) - return torch.multinomial(probs, num_samples=1, generator=rng) + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + logits[logits < v[:, [-1]]] = -float('Inf') -# ----------------------------------------------------------------------------- + logits = logits / temperature + probs = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(probs, num_samples=1, generator=rng) + return idx_next -class RowState: - # Per-row state tracking during generation - def __init__(self, current_tokens=None): - self.current_tokens = current_tokens or [] # Current token sequence for this row - self.forced_tokens = deque() # Queue of tokens to force inject - self.in_python_block = False # Whether we are inside a python block - self.python_expr_tokens = [] # Tokens of the current python expression - self.completed = False # Whether this row has completed generation class Engine: - def __init__(self, model, tokenizer): - self.model = model - self.tokenizer = tokenizer # needed for tool use + def __init__(self, config: GPTConfig): + self.config = config + self.model = GPT(config) + self.model.eval() + self.model.init_weights() - @torch.inference_mode() - def generate(self, tokens, num_samples=1, max_tokens=None, temperature=1.0, top_k=None, seed=42): - """Same as generate, but does single prefill and then clones the KV cache.""" - assert isinstance(tokens, list) and isinstance(tokens[0], int), "expecting list of ints" - device = self.model.get_device() - rng = torch.Generator(device=device) - rng.manual_seed(seed) + # Initialize HypercubeEmbeddingLayer + self.concept_embedding_layer = HypercubeEmbeddingLayer( + num_raw_concepts=config.num_concept_ids, + embedding_dim=config.n_embd, + hypercube_topology=self.hypercube_topology + ) - # Get the special tokens we need to coordinate the tool use state machine - get_special = lambda s: self.tokenizer.encode_special(s) - python_start = get_special("<|python_start|>") - python_end = get_special("<|python_end|>") - output_start = get_special("<|output_start|>") - output_end = get_special("<|output_end|>") - assistant_end = get_special("<|assistant_end|>") # if sampled, ends row - bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + # Initialize AbacusEncoder + self.abacus_encoder = AbacusEncoder( + input_dim=config.abacus_input_dim, + embedding_dim=config.n_embd + ) + nn.init.normal_(self.concept_embedding_layer.weight, mean=0.0, std=1.0) - # 1) Run a batch 1 prefill of the prompt tokens + # Initialize HypercubeTopology + n_hypercube_dim = int(math.log2(config.num_concept_ids)) + if not (2**n_hypercube_dim == config.num_concept_ids): + raise ValueError("config.num_concept_ids must be a power of 2 for hypercube topology.") + self.hypercube_topology = HypercubeTopology(n_hypercube_dim) + + # Initialize LongTermMemory + self.long_term_memory = LongTermMemory( + embedding_dim=config.n_embd, + max_memory_size=1000, # Default max memory size + top_k_retrieval=5 # Default top-k retrieval + ) + + # Initialize InternalSelfModel + self.internal_self_model = InternalSelfModel( + embedding_dim=config.n_embd, + num_concepts=config.num_concept_ids + ) + + # Initialize AbacusStateMemory + self.abacus_state_memory = AbacusStateMemory( + max_memory_size=100, # Default max memory size + abacus_input_dim=config.abacus_input_dim + ) + + # Initialize MemeticLearningLayer + self.memetic_learning_layer = MemeticLearningLayer( + config=config, + abacus_encoder=self.abacus_encoder, + internal_self_model=self.internal_self_model + ) + + # Initialize PsycheController + self.psyche_controller = PsycheController( + config=config + ) + + # Initialize ConsciousIntegrationLayer + self.conscious_integration_layer = ConsciousIntegrationLayer( + config=config, + abacus_encoder=self.abacus_encoder + ) + + self._concept_encoder_placeholder = None # This will be set later + self._concept_attention_fusion_placeholder = nn.Linear(config.n_embd * 2, config.n_embd) + + # Placeholder for concept encoder + self.concept_encoder = self._concept_encoder_placeholder + + def _concept_encoder_placeholder(self, concept_list: list[int] | dict) -> torch.Tensor: + """ + Placeholder for a more sophisticated concept encoder. + For now, it assumes concept_list directly contains integer concept IDs or a dict with 'abacus_pattern'. + """ + if isinstance(concept_list, dict) and "abacus_pattern" in concept_list: + # If an abacus pattern is provided, encode it using the AbacusEncoder + abacus_pattern = concept_list["abacus_pattern"] + # Ensure abacus_pattern is a tensor and has the correct shape + if not isinstance(abacus_pattern, torch.Tensor): + abacus_pattern = torch.tensor(abacus_pattern, dtype=torch.float32) + # Add batch dimension if missing + if abacus_pattern.dim() == 1: + abacus_pattern = abacus_pattern.unsqueeze(0) + return self.abacus_encoder(abacus_pattern) + elif isinstance(concept_list, list): + # Otherwise, assume it's a list of integer concept IDs + return torch.tensor(concept_list, dtype=torch.long, device=self.gpt.get_device()) + else: + raise ValueError("concept_list must be a list of integers or a dict with 'abacus_pattern'.") + + def _concept_memory_retrieve_placeholder(self, current_embedding: torch.Tensor) -> torch.Tensor: + # Placeholder for concept memory retrieval logic + # This would typically involve searching a memory bank for relevant concepts + # based on the current_embedding and returning their embeddings. + # For now, it returns a zero tensor of appropriate shape. + return torch.zeros_like(current_embedding) + + def _concept_attention_fusion_placeholder(self, transformer_output: torch.Tensor, retrieved_concepts: torch.Tensor) -> torch.Tensor: + # Placeholder for concept attention fusion logic + # This would combine the transformer's output with the retrieved concept embeddings + # using some attention mechanism. + # For now, it just returns the transformer_output unchanged. + return transformer_output + + def sample_from_logits(self, concept_logits: torch.Tensor, temperature: float = 1.0, top_k: int = None) -> torch.Tensor: + # Apply temperature + if temperature == 0.0: + next_concept_id = torch.argmax(concept_logits, dim=-1) + else: + concept_logits = concept_logits / temperature + # Apply top-k filtering + if top_k is not None: + v, _ = torch.topk(concept_logits, min(top_k, concept_logits.size(-1))) + concept_logits[concept_logits < v[:, [-1]]] = -float('Inf') + probs = torch.softmax(concept_logits, dim=-1) + next_concept_id = torch.multinomial(probs, num_samples=1).squeeze(-1) + return next_concept_id + + @torch.no_grad() + def generate(self, input_embeddings: torch.Tensor, max_new_concepts: int = 20, temperature: float = 1.0, abacus_embedding: torch.Tensor | None = None, working_memory_window: int = 0) -> list[int]: + B, T, C = input_embeddings.size() + generated_embeddings = [] + + if abacus_embedding is None: + abacus_embedding = torch.zeros(B, 1, self.config.abacus_input_dim, device=input_embeddings.device) + + # Long-Term Memory retrieval for prefill + prefill_long_term_memory_embeddings = self.long_term_memory.retrieve(input_embeddings[:, -1, :].squeeze(0)) + + # Get psyche weights from PsycheController for prefill + prefill_psyche_weights = self.psyche_controller(input_embeddings[:, -1, :]) + + # Prefill the model with input_embeddings + concept_logits, kv_cache, x_id_prefill, x_ego_prefill, x_superego_prefill = self.model.forward_prefill( + input_embeddings, + abacus_embedding=abacus_embedding, + long_term_memory_embeddings=prefill_long_term_memory_embeddings, + psyche_weights=prefill_psyche_weights + ) + + # Conscious Integration Layer for prefill + synthesized_state_prefill = self.conscious_integration_layer.forward( + id_output=x_id_prefill, + ego_output=x_ego_prefill, + superego_output=x_superego_prefill, + long_term_memory_embeddings=prefill_long_term_memory_embeddings, + memetic_fitness=None, # Memetic fitness is not available during prefill + abacus_state=abacus_embedding + ) + + # Combine concept_logits from GPT and synthesized_state_prefill + concept_logits = concept_logits + synthesized_state_prefill + + # Sample the first concept ID from the last token's logits + next_concept_id = self.sample_from_logits(concept_logits[:, -1, :], temperature) + next_embedding = self.concept_embedding_layer(next_concept_id) + generated_embeddings.append(next_embedding) + self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding + + # Abacus State Memory and Encoder integration for the first step + abacus_pattern = self.abacus_encoder(next_embedding) + self.abacus_state_memory.store(abacus_pattern) + abacus_embedding = self.abacus_state_memory.retrieve() + + for _ in range(max_new_concepts - 1): + # Working Memory retrieval + episodic_kv = None + if working_memory_window > 0 and kv_cache.get_pos() > working_memory_window: + start_pos = kv_cache.get_pos() - working_memory_window + end_pos = kv_cache.get_pos() + episodic_kv = kv_cache.retrieve_episode(start_pos, end_pos) + + # Long-Term Memory retrieval + long_term_memory_embeddings = self.long_term_memory.retrieve(next_embedding.squeeze(0)) + + # Get psyche weights from PsycheController + psyche_weights = self.psyche_controller(next_embedding) + + concept_logits, kv_cache, x_id_step, x_ego_step, x_superego_step = self.model.forward_step( + next_embedding, + kv_cache, + abacus_embedding=abacus_embedding, + episodic_kv=episodic_kv, + long_term_memory_embeddings=long_term_memory_embeddings, + psyche_weights=psyche_weights + ) + + # Memetic Learning Layer integration + memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern) + + # Conscious Integration Layer for step + synthesized_state_step = self.conscious_integration_layer.forward( + id_output=x_id_step, + ego_output=x_ego_step, + superego_output=x_superego_step, + long_term_memory_embeddings=long_term_memory_embeddings, + memetic_fitness=memetic_fitness, + abacus_state=abacus_embedding + ) + + # Combine concept_logits from GPT and synthesized_state_step + concept_logits = concept_logits + synthesized_state_step.squeeze(1) # Squeeze to match dimensions + + next_concept_id = self.sample_from_logits(concept_logits, temperature) + next_embedding = self.concept_embedding_layer(next_concept_id) + generated_embeddings.append(next_embedding) + self.long_term_memory.store(next_embedding.squeeze(0)) # Store the generated embedding + + # Abacus State Memory and Encoder integration for subsequent steps + abacus_pattern = self.abacus_encoder(next_embedding) + self.abacus_state_memory.store(abacus_pattern) + abacus_embedding = self.abacus_state_memory.retrieve() # Update abacus_embedding for the next step + + # Memetic Learning Layer integration + memetic_fitness = self.memetic_learning_layer.forward(next_embedding, abacus_pattern) + + return torch.stack(generated_embeddings, dim=1) + + def generate_from_concepts(self, concept_list: list[int] | dict, max_new_concepts: int = 20, temperature: float = 1.0) -> list[int]: + # Encode the concept_list into initial input embeddings + encoded_concepts = self.concept_encoder(concept_list) + + if encoded_concepts.dtype == torch.long: + # If it's concept IDs, get embeddings from the concept_embedding_layer + input_embeddings = self.concept_embedding_layer(encoded_concepts) + abacus_embedding = None # No abacus embedding in this case + elif encoded_concepts.dtype == torch.float: + # If it's an abacus embedding, use it directly and set abacus_embedding + input_embeddings = encoded_concepts + abacus_embedding = encoded_concepts # The abacus embedding is the input embedding itself + else: + raise TypeError("Unexpected return type from concept_encoder.") + + # Call the main generate method + return self.generate(input_embeddings, max_new_concepts, temperature, abacus_embedding=abacus_embedding) + + + # Special tokens are no longer directly used with concept embeddings. + # The tool use logic will need to be re-evaluated or removed if not applicable. + # get_special = lambda s: self.tokenizer.encode_special(s) + # python_start = get_special("<|python_start|>") + # python_end = get_special("<|python_end|>") + # output_start = get_special("<|output_start|>") + # output_end = get_special("<|output_end|>") + # assistant_end = get_special("<|assistant_end|>") # if sampled, ends row + # bos = self.tokenizer.get_bos_token_id() # if sampled, ends row + + # 1) Run a batch 1 prefill of the prompt embeddings m = self.model.config kv_model_kwargs = {"num_heads": m.n_kv_head, "head_dim": m.n_embd // m.n_head, "num_layers": m.n_layer} kv_cache_prefill = KVCache( - batch_size=1, - seq_len=len(tokens), + batch_size=input_embeddings.size(0), + seq_len=input_embeddings.size(1), **kv_model_kwargs, ) - ids = torch.tensor([tokens], dtype=torch.long, device=device) - logits = self.model.forward(ids, kv_cache=kv_cache_prefill) - logits = logits[:, -1, :] - next_ids = sample_next_token(logits, rng, temperature, top_k) # (B, 1) - sampled_tokens = next_ids[:, 0].tolist() + logits = self.model.forward(input_embeddings, kv_cache=kv_cache_prefill) + # Removed token-based sampling logic (replace with embedding generation logic later) # 2) Replicate the KV cache for each sample/row - kv_length_hint = (len(tokens) + max_tokens) if max_tokens is not None else self.model.config.sequence_len + kv_length_hint = (input_embeddings.size(1) + max_tokens) if max_tokens is not None else self.model.config.sequence_len kv_cache_decode = KVCache( batch_size=num_samples, seq_len=kv_length_hint, @@ -230,7 +467,7 @@ class Engine: del kv_cache_prefill # no need to keep this memory around # 3) Initialize states for each sample - row_states = [RowState(tokens.copy()) for _ in range(num_samples)] + row_states = [RowState(input_embeddings[i].tolist()) for i in range(num_samples)] # Assuming input_embeddings is (B, T, C) # 4) Main generation loop num_generated = 0 diff --git a/nanochat/gpt.py b/nanochat/gpt.py index b640f1e..931dbe8 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -25,13 +25,124 @@ from nanochat.adamw import DistAdamW @dataclass class GPTConfig: - sequence_len: int = 1024 - vocab_size: int = 50304 - n_layer: int = 12 - n_head: int = 6 # number of query heads - n_kv_head: int = 6 # number of key/value heads (MQA) - n_embd: int = 768 - + def __init__(self, + n_layer=12, + n_head=12, + n_embd=768, + sequence_len=1024, + n_kv_head=None, + num_concept_ids=4096, # Updated to 4096 for n=12 hypercube + abacus_input_dim=64, # Default input dimension for the AbacusEncoder + dropout=0.0, + bias=False, + multiple_of=256, + norm_eps=1e-5, + rope_theta=10000, + + # For training + batch_size=1, + gradient_accumulation_steps=1, + max_iters=0, + lr=6e-4, + min_lr=6e-5, + weight_decay=1e-1, + beta1=0.9, + beta2=0.95, + grad_clip=1.0, + decay_lr=True, + warmup_iters=2000, + lr_decay_iters=600000, + + # For checkpointing + out_dir='out', + eval_interval=2000, + log_interval=1, + eval_iters=200, + eval_only=False, + always_save_checkpoint=True, + + # For distributed training + backend='nccl', + + # For system + device='cpu', + dtype='bfloat16', + compile=False, + + # For data + dataset='openwebtext', + + # For inference + init_from='scratch', + + # For chat + chat=False, + + # For concept + concept_memory_size=1000, + concept_memory_top_k=5, + use_concept_attention=False, + + # For psyche + psyche_id_lr_scale=1.0, + psyche_ego_lr_scale=1.0, + psyche_superego_lr_scale=1.0, + + **kwargs): + self.n_layer = n_layer + self.n_head = n_head + self.n_embd = n_embd + self.sequence_len = sequence_len + self.n_kv_head = n_kv_head if n_kv_head is not None else n_head + self.num_concept_ids = num_concept_ids + self.dropout = dropout + self.bias = bias + self.multiple_of = multiple_of + self.norm_eps = norm_eps + self.rope_theta = rope_theta + + self.batch_size = batch_size + self.gradient_accumulation_steps = gradient_accumulation_steps + self.max_iters = max_iters + self.lr = lr + self.min_lr = min_lr + self.weight_decay = weight_decay + self.beta1 = beta1 + self.beta2 = beta2 + self.grad_clip = grad_clip + self.decay_lr = decay_lr + self.warmup_iters = warmup_iters + self.lr_decay_iters = lr_decay_iters + + self.out_dir = out_dir + self.eval_interval = eval_interval + self.log_interval = log_interval + self.eval_iters = eval_iters + self.eval_only = eval_only + self.always_save_checkpoint = always_save_checkpoint + + self.backend = backend + + self.device = device + self.dtype = dtype + self.compile = compile + + self.dataset = dataset + + self.init_from = init_from + + self.chat = chat + + self.concept_memory_size = concept_memory_size + self.concept_memory_top_k = concept_memory_top_k + self.use_concept_attention = use_concept_attention + + self.psyche_id_lr_scale = psyche_id_lr_scale + self.psyche_ego_lr_scale = psyche_ego_lr_scale + self.psyche_superego_lr_scale = psyche_superego_lr_scale + + for k, v in kwargs.items(): + setattr(self, k, v) def norm(x): # Purely functional rmsnorm with no learnable params @@ -63,7 +174,7 @@ class CausalSelfAttention(nn.Module): self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) - def forward(self, x, cos_sin, kv_cache): + def forward(self, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None): B, T, C = x.size() # Project the input to get queries, keys, and values @@ -80,6 +191,14 @@ class CausalSelfAttention(nn.Module): # Apply KV cache: insert current k,v into cache, get the full view so far if kv_cache is not None: k, v = kv_cache.insert_kv(self.layer_idx, k, v) + + # If episodic_kv is provided, prepend it to the current k and v + if episodic_kv is not None: + episode_k_layer = episodic_kv[self.layer_idx, 0] + episode_v_layer = episodic_kv[self.layer_idx, 1] + k = torch.cat([episode_k_layer, k], dim=2) + v = torch.cat([episode_v_layer, v], dim=2) + Tq = q.size(2) # number of queries in this forward pass Tk = k.size(2) # number of keys/values in total (in the cache + current forward pass) @@ -135,15 +254,42 @@ class Block(nn.Module): return x +class PsycheController(nn.Module): + def __init__(self, config): + super().__init__() + self.controller_head = nn.Linear(config.n_embd, 3) # 3 for id, ego, superego + + def forward(self, x): + # For now, just return equal weights for each psyche layer + # In the future, this will be trained to dynamically blend psyche outputs + return torch.softmax(self.controller_head(x), dim=-1) + + class GPT(nn.Module): def __init__(self, config): super().__init__() self.config = config self.transformer = nn.ModuleDict({ - "wte": nn.Embedding(config.vocab_size, config.n_embd), "h": nn.ModuleList([Block(config, layer_idx) for layer_idx in range(config.n_layer)]), }) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Partition transformer layers into psyche layers + total_layers = config.n_layer + id_end = total_layers // 3 + ego_end = 2 * total_layers // 3 + + self.id_layers = self.transformer.h[:id_end] + self.ego_layers = self.transformer.h[id_end:ego_end] + self.superego_layers = self.transformer.h[ego_end:] + + self.psyche_registry = { + "id": self.id_layers, + "ego": self.ego_layers, + "superego": self.superego_layers + } + + self.concept_head = nn.Linear(config.n_embd, config.num_concept_ids, bias=False) # New concept head + self.psyche_controller = PsycheController(config) # Initialize PsycheController # To support meta device initialization, we init the rotary embeddings here, but it's fake # As for rotary_seq_len, these rotary embeddings are pretty small/cheap in memory, # so let's just over-compute them, but assert fail if we ever reach that amount. @@ -156,8 +302,8 @@ class GPT(nn.Module): def init_weights(self): self.apply(self._init_weights) - # zero out classifier weights - torch.nn.init.zeros_(self.lm_head.weight) + # zero out concept_head weights + torch.nn.init.zeros_(self.concept_head.weight) # zero out c_proj weights in all blocks for block in self.transformer.h: torch.nn.init.zeros_(block.mlp.c_proj.weight) @@ -166,9 +312,6 @@ class GPT(nn.Module): head_dim = self.config.n_embd // self.config.n_head cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim) self.cos, self.sin = cos, sin - # Cast the embeddings from fp32 to bf16: optim can tolerate it and it saves memory: both in the model and the activations - if self.transformer.wte.weight.device.type == "cuda": - self.transformer.wte.to(dtype=torch.bfloat16) def _init_weights(self, module): if isinstance(module, nn.Linear): @@ -186,7 +329,7 @@ class GPT(nn.Module): def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None): # autodetect the device from model embeddings if device is None: - device = self.transformer.wte.weight.device + device = self.concept_head.weight.device # stride the channels channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) inv_freq = 1.0 / (base ** (channel_range / head_dim)) @@ -200,12 +343,13 @@ class GPT(nn.Module): return cos, sin def get_device(self): - return self.transformer.wte.weight.device + # Get device from concept_head weight + return self.concept_head.weight.device def estimate_flops(self): """ Return the estimated FLOPs per token for the model. Ref: https://arxiv.org/abs/2204.02311 """ nparams = sum(p.numel() for p in self.parameters()) - nparams_embedding = self.transformer.wte.weight.numel() + nparams_embedding = 0 # No separate embedding layer now l, h, q, t = self.config.n_layer, self.config.n_head, self.config.n_embd // self.config.n_head, self.config.sequence_len num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t return num_flops_per_token @@ -213,95 +357,190 @@ class GPT(nn.Module): def setup_optimizers(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() - # Separate out all parameters into 3 groups (matrix, embedding, lm_head) - matrix_params = list(self.transformer.h.parameters()) - embedding_params = list(self.transformer.wte.parameters()) - lm_head_params = list(self.lm_head.parameters()) - assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(lm_head_params) - # Create the AdamW optimizer for the embedding and lm_head + # Separate out all parameters into 3 groups (matrix, embedding, concept_head) + # matrix_params = list(self.transformer.h.parameters()) + id_params = list(self.id_layers.parameters()) + ego_params = list(self.ego_layers.parameters()) + superego_params = list(self.superego_layers.parameters()) + + embedding_params = [] # No separate embedding layer now + concept_head_params = list(self.concept_head.parameters()) # New concept head params + psyche_controller_params = list(self.psyche_controller.parameters()) # Psyche controller params + + # assert len(list(self.parameters())) == len(matrix_params) + len(embedding_params) + len(concept_head_params) + # Create the AdamW optimizer for the embedding and concept_head # Scale the LR for the AdamW parameters by ∝1/√dmodel (having tuned the LRs for 768 dim model) dmodel_lr_scale = (model_dim / 768) ** -0.5 if rank == 0: print(f"Scaling the LR for the AdamW parameters ∝1/√({model_dim}/768) = {dmodel_lr_scale:.6f}") adam_groups = [ - dict(params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale), + dict(params=concept_head_params, lr=unembedding_lr * dmodel_lr_scale), # Use unembedding_lr for concept_head dict(params=embedding_params, lr=embedding_lr * dmodel_lr_scale), + dict(params=id_params, lr=3e-4 * dmodel_lr_scale), # Id layers learning rate + dict(params=ego_params, lr=1e-4 * dmodel_lr_scale), # Ego layers learning rate + dict(params=superego_params, lr=5e-5 * dmodel_lr_scale), # Superego layers learning rate + dict(params=psyche_controller_params, lr=1e-4 * dmodel_lr_scale), # Psyche controller learning rate ] adamw_kwargs = dict(betas=(0.8, 0.95), eps=1e-10, weight_decay=weight_decay) AdamWFactory = DistAdamW if ddp else partial(torch.optim.AdamW, fused=True) adamw_optimizer = AdamWFactory(adam_groups, **adamw_kwargs) # Create the Muon optimizer for the linear layers - muon_kwargs = dict(lr=matrix_lr, momentum=0.95) - MuonFactory = DistMuon if ddp else Muon - muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) + # muon_kwargs = dict(lr=matrix_lr, momentum=0.95) + # MuonFactory = DistMuon if ddp else Muon + # muon_optimizer = MuonFactory(matrix_params, **muon_kwargs) # Combine them the two optimizers into one list - optimizers = [adamw_optimizer, muon_optimizer] + optimizers = [adamw_optimizer] for opt in optimizers: for group in opt.param_groups: group["initial_lr"] = group["lr"] return optimizers - def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): - B, T = idx.size() + def _run_layers(self, layers, x, cos_sin, kv_cache, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None): + for block in layers: + x = block(x, cos_sin, kv_cache, episodic_kv) + return x + + def forward(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None): + B, T, C = input_embeddings.size() # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" - assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" - assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" - # if kv cache exists, we need to offset the rotary embeddings to the current position in the cache - T0 = 0 if kv_cache is None else kv_cache.get_pos() - cos_sin = self.cos[:, T0:T0+T], self.sin[:, T0:T0+T] # truncate cache to current sequence length + cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :] + + x = input_embeddings + + # Process Id layers + x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv) + + # Process Ego layers + x_ego = x_id + if abacus_embedding is not None: + # Broadcast abacus_embedding to match the sequence length of x_ego + # Assuming abacus_embedding is (B, C) and x_ego is (B, T, C) + abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_ego + # Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer + x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv) + + # Process Superego layers + x_superego = x_ego + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_superego + # Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv) + + # Dynamically blend the outputs based on psyche_weights + # Reshape psyche_weights for broadcasting: (B, 1, 3) + psyche_weights_reshaped = psyche_weights.unsqueeze(1) + + # Stack the outputs and apply weighted sum + # Stack will result in (B, T, 3, C) + stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2) + # Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C) + x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2) + + # Final concept head + return self.concept_head(x), kv_cache, x_id, x_ego, x_superego + + def forward_prefill(self, input_embeddings: torch.Tensor, kv_cache=None, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None): + B, T, C = input_embeddings.size() + + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + cos_sin = self.cos[:, :T, :, :], self.sin[:, :T, :, :] + + x = input_embeddings + + # Process Id layers + x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv) + + # Process Ego layers + x_ego = x_id + if abacus_embedding is not None: + # Broadcast abacus_embedding to match the sequence length of x_ego + # Assuming abacus_embedding is (B, C) and x_ego is (B, T, C) + abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_ego + # Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer + x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv) + + # Process Superego layers + x_superego = x_ego + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_superego + # Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv) + + # Dynamically blend the outputs based on psyche_weights + # Reshape psyche_weights for broadcasting: (B, 1, 3) + psyche_weights_reshaped = psyche_weights.unsqueeze(1) + + # Stack the outputs and apply weighted sum + # Stack will result in (B, T, 3, C) + stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2) + # Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C) + x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2) + + # Final concept head + return self.concept_head(x), kv_cache, x_id, x_ego, x_superego + + def forward_step(self, next_embedding: torch.Tensor, kv_cache, abacus_embedding: torch.Tensor | None = None, episodic_kv: tuple[torch.Tensor, torch.Tensor] | None = None, long_term_memory_embeddings: torch.Tensor | None = None, psyche_weights: torch.Tensor | None = None): + B, C = next_embedding.size() + T = kv_cache[0].size(1) + 1 # Current sequence length after adding next_embedding + + # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) + assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" + cos_sin = self.cos[:, T-1:T, :, :], self.sin[:, T-1:T, :, :] + + x = next_embedding.unsqueeze(1) # Add sequence dimension for consistency + + # Process Id layers + x_id = self._run_layers(self.id_layers, x, cos_sin, kv_cache, episodic_kv) + + # Process Ego layers + x_ego = x_id + if abacus_embedding is not None: + # Broadcast abacus_embedding to match the sequence length of x_ego + # Assuming abacus_embedding is (B, C) and x_ego is (B, T, C) + abacus_broadcast = abacus_embedding.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + abacus_broadcast # Inject abacus_embedding into ego layer + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_ego + # Assuming long_term_memory_embeddings is (B, C) and x_ego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_ego.size(1), -1) + x_ego = x_ego + long_term_memory_broadcast # Inject long_term_memory_embeddings into ego layer + x_ego = self._run_layers(self.ego_layers, x_ego, cos_sin, kv_cache, episodic_kv) + + # Process Superego layers + x_superego = x_ego + if long_term_memory_embeddings is not None: + # Broadcast long_term_memory_embeddings to match the sequence length of x_superego + # Assuming long_term_memory_embeddings is (B, C) and x_superego is (B, T, C) + long_term_memory_broadcast = long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = x_superego + long_term_memory_embeddings.unsqueeze(1).expand(-1, x_superego.size(1), -1) + x_superego = self._run_layers(self.superego_layers, x_superego, cos_sin, kv_cache, episodic_kv) + + # Dynamically blend the outputs based on psyche_weights + # Reshape psyche_weights for broadcasting: (B, 1, 3) + psyche_weights_reshaped = psyche_weights.unsqueeze(1) + + # Stack the outputs and apply weighted sum + # Stack will result in (B, T, 3, C) + stacked_outputs = torch.stack([x_id, x_ego, x_superego], dim=2) + # Weighted sum: (B, T, 1, C) after sum, then squeeze to (B, T, C) + x = (stacked_outputs * psyche_weights_reshaped.unsqueeze(-1)).sum(dim=2) - # Forward the trunk of the Transformer - x = self.transformer.wte(idx) x = norm(x) - for block in self.transformer.h: - x = block(x, cos_sin, kv_cache) - x = norm(x) - - # Forward the lm_head (compute logits) - softcap = 15 - if targets is not None: - # training mode: compute and return the loss - # TODO: experiment with Liger Kernels / chunked cross-entropy etc. - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap - logits = logits.float() # use tf32/fp32 for logits - loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1, reduction=loss_reduction) - return loss - else: - # inference mode: compute and return the logits - logits = self.lm_head(x) - logits = softcap * torch.tanh(logits / softcap) # logits softcap - return logits - - @torch.inference_mode() - def generate(self, tokens, max_tokens, temperature=1.0, top_k=None, seed=42): - """ - Naive autoregressive streaming inference. - To make it super simple, let's assume: - - batch size is 1 - - ids and the yielded tokens are simple Python lists and ints - """ - assert isinstance(tokens, list) - device = self.get_device() - rng = None - if temperature > 0: - rng = torch.Generator(device=device) - rng.manual_seed(seed) - ids = torch.tensor([tokens], dtype=torch.long, device=device) # add batch dim - for _ in range(max_tokens): - logits = self.forward(ids) # (B, T, vocab_size) - logits = logits[:, -1, :] # (B, vocab_size) - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') - if temperature > 0: - logits = logits / temperature - probs = F.softmax(logits, dim=-1) - next_ids = torch.multinomial(probs, num_samples=1, generator=rng) - else: - next_ids = torch.argmax(logits, dim=-1, keepdim=True) - ids = torch.cat((ids, next_ids), dim=1) - token = next_ids.item() - yield token + return self.concept_head(x.squeeze(1)), kv_cache, x_id, x_ego, x_superego diff --git a/nanochat/hypercube.py b/nanochat/hypercube.py new file mode 100644 index 0000000..f0787e3 --- /dev/null +++ b/nanochat/hypercube.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn + +class HypercubeTopology: + def __init__(self, n: int): + if not (1 <= n <= 16): # Limiting n for practical reasons, can be adjusted + raise ValueError("Hypercube dimension 'n' must be between 1 and 16.") + self.n = n + self.num_vertices = 2**n + + def to_binary(self, vertex_id: int) -> str: + if not (0 <= vertex_id < self.num_vertices): + raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.") + return format(vertex_id, f'0{self.n}b') + + def from_binary(self, binary_string: str) -> int: + if len(binary_string) != self.n: + raise ValueError(f"Binary string length must be {self.n} for a {self.n}-dimensional hypercube.") + return int(binary_string, 2) + + def get_neighbors(self, vertex_id: int) -> list[int]: + if not (0 <= vertex_id < self.num_vertices): + raise ValueError(f"Vertex ID {vertex_id} is out of range for a {self.n}-dimensional hypercube.") + + neighbors = [] + for i in range(self.n): + # Flip the i-th bit + neighbor_id = vertex_id ^ (1 << i) + neighbors.append(neighbor_id) + return neighbors + + def get_all_edges(self) -> list[tuple[int, int]]: + edges = [] + for i in range(self.num_vertices): + for neighbor in self.get_neighbors(i): + # Ensure each edge is added only once (e.g., (0,1) not (1,0)) + if i < neighbor: + edges.append((i, neighbor)) + return edges + + def get_random_vertex(self) -> int: + return torch.randint(0, self.num_vertices, (1,)).item() + + def get_random_path(self, start_vertex: int, length: int) -> list[int]: + if not (0 <= start_vertex < self.num_vertices): + raise ValueError(f"Start vertex ID {start_vertex} is out of range.") + if length < 1: + raise ValueError("Path length must be at least 1.") + + path = [start_vertex] + current_vertex = start_vertex + + for _ in range(length - 1): + neighbors = self.get_neighbors(current_vertex) + if not neighbors: + break # Should not happen in a hypercube unless n=0 + + # Choose a random neighbor that is not the previous vertex if possible + next_vertex = neighbors[torch.randint(0, len(neighbors), (1,)).item()] + path.append(next_vertex) + current_vertex = next_vertex + return path + +class HypercubeEmbeddingLayer(nn.Module): + def __init__(self, num_raw_concepts, embedding_dim, hypercube_topology): + super().__init__() + self.raw_embedding_layer = nn.Embedding(num_raw_concepts, embedding_dim) + self.hypercube_topology = hypercube_topology + # Embeddings for hypercube vertices. The number of vertices is hypercube_topology.num_vertices + self.vertex_embeddings = nn.Embedding(hypercube_topology.num_vertices, embedding_dim) + + def forward(self, concept_ids): + # Get initial embeddings for the input concept_ids + initial_embeddings = self.raw_embedding_layer(concept_ids) # (batch_size, embedding_dim) + + # Find the nearest hypercube vertex for each initial embedding + + # Get all hypercube vertex embeddings + # Ensure all_vertex_ids are on the same device as concept_ids + all_vertex_ids = torch.arange(self.hypercube_topology.num_vertices, device=concept_ids.device) + all_vertex_embeddings = self.vertex_embeddings(all_vertex_ids) # (num_vertices, embedding_dim) + + # Calculate squared Euclidean distance + # initial_embeddings_expanded: (batch_size, 1, embedding_dim) + # all_vertex_embeddings_expanded: (1, num_vertices, embedding_dim) + initial_embeddings_expanded = initial_embeddings.unsqueeze(1) + all_vertex_embeddings_expanded = all_vertex_embeddings.unsqueeze(0) + + distances = torch.sum((initial_embeddings_expanded - all_vertex_embeddings_expanded)**2, dim=2) # (batch_size, num_vertices) + + # Find the index of the nearest vertex for each initial embedding + nearest_vertex_indices = torch.argmin(distances, dim=1) # (batch_size,) + + # Retrieve the embeddings of the nearest vertices + final_embeddings = self.vertex_embeddings(nearest_vertex_indices) + + return final_embeddings + + +class LongTermMemory(nn.Module): + def __init__(self, embedding_dim: int, max_memory_size: int = 1000, top_k_retrieval: int = 5): + super().__init__() + self.embedding_dim = embedding_dim + self.max_memory_size = max_memory_size + self.top_k_retrieval = top_k_retrieval + + # Initialize an empty memory bank. We'll use a list of tensors for simplicity initially, + # but this could be replaced with a more efficient data structure or a learnable embedding layer. + self.memory_bank = [] + self.memory_bank_tensor = None # Will store concatenated memories for efficient retrieval + + def store(self, embedding: torch.Tensor): + # Store a new embedding in the memory bank + if len(self.memory_bank) >= self.max_memory_size: + # Simple FIFO eviction for now + self.memory_bank.pop(0) + self.memory_bank.append(embedding.detach().cpu()) # Store on CPU to save GPU memory + self.memory_bank_tensor = None # Invalidate cached tensor + + def retrieve(self, query_embedding: torch.Tensor) -> torch.Tensor: + # Retrieve top-k most similar embeddings from the memory bank + if not self.memory_bank: + return torch.zeros(query_embedding.shape[0], self.top_k_retrieval, self.embedding_dim, device=query_embedding.device) + + if self.memory_bank_tensor is None: + self.memory_bank_tensor = torch.stack(self.memory_bank).to(query_embedding.device) + + # Normalize query and memory bank for cosine similarity + query_norm = F.normalize(query_embedding, p=2, dim=-1) + memory_norm = F.normalize(self.memory_bank_tensor, p=2, dim=-1) + + # Calculate cosine similarity + # query_norm: (batch_size, embedding_dim) + # memory_norm: (num_memories, embedding_dim) + # similarities: (batch_size, num_memories) + similarities = torch.matmul(query_norm, memory_norm.transpose(0, 1)) + + # Get top-k similar memories + # top_k_values: (batch_size, top_k_retrieval) + # top_k_indices: (batch_size, top_k_retrieval) + top_k_values, top_k_indices = torch.topk(similarities, min(self.top_k_retrieval, len(self.memory_bank)), dim=-1) + + # Retrieve the actual embeddings + retrieved_memories = self.memory_bank_tensor[top_k_indices] + + return retrieved_memories + + def forward(self, query_embedding: torch.Tensor) -> torch.Tensor: + return self.retrieve(query_embedding) \ No newline at end of file diff --git a/nanochat/memetic_learning.py b/nanochat/memetic_learning.py new file mode 100644 index 0000000..3a68c6a --- /dev/null +++ b/nanochat/memetic_learning.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn + +class MemeticLearningLayer(nn.Module): + def __init__(self, config, abacus_encoder, internal_self_model): + super().__init__() + self.config = config + self.abacus_encoder = abacus_encoder + self.internal_self_model = internal_self_model + + # Placeholder for memetic fitness evaluation mechanism + # This could be a simple linear layer or a more complex network + self.fitness_evaluator = nn.Linear(config.abacus_input_dim, 1) + + # Placeholder for concept mapping expansion (analogy/metaphor) + # This might involve a transformation of embeddings or a retrieval mechanism + self.concept_mapper = nn.Sequential( + nn.Linear(config.n_embd * 2, config.n_embd), # Input: two concept embeddings + nn.ReLU(), + nn.Linear(config.n_embd, config.n_embd) + ) + + def evaluate_memetic_fitness(self, abacus_pattern: torch.Tensor) -> torch.Tensor: + # Evaluate the fitness of a memetic pattern using the abacus encoder output + fitness_score = self.fitness_evaluator(abacus_pattern) + return fitness_score + + def expand_concept_mapping(self, concept1_embedding: torch.Tensor, concept2_embedding: torch.Tensor) -> torch.Tensor: + # Expand concept mapping via analogy and metaphor + # This takes two concept embeddings and generates a new, related concept embedding + combined_concepts = torch.cat([concept1_embedding, concept2_embedding], dim=-1) + new_concept_embedding = self.concept_mapper(combined_concepts) + return new_concept_embedding + + def forward(self, current_concept_embedding: torch.Tensor, abacus_pattern: torch.Tensor): + # Orchestrates the memetic learning process + # 1. Evaluate memetic fitness + fitness = self.evaluate_memetic_fitness(abacus_pattern) + + # 2. Potentially update internal self-model based on fitness or new concepts + # self.internal_self_model.update_beliefs(current_concept_embedding) # Example interaction + + # 3. Generate new concepts or analogies + # new_concept = self.expand_concept_mapping(current_concept_embedding, some_other_concept) + + return fitness \ No newline at end of file diff --git a/nanochat/self_model.py b/nanochat/self_model.py new file mode 100644 index 0000000..5a3a560 --- /dev/null +++ b/nanochat/self_model.py @@ -0,0 +1,58 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class InternalSelfModel(nn.Module): + def __init__(self, embedding_dim: int, config): + super().__init__() + self.embedding_dim = embedding_dim + self.config = config + + # Placeholder for belief representation (e.g., a learned embedding or a set of parameters) + self.belief_embedding = nn.Parameter(torch.randn(embedding_dim)) + + # Placeholder for a simple prediction head to calculate prediction error + self.prediction_head = nn.Linear(embedding_dim, 1) # Predicts a scalar error signal + + # Placeholder for conceptual growth mechanism (e.g., a small MLP or attention mechanism) + self.conceptual_growth_mlp = nn.Sequential( + nn.Linear(embedding_dim * 2, embedding_dim), # Input: current belief + new concept + nn.ReLU(), + nn.Linear(embedding_dim, embedding_dim) + ) + + def update_beliefs(self, current_concept_embedding: torch.Tensor): + # This method would update the internal belief representation based on new information. + # For now, a simple update rule (e.g., moving average or attention-based update) + # In a more advanced implementation, this could involve a recurrent neural network. + # Example: simple weighted average + alpha = 0.1 # Learning rate for belief update + self.belief_embedding.data = (1 - alpha) * self.belief_embedding.data + alpha * current_concept_embedding.mean(dim=0) + + def calculate_prediction_error(self, predicted_output: torch.Tensor, actual_output: torch.Tensor) -> torch.Tensor: + # This method calculates the discrepancy between predicted and actual outcomes. + # For now, a simple mean squared error. + error = F.mse_loss(predicted_output, actual_output) + return error + + def promote_conceptual_growth(self, current_belief_embedding: torch.Tensor, new_concept_embedding: torch.Tensor) -> torch.Tensor: + # This method facilitates the integration of new concepts into the existing conceptual framework. + # For now, a simple MLP that takes the concatenation of current belief and new concept. + combined_embedding = torch.cat([current_belief_embedding, new_concept_embedding], dim=-1) + updated_concept_embedding = self.conceptual_growth_mlp(combined_embedding) + return updated_concept_embedding + + def forward(self, current_concept_embedding: torch.Tensor, predicted_output: torch.Tensor, actual_output: torch.Tensor): + # Update beliefs + self.update_beliefs(current_concept_embedding) + + # Calculate prediction error + error = self.calculate_prediction_error(predicted_output, actual_output) + + # Promote conceptual growth (example: using the current belief and concept embedding) + # This part would be more complex in a full implementation, potentially involving attention + # or other mechanisms to decide how to grow concepts based on error and new info. + # For demonstration, let's assume conceptual growth is triggered by new concept embeddings. + # updated_concept = self.promote_conceptual_growth(self.belief_embedding, current_concept_embedding) + + return error, self.belief_embedding \ No newline at end of file