mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 12:52:16 +00:00
Merge pull request #9 from Dianababaei/feat/enable-torch-compile-chat-sft-fixed-shapes
Update chat SFT training script configuration and parameters
This commit is contained in:
commit
0af8c8af68
|
|
@ -54,7 +54,6 @@ eval_metrics_every = 200
|
||||||
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
|
||||||
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
exec(open(os.path.join('nanochat', 'configurator.py')).read()) # overrides from command line or config file
|
||||||
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
user_config = {k: globals()[k] for k in config_keys} # possibly useful for logging
|
||||||
max_seq_len = 2048 # Maximum sequence length for fixed padding (enables torch.compile)
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
# Compute init
|
# Compute init
|
||||||
|
|
@ -70,7 +69,7 @@ wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sf
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||||
orig_model = model # original, uncompiled model
|
orig_model = model # original, uncompiled model
|
||||||
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
|
model = torch.compile(model, dynamic=False) # Fixed shapes enable efficient compilation
|
||||||
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
engine = Engine(model, tokenizer) # will be used for inline model evaluation only
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
@ -87,16 +86,16 @@ val_ds = SmolTalk(split="test") # general conversations, 24K rows (though we don
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# DataLoader
|
# DataLoader
|
||||||
|
|
||||||
def sft_data_generator(dataset, batch_size, max_seq_len=2048):
|
def sft_data_generator(dataset, batch_size):
|
||||||
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
pad_token_id = tokenizer.encode_special("<|assistant_end|>") # use <|assistant_end|> as the pad token is ok, these positions are masked in the loss
|
||||||
# prepares a list of tokenized conversations into a batch and yields
|
# prepares a list of tokenized conversations into a batch and yields
|
||||||
def collate_and_yield(batch):
|
def collate_and_yield(batch):
|
||||||
nrows = len(batch)
|
nrows = len(batch)
|
||||||
ncols = max_seq_len
|
ncols = max(len(ids) for ids, mask in batch) - 1 # seq of n creates inputs/targets of n-1
|
||||||
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
inputs = torch.full((nrows, ncols), pad_token_id, dtype=torch.long)
|
||||||
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
targets = torch.full((nrows, ncols), -1, dtype=torch.long) # -1 is ignore index
|
||||||
for i, (ids, mask) in enumerate(batch):
|
for i, (ids, mask) in enumerate(batch):
|
||||||
n = min(len(ids), max_seq_len + 1)
|
n = len(ids)
|
||||||
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
ids_tensor = torch.tensor(ids, dtype=torch.long)
|
||||||
inputs[i, :n-1] = ids_tensor[:-1]
|
inputs[i, :n-1] = ids_tensor[:-1]
|
||||||
# recall -1 is the ignore index, so mask out targets where mask is 0
|
# recall -1 is the ignore index, so mask out targets where mask is 0
|
||||||
|
|
@ -131,8 +130,8 @@ num_iterations = (len(train_ds) // target_examples_per_step) * num_epochs
|
||||||
if max_iterations >= 0 and num_iterations > max_iterations:
|
if max_iterations >= 0 and num_iterations > max_iterations:
|
||||||
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
print0(f"Number of iterations is too high: {num_iterations}, capping to {max_iterations}")
|
||||||
num_iterations = max_iterations
|
num_iterations = max_iterations
|
||||||
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size, max_seq_len=max_seq_len)
|
train_loader = sft_data_generator(train_ds, batch_size=device_batch_size)
|
||||||
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size, max_seq_len=max_seq_len)
|
build_val_loader = lambda: sft_data_generator(val_ds, batch_size=device_batch_size)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Initialize the Optimizer
|
# Initialize the Optimizer
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user