nanochat/nanochat/abacus_encoder.py
2025-10-31 03:27:16 +00:00

24 lines
1.1 KiB
Python

import torch
import torch.nn as nn
class AbacusEncoder(nn.Module):
def __init__(self, input_dim: int, embedding_dim: int):
super().__init__()
self.input_dim = input_dim
self.embedding_dim = embedding_dim
# Simple linear layer to encode abacus-like patterns into the embedding space
self.encoder_layer = nn.Linear(input_dim, embedding_dim)
def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
# abacus_pattern is expected to be a tensor of shape (batch_size, input_dim)
if abacus_pattern.shape[-1] != self.input_dim:
raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}")
return self.encoder_layer(abacus_pattern)
def decode(self, concept_vector: torch.Tensor) -> torch.Tensor:
# Placeholder for decoding functionality
raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.")
def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
return self.encode(abacus_pattern)