This commit is contained in:
Abhaya Pattanaik 2026-03-25 18:22:17 +05:00 committed by GitHub
commit 95135c8a67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -191,6 +191,21 @@ class RustBPETokenizer:
@classmethod
def from_directory(cls, tokenizer_dir):
# Try JSON format first (portable, avoids pickle of native Rust objects)
json_path = os.path.join(tokenizer_dir, "tokenizer.json")
if os.path.exists(json_path):
import json
with open(json_path, "r") as f:
data = json.load(f)
mergeable_ranks = {bytes.fromhex(k): v for k, v in data["mergeable_ranks"].items()}
enc = tiktoken.Encoding(
name=data.get("name", "rustbpe"),
pat_str=data["pattern"],
mergeable_ranks=mergeable_ranks,
special_tokens=data["special_tokens"],
)
return cls(enc, "<|bos|>")
# Fall back to legacy pickle format
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
with open(pickle_path, "rb") as f:
enc = pickle.load(f)
@ -256,12 +271,20 @@ class RustBPETokenizer:
return self.enc.decode(ids)
def save(self, tokenizer_dir):
# save the encoding object to disk
# Save as portable JSON (avoids pickling tiktoken's native Rust objects,
# which causes heap corruption / SIGSEGV on ARM Docker with glibc malloc)
import json
os.makedirs(tokenizer_dir, exist_ok=True)
pickle_path = os.path.join(tokenizer_dir, "tokenizer.pkl")
with open(pickle_path, "wb") as f:
pickle.dump(self.enc, f)
print(f"Saved tokenizer encoding to {pickle_path}")
data = {
"name": self.enc.name,
"pattern": self.enc._pat_str,
"mergeable_ranks": {k.hex(): v for k, v in self.enc._mergeable_ranks.items()},
"special_tokens": self.enc._special_tokens,
}
json_path = os.path.join(tokenizer_dir, "tokenizer.json")
with open(json_path, "w") as f:
json.dump(data, f)
print(f"Saved tokenizer encoding to {json_path}")
def render_conversation(self, conversation, max_tokens=2048):
"""
@ -402,5 +425,5 @@ def get_token_bytes(device="cpu"):
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
assert os.path.exists(token_bytes_path), f"Token bytes not found at {token_bytes_path}? It gets written by tok_train.py"
with open(token_bytes_path, "rb") as f:
token_bytes = torch.load(f, map_location=device)
token_bytes = torch.load(f, map_location=device, weights_only=True)
return token_bytes