This commit is contained in:
Hossein-Lakzaei 2025-11-14 15:02:27 -08:00 committed by GitHub
commit f77590ab9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 7 deletions

View File

@ -13,3 +13,4 @@ pyo3-log = "0.12.4"
ahash = "0.8.12"
rayon = "1.11.0"
compact_str = "0.9.0"
num-format = "0.4"

View File

@ -6,6 +6,7 @@ use fancy_regex::Regex;
use pyo3::prelude::*;
use ahash::{AHashMap, AHashSet};
use num_format::{Locale, ToFormattedString};
use compact_str::CompactString;
use rayon::prelude::*;
@ -164,15 +165,15 @@ impl Tokenizer {
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
assert!(vocab_size >= 256, "vocab_size must be at least 256");
let num_merges = vocab_size - 256;
log::info!("Starting BPE training: {} merges to compute", num_merges);
log::info!("Starting BPE training: {} merges to compute", num_merges.to_formatted_string(&Locale::en));
self.merges.clear();
// ---- Initial pair_counts and where_to_update (parallel) ----
log::info!("Computing initial pair counts from {} unique sequences", words.len());
log::info!("Computing initial pair counts from {} unique sequences", words.len().to_formatted_string(&Locale::en));
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
// ---- Build heap ----
log::info!("Building heap with {} unique pairs", pair_counts.len());
log::info!("Building heap with {} unique pairs", pair_counts.len().to_formatted_string(&Locale::en));
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() {
let c = *pair_counts.get(&pair).unwrap_or(&0);
@ -375,7 +376,7 @@ impl Tokenizer {
break;
}
}
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
log::info!("Processed {} sequences total, {} unique", total_sequences.to_formatted_string(&Locale::en), counts.len().to_formatted_string(&Locale::en));
// Materialize words & counts
let mut words = Vec::with_capacity(counts.len());

View File

@ -196,9 +196,9 @@ RESET = '\033[0m'
# Print vocab sizes
print(f"\nVocab sizes:")
print(f"GPT-2: {vocab_sizes['gpt2']}")
print(f"GPT-4: {vocab_sizes['gpt4']}")
print(f"Ours: {vocab_sizes['ours']}")
print(f"GPT-2: {vocab_sizes['gpt2']:,}")
print(f"GPT-4: {vocab_sizes['gpt4']:,}")
print(f"Ours: {vocab_sizes['ours']:,}")
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
"""Print comparison table between baseline tokenizer and ours."""