From 22f8d0234595297e405787f1155b059b7d2db056 Mon Sep 17 00:00:00 2001 From: Hossein-Lakzaei Date: Wed, 15 Oct 2025 17:15:59 +0330 Subject: [PATCH] Enhance error handling in dataset and training scripts - Update file removal logic in dataset.py to log warnings for OSError and PermissionError. - Improve assertion messages in gpt.py, base_train.py, mid_train.py, chat_rl.py, tok_eval.py, tok_train.py, and test_rustbpe.py to provide clearer context on assertion failures. --- nanochat/dataset.py | 4 ++-- nanochat/gpt.py | 6 +++--- scripts/base_train.py | 4 ++-- scripts/chat_rl.py | 2 +- scripts/mid_train.py | 4 ++-- scripts/tok_eval.py | 4 ++-- scripts/tok_train.py | 2 +- tests/test_rustbpe.py | 4 ++-- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/nanochat/dataset.py b/nanochat/dataset.py index 602daed..1d27a68 100644 --- a/nanochat/dataset.py +++ b/nanochat/dataset.py @@ -95,8 +95,8 @@ def download_single_file(index): if os.path.exists(path): try: os.remove(path) - except: - pass + except (OSError, PermissionError) as e: + logger.warning(f"Could not remove file: {e}") # Try a few times with exponential backoff: 2^attempt seconds if attempt < max_attempts: wait_time = 2 ** attempt diff --git a/nanochat/gpt.py b/nanochat/gpt.py index 5a066b2..f85ad6f 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -39,7 +39,7 @@ def norm(x): def apply_rotary_emb(x, cos, sin): - assert x.ndim == 4 # multihead attention + assert x.ndim == 4, f"Expected 4D tensor for multihead attention, got {x.ndim}D tensor with shape {x.shape}" # multihead attention d = x.shape[3] // 2 x1, x2 = x[..., :d], x[..., d:] # split up last time into two halves y1 = x1 * cos + x2 * sin # rotate pairs of dims @@ -69,8 +69,8 @@ class CausalSelfAttention(nn.Module): self.n_kv_head = config.n_kv_head self.n_embd = config.n_embd self.head_dim = self.n_embd // self.n_head - assert self.n_embd % self.n_head == 0 - assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0 + assert self.n_embd % self.n_head == 0, f"n_embd must be divisible by n_head" + assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0, f"MQA constraints violated: n_kv_head must be <= n_head and n_head must be divisible by n_kv_head" self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False) self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False) diff --git a/scripts/base_train.py b/scripts/base_train.py index b691ed4..86707ba 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -85,7 +85,7 @@ print0(f"num_kv_heads: {num_kv_heads}") # figure out the needed gradient accumulation to reach the desired total batch size tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 +assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})" grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") @@ -106,7 +106,7 @@ num_flops_per_token = model.estimate_flops() print0(f"Estimated FLOPs per token: {num_flops_per_token:e}") # Calculate number of iterations. Either it is given, or from target flops, or from target data:param ratio (in that order) -assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0 +assert num_iterations > 0 or target_param_data_ratio > 0 or target_flops > 0, "At least one of num_iterations, target_param_data_ratio, or target_flops must be > 0" if num_iterations > 0: print0(f"Using user-provided number of iterations: {num_iterations:,}") elif target_flops > 0: diff --git a/scripts/chat_rl.py b/scripts/chat_rl.py index af70bda..00446d1 100644 --- a/scripts/chat_rl.py +++ b/scripts/chat_rl.py @@ -160,7 +160,7 @@ def run_gsm8k_eval(task, tokenizer, engine, tokens = tokenizer.render_for_completion(conversation) prefix_length = len(tokens) # Generate k samples using batched generation inside the Engine - assert num_samples <= device_batch_size # usually this is true. we can add a loop if not... + assert num_samples <= device_batch_size, f"num_samples ({num_samples}) must be <= device_batch_size ({device_batch_size})" # usually this is true. we can add a loop if not... generated_token_sequences, masks = engine.generate_batch( tokens, num_samples=num_samples, diff --git a/scripts/mid_train.py b/scripts/mid_train.py index 202682d..2ae0bd1 100644 --- a/scripts/mid_train.py +++ b/scripts/mid_train.py @@ -70,7 +70,7 @@ depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank world_tokens_per_fwdbwd = tokens_per_fwdbwd * ddp_world_size # total tokens per iteration for all ranks -assert total_batch_size % world_tokens_per_fwdbwd == 0 +assert total_batch_size % world_tokens_per_fwdbwd == 0, f"total_batch_size ({total_batch_size}) must be divisible by world_tokens_per_fwdbwd ({world_tokens_per_fwdbwd})" grad_accum_steps = total_batch_size // world_tokens_per_fwdbwd print0(f"Tokens / micro-batch / rank: {device_batch_size} x {max_seq_len} = {tokens_per_fwdbwd:,}") print0(f"Tokens / micro-batch: {world_tokens_per_fwdbwd:,}") @@ -108,7 +108,7 @@ def mid_data_generator(split): assert split in {"train", "val"}, "split must be 'train' or 'val'" dataset = train_dataset if split == "train" else val_dataset dataset_size = len(dataset) - assert dataset_size > 0 + assert dataset_size > 0, f"Dataset size must be > 0, got {dataset_size} for split '{split}'" needed_tokens = device_batch_size * max_seq_len + 1 # to form one training batch of inputs,targets token_buffer = deque() scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True) diff --git a/scripts/tok_eval.py b/scripts/tok_eval.py index 9233d71..c79001f 100644 --- a/scripts/tok_eval.py +++ b/scripts/tok_eval.py @@ -37,7 +37,7 @@ class BasicTokenizer(Tokenizer): super().__init__() def train(self, text, vocab_size, verbose=False): - assert vocab_size >= 256 + assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}" num_merges = vocab_size - 256 # input text preprocessing @@ -179,7 +179,7 @@ for tokenizer_name in ["gpt2", "gpt4", "ours"]: for name, text in all_text: encoded = tokenizer.encode(text) decoded = tokenizer.decode(encoded) - assert decoded == text + assert decoded == text, f"Decode-encode roundtrip failed: decoded != original" encoded_bytes = text.encode('utf-8') ratio = len(encoded_bytes) / len(encoded) diff --git a/scripts/tok_train.py b/scripts/tok_train.py index c2faf17..7e674d2 100644 --- a/scripts/tok_train.py +++ b/scripts/tok_train.py @@ -66,7 +66,7 @@ Special chars: @#$%^&*() Unicode: 你好世界 🌍""" encoded = tokenizer.encode(test_text) decoded = tokenizer.decode(encoded) -assert decoded == test_text +assert decoded == test_text, f"Decode-encode roundtrip failed: decoded != original" # ----------------------------------------------------------------------------- # One more thing: we wish to cache a mapping from token id to number of bytes of that token diff --git a/tests/test_rustbpe.py b/tests/test_rustbpe.py index 5f95721..90acaa7 100644 --- a/tests/test_rustbpe.py +++ b/tests/test_rustbpe.py @@ -84,7 +84,7 @@ class RegexTokenizer: return vocab def train(self, text, vocab_size, verbose=False): - assert vocab_size >= 256 + assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}" num_merges = vocab_size - 256 # keep track of whether at any point during training the merge is ambiguous (counts of pairs are not unique) @@ -215,7 +215,7 @@ class FastRegexTokenizer: - collapse identical chunks to just the unique ones - update counts more cleverly - only around the affected chunks """ - assert vocab_size >= 256 + assert vocab_size >= 256, f"vocab_size must be >= 256, got {vocab_size}" num_merges = vocab_size - 256 # split the text up into text chunks