mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
fix silly issue in dataloader, this version is much faster and more portable to mps too
This commit is contained in:
parent
c9ea7a91e2
commit
bb71c64579
|
|
@ -16,7 +16,6 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
bos_token = tokenizer.get_bos_token_id()
|
||||
# scratch buffer holds the tokens for one iteration
|
||||
token_buffer = deque() # we stream tokens on the right and pop from the left
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
||||
# infinite iterator over document batches
|
||||
def document_batches():
|
||||
|
|
@ -38,8 +37,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
|||
token_buffer.extend(tokens)
|
||||
batch_index += 1
|
||||
# Move tokens from the deque into the scratch buffer
|
||||
for i in range(needed_tokens):
|
||||
scratch[i] = token_buffer.popleft()
|
||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
||||
# Create the inputs/targets as 1D tensors
|
||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||
targets_cpu = scratch[1:]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user