chat sft optim fix

This commit is contained in:
Sushrut Karnik 2026-03-13 01:11:24 +01:00
parent bf19cb325c
commit e9bf1a5a67
4 changed files with 20 additions and 17 deletions

View File

@ -355,7 +355,7 @@ class GPT(nn.Module):
'total': total,
}
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, device_type="cuda"):
model_dim = self.config.n_embd
ddp, rank, local_rank, world_size = get_dist_info()
@ -390,7 +390,7 @@ class GPT(nn.Module):
))
Factory = DistMuonAdamW if ddp else MuonAdamW
optimizer = Factory(param_groups)
optimizer = Factory(param_groups, device_type)
for group in optimizer.param_groups:
group["initial_lr"] = group["lr"]
return optimizer

View File

@ -176,11 +176,11 @@ class MuonAdamW(torch.optim.Optimizer):
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
def __init__(self, param_groups: list[dict], device_type: str = "cpu"):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
# AdamW tensors
device="mps"
device=device_type
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device)
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device)
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device)
@ -358,19 +358,19 @@ class DistMuonAdamW(torch.optim.Optimizer):
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
"""
def __init__(self, param_groups: list[dict]):
def __init__(self, param_groups: list[dict], device_type: str = "cpu"):
super().__init__(param_groups, defaults={})
# 0-D CPU tensors to avoid torch.compile recompilation when values change
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""

View File

@ -83,6 +83,8 @@ user_config = vars(args).copy() # for logging
# Compute init and wandb logging
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
@ -312,6 +314,7 @@ optimizer = model.setup_optimizer(
# Muon hyperparameters
matrix_lr=args.matrix_lr * batch_lr_scale,
weight_decay=weight_decay_scaled,
device_type=device_type if device_type != "cuda" else "cpu" # since k keeps optim in cpu
)
# optimizer.to(device)
if resuming:

View File

@ -117,7 +117,7 @@ for name, fallback, source in [
print0(f"Using {name}={arg_val}")
orig_model = model
model = torch.compile(model, dynamic=False)
# model = torch.compile(model, dynamic=False)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
@ -131,7 +131,7 @@ token_bytes = get_token_bytes(device=device)
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type)
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the