diff --git a/rustbpe/src/lib.rs b/rustbpe/src/lib.rs index b43fb6c..c3696f5 100644 --- a/rustbpe/src/lib.rs +++ b/rustbpe/src/lib.rs @@ -27,6 +27,13 @@ pub struct Tokenizer { // ------------------------ internal helpers ------------------------ +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[repr(i8)] +enum Delta { + Rem = -1, + Ins = 1, +} + #[derive(Clone, Debug)] struct Word { ids: Vec, @@ -48,7 +55,7 @@ impl Word { /// -1 for removed pairs, +1 for newly created pairs. /// /// NOTE: this version deliberately avoids a HashMap in the hot loop. - fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, i32)> { + fn merge_pair(&mut self, pair: Pair, new_id: u32) -> Vec<(Pair, Delta)> { let (a, b) = pair; let n = self.ids.len(); if n < 2 { @@ -56,7 +63,7 @@ impl Word { } let mut out: Vec = Vec::with_capacity(n); - let mut deltas: Vec<(Pair, i32)> = Vec::with_capacity(6); + let mut deltas: Vec<(Pair, Delta)> = Vec::with_capacity(6); let mut i = 0; while i < n { @@ -66,13 +73,13 @@ impl Word { // remove old pairs if let Some(x) = left { - deltas.push(((x, a), -1)); - deltas.push(((x, new_id), 1)); + deltas.push(((x, a), Delta::Rem)); + deltas.push(((x, new_id), Delta::Ins)); } - deltas.push(((a, b), -1)); + deltas.push(((a, b), Delta::Rem)); if let Some(y) = right { - deltas.push(((b, y), -1)); - deltas.push(((new_id, y), 1)); + deltas.push(((b, y), Delta::Rem)); + deltas.push(((new_id, y), Delta::Ins)); } // write merged token @@ -112,12 +119,10 @@ impl PartialOrd for MergeJob { impl Ord for MergeJob { fn cmp(&self, other: &Self) -> Ordering { // Max-heap by count; tie-break to ascending pair order (deterministic) - if self.count != other.count { - self.count.cmp(&other.count) - } else { + self.count.cmp(&other.count).then_with(|| // ascending order on the pair when counts tie other.pair.cmp(&self.pair) - } + ) } } @@ -217,10 +222,10 @@ impl Tokenizer { let changes = words[word_idx].merge_pair(top.pair, new_id); // Update global pair counts based on this word's count for (pair, delta) in changes { - let delta_total = delta * counts[word_idx]; + let delta_total = (delta as i32) * counts[word_idx]; if delta_total != 0 { *pair_counts.entry(pair).or_default() += delta_total; - if delta > 0 { + if delta == Delta::Ins { local_pos_updates.entry(pair).or_default().insert(word_idx); } }