mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
Compare commits
5 Commits
81f6a358cb
...
4450ddf07e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4450ddf07e | ||
|
|
4a87a0d19f | ||
|
|
11e68bf442 | ||
|
|
09f1f4283e | ||
|
|
18590307ae |
|
|
@ -244,7 +244,7 @@ class GPT(nn.Module):
|
||||||
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
def forward(self, idx, targets=None, kv_cache=None, loss_reduction='mean'):
|
||||||
B, T = idx.size()
|
B, T = idx.size()
|
||||||
|
|
||||||
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim))
|
# Grab the rotary embeddings for the current sequence length (they are of shape (1, seq_len, 1, head_dim/2))
|
||||||
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
assert T <= self.cos.size(1), f"Sequence length grew beyond the rotary embeddings cache: {T} > {self.cos.size(1)}"
|
||||||
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
assert idx.device == self.cos.device, f"Rotary embeddings and idx are on different devices: {idx.device} != {self.cos.device}"
|
||||||
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
assert self.cos.dtype == torch.bfloat16, "Rotary embeddings must be in bfloat16"
|
||||||
|
|
|
||||||
30
rustbpe/Cargo.lock
generated
30
rustbpe/Cargo.lock
generated
|
|
@ -230,11 +230,10 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3"
|
name = "pyo3"
|
||||||
version = "0.23.5"
|
version = "0.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7778bffd85cf38175ac1f545509665d0b9b92a198ca7941f131f85f7a4f9a872"
|
checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
|
||||||
"indoc",
|
"indoc",
|
||||||
"libc",
|
"libc",
|
||||||
"memoffset",
|
"memoffset",
|
||||||
|
|
@ -248,19 +247,18 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-build-config"
|
name = "pyo3-build-config"
|
||||||
version = "0.23.5"
|
version = "0.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94f6cbe86ef3bf18998d9df6e0f3fc1050a8c5efa409bf712e661a4366e010fb"
|
checksum = "4fc6ddaf24947d12a9aa31ac65431fb1b851b8f4365426e182901eabfb87df5f"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"once_cell",
|
|
||||||
"target-lexicon",
|
"target-lexicon",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-ffi"
|
name = "pyo3-ffi"
|
||||||
version = "0.23.5"
|
version = "0.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e9f1b4c431c0bb1c8fb0a338709859eed0d030ff6daa34368d3b152a63dfdd8d"
|
checksum = "025474d3928738efb38ac36d4744a74a400c901c7596199e20e45d98eb194105"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"pyo3-build-config",
|
"pyo3-build-config",
|
||||||
|
|
@ -268,9 +266,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-log"
|
name = "pyo3-log"
|
||||||
version = "0.12.4"
|
version = "0.13.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "45192e5e4a4d2505587e27806c7b710c231c40c56f3bfc19535d0bb25df52264"
|
checksum = "d359e20231345f21a3b5b6aea7e73f4dc97e1712ef3bfe2d88997ac6a308d784"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"arc-swap",
|
"arc-swap",
|
||||||
"log",
|
"log",
|
||||||
|
|
@ -279,9 +277,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-macros"
|
name = "pyo3-macros"
|
||||||
version = "0.23.5"
|
version = "0.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fbc2201328f63c4710f68abdf653c89d8dbc2858b88c5d88b0ff38a75288a9da"
|
checksum = "2e64eb489f22fe1c95911b77c44cc41e7c19f3082fc81cce90f657cdc42ffded"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"pyo3-macros-backend",
|
"pyo3-macros-backend",
|
||||||
|
|
@ -291,9 +289,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyo3-macros-backend"
|
name = "pyo3-macros-backend"
|
||||||
version = "0.23.5"
|
version = "0.26.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "fca6726ad0f3da9c9de093d6f116a93c1a38e417ed73bf138472cf4064f72028"
|
checksum = "100246c0ecf400b475341b8455a9213344569af29a3c841d29270e53102e0fcf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"heck",
|
"heck",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
|
|
@ -400,9 +398,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "target-lexicon"
|
name = "target-lexicon"
|
||||||
version = "0.12.16"
|
version = "0.13.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
|
checksum = "df7f62577c25e07834649fc3b39fafdc597c0a3527dc1c60129201ccfcbaa50c"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ dary_heap = "0.3"
|
||||||
indexmap = "2.2"
|
indexmap = "2.2"
|
||||||
fancy-regex = "0.16.1"
|
fancy-regex = "0.16.1"
|
||||||
log = "0.4.28"
|
log = "0.4.28"
|
||||||
pyo3 = { version = "0.23.3", features = ["extension-module"] }
|
pyo3 = { version = "0.26", features = ["extension-module"] }
|
||||||
pyo3-log = "0.12.4"
|
pyo3-log = "0.13"
|
||||||
ahash = "0.8.12"
|
ahash = "0.8.12"
|
||||||
rayon = "1.11.0"
|
rayon = "1.11.0"
|
||||||
compact_str = "0.9.0"
|
compact_str = "0.9.0"
|
||||||
|
|
|
||||||
|
|
@ -277,7 +277,7 @@ impl Tokenizer {
|
||||||
pub fn train_from_iterator(
|
pub fn train_from_iterator(
|
||||||
&mut self,
|
&mut self,
|
||||||
py: pyo3::Python<'_>,
|
py: pyo3::Python<'_>,
|
||||||
iterator: &pyo3::Bound<'_, pyo3::PyAny>,
|
iterator: pyo3::Py<pyo3::types::PyIterator>,
|
||||||
vocab_size: u32,
|
vocab_size: u32,
|
||||||
buffer_size: usize,
|
buffer_size: usize,
|
||||||
pattern: Option<String>,
|
pattern: Option<String>,
|
||||||
|
|
@ -290,11 +290,6 @@ impl Tokenizer {
|
||||||
self.compiled_pattern = Regex::new(&pattern_str)
|
self.compiled_pattern = Regex::new(&pattern_str)
|
||||||
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("Invalid regex pattern: {}", e)))?;
|
||||||
|
|
||||||
// Prepare a true Python iterator object
|
|
||||||
let py_iter: pyo3::Py<pyo3::PyAny> = unsafe {
|
|
||||||
pyo3::Py::from_owned_ptr_or_err(py, pyo3::ffi::PyObject_GetIter(iterator.as_ptr()))?
|
|
||||||
};
|
|
||||||
|
|
||||||
// Global chunk counts
|
// Global chunk counts
|
||||||
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
let mut counts: AHashMap<CompactString, i32> = AHashMap::new();
|
||||||
|
|
||||||
|
|
@ -307,28 +302,21 @@ impl Tokenizer {
|
||||||
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
|
// Helper: refill `buf` with up to `buffer_size` strings from the Python iterator.
|
||||||
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
|
// Returns Ok(true) if the iterator is exhausted, Ok(false) otherwise.
|
||||||
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
|
let refill = |buf: &mut Vec<String>| -> PyResult<bool> {
|
||||||
pyo3::Python::with_gil(|py| {
|
pyo3::Python::attach(|py| {
|
||||||
buf.clear();
|
buf.clear();
|
||||||
let it = py_iter.bind(py);
|
let mut it = iterator.bind(py).clone();
|
||||||
loop {
|
loop {
|
||||||
if buf.len() >= buffer_size {
|
if buf.len() >= buffer_size {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
// next(it)
|
match it.next() {
|
||||||
let next_obj = unsafe {
|
Some(Ok(obj)) => {
|
||||||
pyo3::Bound::from_owned_ptr_or_opt(py, pyo3::ffi::PyIter_Next(it.as_ptr()))
|
|
||||||
};
|
|
||||||
match next_obj {
|
|
||||||
Some(obj) => {
|
|
||||||
let s: String = obj.extract()?;
|
let s: String = obj.extract()?;
|
||||||
buf.push(s);
|
buf.push(s);
|
||||||
}
|
},
|
||||||
|
Some(Err(e)) => return Err(e),
|
||||||
None => {
|
None => {
|
||||||
if pyo3::PyErr::occurred(py) {
|
return Ok(true); // exhausted
|
||||||
return Err(pyo3::PyErr::fetch(py));
|
|
||||||
} else {
|
|
||||||
return Ok(true); // exhausted
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -345,7 +333,7 @@ impl Tokenizer {
|
||||||
total_sequences += buf.len() as u64;
|
total_sequences += buf.len() as u64;
|
||||||
|
|
||||||
let pattern = self.compiled_pattern.clone();
|
let pattern = self.compiled_pattern.clone();
|
||||||
let local: AHashMap<CompactString, i32> = py.allow_threads(|| {
|
let local: AHashMap<CompactString, i32> = py.detach(|| {
|
||||||
buf.par_iter()
|
buf.par_iter()
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
let mut m: AHashMap<CompactString, i32> = AHashMap::new();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user