From d8be015b207ade6a8baa3ce9c6cba4139190222b Mon Sep 17 00:00:00 2001 From: Artemis Git Integration Date: Wed, 5 Nov 2025 16:04:26 +0000 Subject: [PATCH] feat(chat_sft): add fixed-length padding for torch.compile compatibility Replace variable-length padding with fixed 2048-token padding to create constant batch shapes, enabling efficient torch.compile in subsequent training steps --- scripts/chat_sft.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index b5ba49a..86415c9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -54,6 +54,7 @@ eval_metrics_every = 200 config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging +max_seq_len = 2048 # Maximum sequence length for fixed padding (enables torch.compile) # ----------------------------------------------------------------------------- # Compute init @@ -86,16 +87,16 @@ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don # ----------------------------------------------------------------------------- # DataLoader -def sft_data_generator(dataset, batch_size): +def sft_data_generator(dataset, batch_size, max_seq_len=2048): pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss # prepares a list of tokenized conversations into a batch and yields def collate_and_yield(batch): nrows = len(batch) - ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1 + ncols = max_seq_len inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long) targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index for i, (ids, mask) in enumerate(batch): - n = len(ids) + n = min(len(ids), max_seq_len + 1) ids_tensor = torch.tensor(ids, dtype=torch.long) inputs[i, :n-1] = ids_tensor[:-1] # recall -1 is the ignore index, so mask out targets where mask is 0 @@ -130,8 +131,8 @@ num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs if max_iterations >= 0 and num_iterations > max_iterations: print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}") num_iterations = max_iterations -train_loader = sft_data_generator(train_ds, batch_size=device_batch_size) -build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size) +train_loader = sft_data_generator(train_ds, batch_size=device_batch_size, max_seq_len=max_seq_len) +build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size, max_seq_len=max_seq_len) # ----------------------------------------------------------------------------- # Initialize the Optimizer