Compare commits

...

9 Commits

4 changed files with 10 additions and 8 deletions

View File

@ -244,7 +244,7 @@ class GPT(nn.Module):
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'): def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
B, T = idx.size() B, T = idx.size()
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim)) # Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}" assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}" assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16" assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"

View File

@ -13,3 +13,4 @@ pyo3-log = "0.12.4"
ahash = "0.8.12" ahash = "0.8.12"
rayon = "1.11.0" rayon = "1.11.0"
compact_str = "0.9.0" compact_str = "0.9.0"
num-format = "0.4"

View File

@ -6,6 +6,7 @@ use fancy_regex::Regex;
use pyo3::prelude::*; use pyo3::prelude::*;
use ahash::{AHashMap, AHashSet}; use ahash::{AHashMap, AHashSet};
use num_format::{Locale, ToFormattedString};
use compact_str::CompactString; use compact_str::CompactString;
use rayon::prelude::*; use rayon::prelude::*;
@ -164,15 +165,15 @@ impl Tokenizer {
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) { fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
assert!(vocab_size >= 256, "vocab_size must be at least 256"); assert!(vocab_size >= 256, "vocab_size must be at least 256");
let num_merges = vocab_size - 256; let num_merges = vocab_size - 256;
log::info!("Starting BPE training: {} merges to compute", num_merges); log::info!("Starting BPE training: {} merges to compute", num_merges.to_formatted_string(&Locale::en));
self.merges.clear(); self.merges.clear();
// ---- Initial pair_counts and where_to_update (parallel) ---- // ---- Initial pair_counts and where_to_update (parallel) ----
log::info!("Computing initial pair counts from {} unique sequences", words.len()); log::info!("Computing initial pair counts from {} unique sequences", words.len().to_formatted_string(&Locale::en));
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts); let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
// ---- Build heap ---- // ---- Build heap ----
log::info!("Building heap with {} unique pairs", pair_counts.len()); log::info!("Building heap with {} unique pairs", pair_counts.len().to_formatted_string(&Locale::en));
let mut heap = OctonaryHeap::with_capacity(pair_counts.len()); let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
for (pair, pos) in where_to_update.drain() { for (pair, pos) in where_to_update.drain() {
let c = *pair_counts.get(&pair).unwrap_or(&0); let c = *pair_counts.get(&pair).unwrap_or(&0);
@ -375,7 +376,7 @@ impl Tokenizer {
break; break;
} }
} }
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len()); log::info!("Processed {} sequences total, {} unique", total_sequences.to_formatted_string(&Locale::en), counts.len().to_formatted_string(&Locale::en));
// Materialize words & counts // Materialize words & counts
let mut words = Vec::with_capacity(counts.len()); let mut words = Vec::with_capacity(counts.len());

View File

@ -196,9 +196,9 @@ RESET = '\033[0m'
# Print vocab sizes # Print vocab sizes
print(f"\nVocab sizes:") print(f"\nVocab sizes:")
print(f"GPT-2: {vocab_sizes['gpt2']}") print(f"GPT-2: {vocab_sizes['gpt2']:,}")
print(f"GPT-4: {vocab_sizes['gpt4']}") print(f"GPT-4: {vocab_sizes['gpt4']:,}")
print(f"Ours: {vocab_sizes['ours']}") print(f"Ours: {vocab_sizes['ours']:,}")
def print_comparison(baseline_name, baseline_results, ours_results, all_text): def print_comparison(baseline_name, baseline_results, ours_results, all_text):
"""Print comparison table between baseline tokenizer and ours.""" """Print comparison table between baseline tokenizer and ours."""