mirror of
https://github.com/karpathy/nanochat.git
synced 2026-03-12 04:05:30 +00:00
203 lines
8.2 KiB
Python
203 lines
8.2 KiB
Python
"""
|
||
Train a tokenizer using our own BPE Tokenizer library.
|
||
In the style of GPT-4 tokenizer.
|
||
"""
|
||
import os
|
||
import time
|
||
from typing import Iterator, Tuple
|
||
import argparse
|
||
import torch
|
||
from nanochat.tokenizer import RustBPETokenizer
|
||
from nanochat.common import get_base_dir
|
||
from nanochat.dataset import parquets_iter_batched
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Parse command line arguments
|
||
|
||
parser = argparse.ArgumentParser(description='Train a BPE tokenizer')
|
||
parser.add_argument('--max-chars', type=int, default=2_000_000_000, help='Maximum characters to train on (default: 10B)')
|
||
parser.add_argument('--doc-cap', type=int, default=10_000, help='Maximum characters per document (default: 10,000)')
|
||
parser.add_argument('--vocab-size', type=int, default=2**16, help='Vocabulary size (default: 32768 = 2^15)')
|
||
args = parser.parse_args()
|
||
print(f"max_chars: {args.max_chars:,}")
|
||
print(f"doc_cap: {args.doc_cap:,}")
|
||
print(f"vocab_size: {args.vocab_size:,}")
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Text iterator
|
||
|
||
def text_iterator() -> Iterator[str]:
|
||
"""文档文本迭代器
|
||
Args:
|
||
None
|
||
|
||
Yields:
|
||
str: 文档文本
|
||
"""
|
||
nchars = 0
|
||
for batch in parquets_iter_batched(split="train"):
|
||
for doc in batch:
|
||
doc_text = doc
|
||
# 若文档长度超过配置上限,则截断
|
||
if len(doc_text) > args.doc_cap:
|
||
doc_text = doc_text[:args.doc_cap]
|
||
nchars += len(doc_text)
|
||
yield doc_text
|
||
# 如果已经处理的字符数超过配置上限,则停止迭代
|
||
if nchars > args.max_chars:
|
||
return
|
||
text_iter = text_iterator()
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Train the tokenizer
|
||
start = time.time()
|
||
tokenizer = RustBPETokenizer.train_from_iterator(text_iter, args.vocab_size)
|
||
end = time.time()
|
||
train_time = end - start
|
||
print(f"Training time: {train_time:.2f}s")
|
||
|
||
def train(iterator: Iterator[str], vocab_size: int) -> Tuple[RustBPETokenizer, float]:
|
||
"""训练BPE分词器
|
||
Args:
|
||
iterator (Iterator[str]): 文本迭代器
|
||
vocab_size (int): 词表大小
|
||
|
||
Returns:
|
||
Tuple[RustBPETokenizer, float]: 训练好的分词器和训练时间(秒)
|
||
"""
|
||
start = time.time()
|
||
tokenizer = RustBPETokenizer.train_from_iterator(iterator, vocab_size)
|
||
end = time.time()
|
||
train_time = end - start
|
||
return tokenizer, train_time
|
||
# -----------------------------------------------------------------------------
|
||
# Save the tokenizer to disk
|
||
base_dir = get_base_dir()
|
||
tokenizer_dir = os.path.join(base_dir, "tokenizer")
|
||
tokenizer.save(tokenizer_dir)
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Quick inline sanity check
|
||
test_text = """Hello world! This is a test.
|
||
Numbers: 123, 4567, 89
|
||
Contractions: I'm, you're, it's
|
||
Special chars: @#$%^&*()
|
||
Unicode: 你好世界 🌍"""
|
||
encoded = tokenizer.encode(test_text)
|
||
decoded = tokenizer.decode(encoded)
|
||
assert decoded == test_text
|
||
|
||
def sanity_check(tokenizer: RustBPETokenizer):
|
||
"""对分词器进行快速的内联检查,确保编码和解码的一致性
|
||
Args:
|
||
tokenizer (RustBPETokenizer): 需要检查的分词器
|
||
|
||
Raises:
|
||
AssertionError: 如果编码和解码不一致,则抛出断言错误
|
||
"""
|
||
test_text = """Hello world! This is a test.
|
||
Numbers: 123, 4567, 89
|
||
Contractions: I'm, you're, it's
|
||
Special chars: @#$%^&*()
|
||
Unicode: 你好世界 🌍"""
|
||
encoded = tokenizer.encode(test_text)
|
||
decoded = tokenizer.decode(encoded)
|
||
assert decoded == test_text
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# One more thing: we wish to cache a mapping from token id to number of bytes of that token
|
||
# for efficient evaluation of bits per byte. Unlike the typical mean loss, this
|
||
# allows us to report a loss that is invariant to the vocab size of the tokenizer.
|
||
# The bits per byte on the validation set is then one of the primary metrics we care about.
|
||
vocab_size = tokenizer.get_vocab_size()
|
||
special_set = set(tokenizer.get_special_tokens())
|
||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||
token_bytes = []
|
||
for token_id in range(vocab_size):
|
||
token_str = token_strings[token_id] # the Python string representation of this token
|
||
if token_str in special_set:
|
||
token_bytes.append(0) # special characters are not counted
|
||
else:
|
||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||
token_bytes.append(id_bytes)
|
||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||
token_bytes_path = os.path.join(tokenizer_dir, "token_bytes.pt")
|
||
with open(token_bytes_path, "wb") as f:
|
||
torch.save(token_bytes, f)
|
||
print(f"Saved token_bytes to {token_bytes_path}")
|
||
|
||
def generate_token_bytes(tokenizer: RustBPETokenizer, save_path: str) -> torch.Tensor:
|
||
"""生成一个张量,表示每个token id对应的字节数,并保存到磁盘
|
||
Args:
|
||
tokenizer (RustBPETokenizer): 已训练好的分词器
|
||
save_path (str): 保存token_bytes张量的路径
|
||
|
||
Returns:
|
||
torch.Tensor: 包含每个token id对应的字节数的张量
|
||
"""
|
||
vocab_size = tokenizer.get_vocab_size()
|
||
special_set = set(tokenizer.get_special_tokens())
|
||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||
token_bytes = []
|
||
for token_id in range(vocab_size):
|
||
token_str = token_strings[token_id] # the Python string representation of this token
|
||
if token_str in special_set:
|
||
token_bytes.append(0) # special characters are not counted
|
||
else:
|
||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||
token_bytes.append(id_bytes)
|
||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||
with open(save_path, "wb") as f:
|
||
torch.save(token_bytes, f)
|
||
print(f"Saved token_bytes to {save_path}")
|
||
return token_bytes
|
||
|
||
# Log to report
|
||
from nanochat.report import get_report
|
||
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
|
||
get_report().log(section="Tokenizer training", data=[
|
||
vars(args), # argparse command line arguments
|
||
{"train_time": train_time},
|
||
{"num_special_tokens": len(special_set)},
|
||
{
|
||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||
}
|
||
])
|
||
|
||
def log_tokenizer_training(args: argparse.Namespace, train_time: float, tokenizer: RustBPETokenizer):
|
||
"""记录分词器训练的相关信息到报告中
|
||
Args:
|
||
args (argparse.Namespace): 命令行参数
|
||
train_time (float): 训练时间(秒)
|
||
tokenizer (RustBPETokenizer): 已训练好的分词器
|
||
"""
|
||
# 计算token_bytes统计信息
|
||
vocab_size = tokenizer.get_vocab_size()
|
||
special_set = set(tokenizer.get_special_tokens())
|
||
token_strings = [tokenizer.decode([token_id]) for token_id in range(vocab_size)]
|
||
token_bytes = []
|
||
for token_id in range(vocab_size):
|
||
token_str = token_strings[token_id] # the Python string representation of this token
|
||
if token_str in special_set:
|
||
token_bytes.append(0) # special characters are not counted
|
||
else:
|
||
id_bytes = len(token_str.encode("utf-8")) # number of bytes that make up this token
|
||
token_bytes.append(id_bytes)
|
||
token_bytes = torch.tensor(token_bytes, dtype=torch.int32, device='cpu')
|
||
token_bytes_nonzero = (token_bytes[token_bytes > 0]).to(dtype=torch.float32)
|
||
|
||
# Log to report
|
||
get_report().log(section="Tokenizer training", data=[
|
||
vars(args), # argparse command line arguments
|
||
{"train_time": train_time},
|
||
{"num_special_tokens": len(special_set)},
|
||
{
|
||
"token_bytes_min": int(token_bytes_nonzero.min().item()),
|
||
"token_bytes_max": int(token_bytes_nonzero.max().item()),
|
||
"token_bytes_mean": token_bytes_nonzero.mean().item(),
|
||
"token_bytes_std": token_bytes_nonzero.std().item(),
|
||
}
|
||
]) |