Enhance error handling in dataset and training scripts

- Update file removal logic in dataset.py to log warnings for OSError and PermissionError.
- Improve assertion messages in gpt.py, base_train.py, mid_train.py, chat_rl.py, tok_eval.py, tok_train.py, and test_rustbpe.py to provide clearer context on assertion failures.
This commit is contained in:
Hossein-Lakzaei 2025-10-15 17:15:59 +03:30
parent 5777e51288
commit 22f8d02345
8 changed files with 15 additions and 15 deletions

View File

@ -95,8 +95,8 @@ def download_single_file(index):
if os.path.exists(path):
try:
os.remove(path)
except:
pass
except (OSError, PermissionError) as e:
logger.warning(f"Could not remove file: {e}")
# Try a few times with exponential backoff: 2^attempt seconds
if attempt < max_attempts:
wait_time = 2 ** attempt

View File

@ -39,7 +39,7 @@ def norm(x):
def apply_rotary_emb(x, cos, sin):
assert x.ndim == 4 # multihead attention
assert x.ndim == 4, f"Expected 4D tensor for multihead attention, got {x.ndim}D tensor with shape {x.shape}" # multihead attention
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves
y1 = x1 * cos + x2 * sin # rotate pairs of dims
@ -69,8 +69,8 @@ class CausalSelfAttention(nn.Module):
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
assert self.n_embd % self.n_head == 0, f"n_embd must be divisible by n_head"
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0, f"MQA constraints violated: n_kv_head must be <= n_head and n_head must be divisible by n_kv_head"
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)

View File

@ -85,7 +85,7 @@ print0(f"num_kv_heads: {num_kv_heads}")
# figure out the needed gradient accumulation to reach the desired total batch size
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})"
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
@ -106,7 +106,7 @@ num_flops_per_token = model.estimate_flops()
print0(f"Estimated FLOPs per token: {num_flops_per_token:e}")
# Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order)
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0
assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0, "At least one of num_iterations, target_param_data_ratio, or target_flops must be > 0"
if num_iterations > 0:
print0(f"Using user-provided number of iterations: {num_iterations:,}")
elif target_flops > 0:

View File

@ -160,7 +160,7 @@ def run_gsm8k_eval(task, tokenizer, engine,
tokens = tokenizer.render_for_completion(conversation)
prefix_length = len(tokens)
# Generate k samples using batched generation inside the Engine
assert num_samples <= device_batch_size # usually this is true. we can add a loop if not...
assert num_samples <= device_batch_size, f"num_samples ({num_samples}) must be <= device_batch_size ({device_batch_size})" # usually this is true. we can add a loop if not...
generated_token_sequences, masks = engine.generate_batch(
tokens,
num_samples=num_samples,

View File

@ -70,7 +70,7 @@ depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks
assert total_batch_size % world_tokens_per_fwdbwd == 0
assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})"
grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd
print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}")
print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}")
@ -108,7 +108,7 @@ def mid_data_generator(split):
assert split in {"train", "val"}, "split must be 'train' or 'val'"
dataset = train_dataset if split == "train" else val_dataset
dataset_size = len(dataset)
assert dataset_size > 0
assert dataset_size > 0, f"Dataset size must be > 0, got {dataset_size} for split '{split}'"
needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets
token_buffer = deque()
scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)

View File

@ -37,7 +37,7 @@ class BasicTokenizer(Tokenizer):
super().__init__()
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
num_merges = vocab_size - 256
# input text preprocessing
@ -179,7 +179,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]:
for name, text in all_text:
encoded = tokenizer.encode(text)
decoded = tokenizer.decode(encoded)
assert decoded == text
assert decoded == text, f"Decode-encode roundtrip failed: decoded != original"
encoded_bytes = text.encode('utf-8')
ratio = len(encoded_bytes) / len(encoded)

View File

@ -66,7 +66,7 @@ Special chars: @#$%^&*()
Unicode: 你好世界 🌍"""
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
assert decoded == test_text
assert decoded == test_text, f"Decode-encode roundtrip failed: decoded != original"
# -----------------------------------------------------------------------------
# One more thing: we wish to cache a mapping from token id to number of bytes of that token

View File

@ -84,7 +84,7 @@ class RegexTokenizer:
return vocab
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= 256
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
num_merges = vocab_size - 256
# keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique)
@ -215,7 +215,7 @@ class FastRegexTokenizer:
- collapse identical chunks to just the unique ones
- update counts more cleverly - only around the affected chunks
"""
assert vocab_size >= 256
assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}"
num_merges = vocab_size - 256
# split the text up into text chunks