mirror of
https://github.com/karpathy/nanochat.git
synced 2026-01-05 03:02:37 +00:00
24 lines
1.1 KiB
Python
24 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class AbacusEncoder(nn.Module):
|
|
def __init__(self, input_dim: int, embedding_dim: int):
|
|
super().__init__()
|
|
self.input_dim = input_dim
|
|
self.embedding_dim = embedding_dim
|
|
|
|
# Simple linear layer to encode abacus-like patterns into the embedding space
|
|
self.encoder_layer = nn.Linear(input_dim, embedding_dim)
|
|
|
|
def encode(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
|
|
# abacus_pattern is expected to be a tensor of shape (batch_size, input_dim)
|
|
if abacus_pattern.shape[-1] != self.input_dim:
|
|
raise ValueError(f"Expected abacus_pattern to have last dimension {self.input_dim}, but got {abacus_pattern.shape[-1]}")
|
|
return self.encoder_layer(abacus_pattern)
|
|
|
|
def decode(self, concept_vector: torch.Tensor) -> torch.Tensor:
|
|
# Placeholder for decoding functionality
|
|
raise NotImplementedError("Decoding from concept vector to abacus pattern is not yet implemented.")
|
|
|
|
def forward(self, abacus_pattern: torch.Tensor) -> torch.Tensor:
|
|
return self.encode(abacus_pattern) |