mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 12:22:18 +00:00
Fix Torch crash caused by pinning on CPU
This commit is contained in:
parent
2e938530ce
commit
defd1246aa
|
|
@ -38,7 +38,8 @@ def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokeniz
|
||||||
batch_index += 1
|
batch_index += 1
|
||||||
# Move tokens from the deque into the scratch buffer
|
# Move tokens from the deque into the scratch buffer
|
||||||
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
tokens = [token_buffer.popleft() for _ in range(needed_tokens)]
|
||||||
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=True)
|
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||||
|
scratch = torch.tensor(tokens, dtype=torch.int64, pin_memory=(device == "cuda"))
|
||||||
# Create the inputs/targets as 1D tensors
|
# Create the inputs/targets as 1D tensors
|
||||||
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
inputs_cpu = scratch[:-1].to(dtype=torch.int32)
|
||||||
targets_cpu = scratch[1:]
|
targets_cpu = scratch[1:]
|
||||||
|
|
|
||||||
|
|
@ -119,7 +119,8 @@ def mid_data_generator(split):
|
||||||
assert dataset_size > 0
|
assert dataset_size > 0
|
||||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||||
token_buffer = deque()
|
token_buffer = deque()
|
||||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
# CUDA supports memory pinning for faster transfers between CPU and GPU:
|
||||||
|
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=(device_type == "cuda"))
|
||||||
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
cursor = ddp_rank # increments by ddp_world_size each time, so each rank processes unique documents
|
||||||
it = 0 # iteration counter
|
it = 0 # iteration counter
|
||||||
while True:
|
while True:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user