mirror of
https://github.com/karpathy/nanochat.git
synced 2026-04-02 05:35:19 +00:00
chat sft optim fix
This commit is contained in:
parent
bf19cb325c
commit
e9bf1a5a67
|
|
@ -355,7 +355,7 @@ class GPT(nn.Module):
|
|||
'total': total,
|
||||
}
|
||||
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5):
|
||||
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, device_type="cuda"):
|
||||
model_dim = self.config.n_embd
|
||||
ddp, rank, local_rank, world_size = get_dist_info()
|
||||
|
||||
|
|
@ -390,7 +390,7 @@ class GPT(nn.Module):
|
|||
))
|
||||
|
||||
Factory = DistMuonAdamW if ddp else MuonAdamW
|
||||
optimizer = Factory(param_groups)
|
||||
optimizer = Factory(param_groups, device_type)
|
||||
for group in optimizer.param_groups:
|
||||
group["initial_lr"] = group["lr"]
|
||||
return optimizer
|
||||
|
|
|
|||
|
|
@ -176,11 +176,11 @@ class MuonAdamW(torch.optim.Optimizer):
|
|||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
def __init__(self, param_groups: list[dict], device_type: str = "cpu"):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
# AdamW tensors
|
||||
device="mps"
|
||||
device=device_type
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device)
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device)
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device)
|
||||
|
|
@ -358,19 +358,19 @@ class DistMuonAdamW(torch.optim.Optimizer):
|
|||
- For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay'
|
||||
- For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay'
|
||||
"""
|
||||
def __init__(self, param_groups: list[dict]):
|
||||
def __init__(self, param_groups: list[dict], device_type: str = "cpu"):
|
||||
super().__init__(param_groups, defaults={})
|
||||
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
||||
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type)
|
||||
|
||||
def _reduce_adamw(self, group: dict, world_size: int) -> dict:
|
||||
"""Launch async reduce ops for AdamW group. Returns info dict with per-param infos."""
|
||||
|
|
|
|||
|
|
@ -83,6 +83,8 @@ user_config = vars(args).copy() # for logging
|
|||
# Compute init and wandb logging
|
||||
|
||||
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
|
||||
|
||||
|
||||
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
|
||||
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
|
||||
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
|
||||
|
|
@ -312,6 +314,7 @@ optimizer = model.setup_optimizer(
|
|||
# Muon hyperparameters
|
||||
matrix_lr=args.matrix_lr * batch_lr_scale,
|
||||
weight_decay=weight_decay_scaled,
|
||||
device_type=device_type if device_type != "cuda" else "cpu" # since k keeps optim in cpu
|
||||
)
|
||||
# optimizer.to(device)
|
||||
if resuming:
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ for name, fallback, source in [
|
|||
print0(f"Using {name}={arg_val}")
|
||||
|
||||
orig_model = model
|
||||
model = torch.compile(model, dynamic=False)
|
||||
# model = torch.compile(model, dynamic=False)
|
||||
depth = model.config.n_layer
|
||||
num_flops_per_token = model.estimate_flops()
|
||||
tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank
|
||||
|
|
@ -131,7 +131,7 @@ token_bytes = get_token_bytes(device=device)
|
|||
|
||||
# Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest)
|
||||
# Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0)
|
||||
optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type)
|
||||
|
||||
# Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.)
|
||||
# Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user