mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-01 21:25:21 +00:00
Enhance error handling in dataset and training scripts
- Update file removal logic in dataset.py to log warnings for OSError and PermissionError. - Improve assertion messages in gpt.py, base_train.py, mid_train.py, chat_rl.py, tok_eval.py, tok_train.py, and test_rustbpe.py to provide clearer context on assertion failures.
This commit is contained in:
parent
bfd8d21313
commit
f9dd11fefe
|
|
@ -95,8 +95,8 @@ def download_single_file(index):
|
|||
if os.path.exists(path):
|
||||
try:
|
||||
os.remove(path)
|
||||
except:
|
||||
pass
|
||||
except (OSError, PermissionError) as e:
|
||||
logger.warning(f"Could not remove file: {e}")
|
||||
# Try a few times with exponential backoff: 2^attempt seconds
|
||||
if attempt < max_attempts:
|
||||
wait_time = 2 ** attempt
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def norm(x):
|
|||
|
||||
|
||||
def apply_rotary_emb(x, cos, sin):
|
||||
assert x.ndim == 4 # multihead attention
|
||||
assert x.ndim == 4, f"Expected 4D tensor for multihead attention, got {x.ndim}D tensor with shape {x.shape}" # multihead attention
|
||||
d = x.shape[3] // 2
|
||||
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
|
||||
y1 = x1 * cos + x2 * sin # rotate pairs of dims
|
||||
|
|
@ -69,8 +69,8 @@ class CausalSelfAttention(nn.Module):
|
|||
self.n_kv_head = config.n_kv_head
|
||||
self.n_embd = config.n_embd
|
||||
self.head_dim = self.n_embd // self.n_head
|
||||
assert self.n_embd % self.n_head == 0
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
||||
assert self.n_embd % self.n_head == 0, f"n_embd must be divisible by n_head"
|
||||
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0, f"MQA constraints violated: n_kv_head must be <= n_head and n_head must be divisible by n_kv_head"
|
||||
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
||||
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ print0(f"num_kv_heads: {num_kv_heads}")
|
|||
# figure out the needed gradient accumulation to reach the desired total batch size
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})"
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
|
|
@ -106,7 +106,7 @@ num_flops_per_token = model.estimate_flops()
|
|||
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
||||
|
||||
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
|
||||
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0, "At least one of num_iterations, target_param_data_ratio, or target_flops must be > 0"
|
||||
if num_iterations > 0:
|
||||
print0(f"Using user-provided number of iterations: {num_iterations:,}")
|
||||
elif target_flops > 0:
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
|
|||
tokens = tokenizer.render_for_completion(conversation)
|
||||
prefix_length = len(tokens)
|
||||
# Generate k samples using batched generation inside the Engine
|
||||
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
|
||||
assert num_samples <= device_batch_size, f"num_samples ({num_samples}) must be <= device_batch_size ({device_batch_size})" # usually this is true. we can add a loop if not...
|
||||
generated_token_sequences, masks = engine.generate_batch(
|
||||
tokens,
|
||||
num_samples=num_samples,
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ depth = model.config.n_layer
|
|||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
|
||||
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0
|
||||
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})"
|
||||
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
|
||||
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
|
||||
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
|
||||
|
|
@ -108,7 +108,7 @@ def mid_data_generator(split):
|
|||
assert split in {"train", "val"}, "split must be 'train' or 'val'"
|
||||
dataset = train_dataset if split == "train" else val_dataset
|
||||
dataset_size = len(dataset)
|
||||
assert dataset_size > 0
|
||||
assert dataset_size > 0, f"Dataset size must be > 0, got {dataset_size} for split '{split}'"
|
||||
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
|
||||
token_buffer = deque()
|
||||
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class BasicTokenizer(Tokenizer):
|
|||
super().__init__()
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
assert vocab_size >= 256
|
||||
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# input text preprocessing
|
||||
|
|
@ -179,7 +179,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
|
|||
for name, text in all_text:
|
||||
encoded = tokenizer.encode(text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == text
|
||||
assert decoded == text, f"Decode-encode roundtrip failed: decoded != original"
|
||||
|
||||
encoded_bytes = text.encode('utf-8')
|
||||
ratio = len(encoded_bytes) / len(encoded)
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ Special chars: @#$%^&*()
|
|||
Unicode: 你好世界 🌍"""
|
||||
encoded = tokenizer.encode(test_text)
|
||||
decoded = tokenizer.decode(encoded)
|
||||
assert decoded == test_text
|
||||
assert decoded == test_text, f"Decode-encode roundtrip failed: decoded != original"
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# One more thing: we wish to cache a mapping from token id to number of bytes of that token
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class RegexTokenizer:
|
|||
return vocab
|
||||
|
||||
def train(self, text, vocab_size, verbose=False):
|
||||
assert vocab_size >= 256
|
||||
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
|
||||
|
|
@ -215,7 +215,7 @@ class FastRegexTokenizer:
|
|||
- collapse identical chunks to just the unique ones
|
||||
- update counts more cleverly - only around the affected chunks
|
||||
"""
|
||||
assert vocab_size >= 256
|
||||
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
|
||||
num_merges = vocab_size - 256
|
||||
|
||||
# split the text up into text chunks
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user