diff --git a/nanochat/tokenizer.py b/nanochat/tokenizer.py index a2146c2..1f062d4 100644 --- a/nanochat/tokenizer.py +++ b/nanochat/tokenizer.py @@ -191,6 +191,21 @@ class RustBPETokenizer: @classmethod def from_directory(cls, tokenizer_dir): + # Try JSON format first (portable, avoids pickle of native Rust objects) + json_path = os.path.join(tokenizer_dir, "tokenizer.json") + if os.path.exists(json_path): + import json + with open(json_path, "r") as f: + data = json.load(f) + mergeable_ranks = {bytes.fromhex(k): v for k, v in data["mergeable_ranks"].items()} + enc = tiktoken.Encoding( + name=data.get("name", "rustbpe"), + pat_str=data["pattern"], + mergeable_ranks=mergeable_ranks, + special_tokens=data["special_tokens"], + ) + return cls(enc, "<|bos|>") + # Fall back to legacy pickle format pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") with open(pickle_path, "rb") as f: enc = pickle.load(f) @@ -256,12 +271,20 @@ class RustBPETokenizer: return self.enc.decode(ids) def save(self, tokenizer_dir): - # save the encoding object to disk + # Save as portable JSON (avoids pickling tiktoken's native Rust objects, + # which causes heap corruption / SIGSEGV on ARM Docker with glibc malloc) + import json os.makedirs(tokenizer_dir, exist_ok=True) - pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl") - with open(pickle_path, "wb") as f: - pickle.dump(self.enc, f) - print(f"Saved tokenizer encoding to {pickle_path}") + data = { + "name": self.enc.name, + "pattern": self.enc._pat_str, + "mergeable_ranks": {k.hex(): v for k, v in self.enc._mergeable_ranks.items()}, + "special_tokens": self.enc._special_tokens, + } + json_path = os.path.join(tokenizer_dir, "tokenizer.json") + with open(json_path, "w") as f: + json.dump(data, f) + print(f"Saved tokenizer encoding to {json_path}") def render_conversation(self, conversation, max_tokens=2048): """ @@ -402,5 +425,5 @@ def get_token_bytes(device="cpu"): token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt") assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py" with open(token_bytes_path, "rb") as f: - token_bytes = torch.load(f, map_location=device) + token_bytes = torch.load(f, map_location=device, weights_only=True) return token_bytes