mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
9 Commits
c33619b8f4
...
de354b6fad
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de354b6fad | ||
|
|
4a87a0d19f | ||
|
|
11e68bf442 | ||
|
|
4715fdcf52 | ||
|
|
1a428f2b0b | ||
|
|
520bdfe081 | ||
|
|
355cc60089 | ||
|
|
671e8d9fc9 | ||
|
|
52382d58c5 |
|
|
@ -244,7 +244,7 @@ class GPT(nn.Module):
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||||
|
|
|
||||||
|
|
@ -13,3 +13,4 @@ pyo3-log = "0.12.4"
|
||||||
ahash = "0.8.12"
|
ahash = "0.8.12"
|
||||||
rayon = "1.11.0"
|
rayon = "1.11.0"
|
||||||
compact_str = "0.9.0"
|
compact_str = "0.9.0"
|
||||||
|
num-format = "0.4"
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ use fancy_regex::Regex;
|
||||||
use pyo3::prelude::*;
|
use pyo3::prelude::*;
|
||||||
|
|
||||||
use ahash::{AHashMap, AHashSet};
|
use ahash::{AHashMap, AHashSet};
|
||||||
|
use num_format::{Locale, ToFormattedString};
|
||||||
use compact_str::CompactString;
|
use compact_str::CompactString;
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
|
@ -164,15 +165,15 @@ impl Tokenizer {
|
||||||
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
|
fn train_core_incremental(&mut self, mut words: Vec<Word>, counts: Vec<i32>, vocab_size: u32) {
|
||||||
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
assert!(vocab_size >= 256, "vocab_size must be at least 256");
|
||||||
let num_merges = vocab_size - 256;
|
let num_merges = vocab_size - 256;
|
||||||
log::info!("Starting BPE training: {} merges to compute", num_merges);
|
log::info!("Starting BPE training: {} merges to compute", num_merges.to_formatted_string(&Locale::en));
|
||||||
self.merges.clear();
|
self.merges.clear();
|
||||||
|
|
||||||
// ---- Initial pair_counts and where_to_update (parallel) ----
|
// ---- Initial pair_counts and where_to_update (parallel) ----
|
||||||
log::info!("Computing initial pair counts from {} unique sequences", words.len());
|
log::info!("Computing initial pair counts from {} unique sequences", words.len().to_formatted_string(&Locale::en));
|
||||||
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
|
let (mut pair_counts, mut where_to_update) = count_pairs_parallel(&words, &counts);
|
||||||
|
|
||||||
// ---- Build heap ----
|
// ---- Build heap ----
|
||||||
log::info!("Building heap with {} unique pairs", pair_counts.len());
|
log::info!("Building heap with {} unique pairs", pair_counts.len().to_formatted_string(&Locale::en));
|
||||||
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
|
let mut heap = OctonaryHeap::with_capacity(pair_counts.len());
|
||||||
for (pair, pos) in where_to_update.drain() {
|
for (pair, pos) in where_to_update.drain() {
|
||||||
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
let c = *pair_counts.get(&pair).unwrap_or(&0);
|
||||||
|
|
@ -375,7 +376,7 @@ impl Tokenizer {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log::info!("Processed {} sequences total, {} unique", total_sequences, counts.len());
|
log::info!("Processed {} sequences total, {} unique", total_sequences.to_formatted_string(&Locale::en), counts.len().to_formatted_string(&Locale::en));
|
||||||
|
|
||||||
// Materialize words & counts
|
// Materialize words & counts
|
||||||
let mut words = Vec::with_capacity(counts.len());
|
let mut words = Vec::with_capacity(counts.len());
|
||||||
|
|
|
||||||
|
|
@ -196,9 +196,9 @@ RESET = '\033[0m'
|
||||||
|
|
||||||
# Print vocab sizes
|
# Print vocab sizes
|
||||||
print(f"\nVocab sizes:")
|
print(f"\nVocab sizes:")
|
||||||
print(f"GPT-2: {vocab_sizes['gpt2']}")
|
print(f"GPT-2: {vocab_sizes['gpt2']:,}")
|
||||||
print(f"GPT-4: {vocab_sizes['gpt4']}")
|
print(f"GPT-4: {vocab_sizes['gpt4']:,}")
|
||||||
print(f"Ours: {vocab_sizes['ours']}")
|
print(f"Ours: {vocab_sizes['ours']:,}")
|
||||||
|
|
||||||
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
def print_comparison(baseline_name, baseline_results, ours_results, all_text):
|
||||||
"""Print comparison table between baseline tokenizer and ours."""
|
"""Print comparison table between baseline tokenizer and ours."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user