diff --git a/nanochat/gpt.py b/nanochat/gpt.py index fb77d77..979ff47 100644 --- a/nanochat/gpt.py +++ b/nanochat/gpt.py @@ -355,7 +355,7 @@ class GPT(nn.Module): 'total': total, } - def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5): + def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02, weight_decay=0.0, scalar_lr=0.5, device_type="cuda"): model_dim = self.config.n_embd ddp, rank, local_rank, world_size = get_dist_info() @@ -390,7 +390,7 @@ class GPT(nn.Module): )) Factory = DistMuonAdamW if ddp else MuonAdamW - optimizer = Factory(param_groups) + optimizer = Factory(param_groups, device_type) for group in optimizer.param_groups: group["initial_lr"] = group["lr"] return optimizer diff --git a/nanochat/optim.py b/nanochat/optim.py index 42b4949..5f67417 100644 --- a/nanochat/optim.py +++ b/nanochat/optim.py @@ -176,11 +176,11 @@ class MuonAdamW(torch.optim.Optimizer): - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' """ - def __init__(self, param_groups: list[dict]): + def __init__(self, param_groups: list[dict], device_type: str = "cpu"): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change # AdamW tensors - device="mps" + device=device_type self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device) self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device) self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device) @@ -358,19 +358,19 @@ class DistMuonAdamW(torch.optim.Optimizer): - For AdamW groups: 'lr', 'betas', 'eps', 'weight_decay' - For Muon groups: 'lr', 'momentum', 'ns_steps', 'beta2', 'weight_decay' """ - def __init__(self, param_groups: list[dict]): + def __init__(self, param_groups: list[dict], device_type: str = "cpu"): super().__init__(param_groups, defaults={}) # 0-D CPU tensors to avoid torch.compile recompilation when values change - self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") - self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu") + self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) + self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device=device_type) def _reduce_adamw(self, group: dict, world_size: int) -> dict: """Launch async reduce ops for AdamW group. Returns info dict with per-param infos.""" diff --git a/scripts/base_train.py b/scripts/base_train.py index 2f44196..a9a4534 100644 --- a/scripts/base_train.py +++ b/scripts/base_train.py @@ -83,6 +83,8 @@ user_config = vars(args).copy() # for logging # Compute init and wandb logging device_type = autodetect_device_type() if args.device_type == "" else args.device_type + + ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None @@ -312,6 +314,7 @@ optimizer = model.setup_optimizer( # Muon hyperparameters matrix_lr=args.matrix_lr * batch_lr_scale, weight_decay=weight_decay_scaled, + device_type=device_type if device_type != "cuda" else "cpu" # since k keeps optim in cpu ) # optimizer.to(device) if resuming: diff --git a/scripts/chat_sft.py b/scripts/chat_sft.py index c1adbb6..c8382f9 100644 --- a/scripts/chat_sft.py +++ b/scripts/chat_sft.py @@ -117,7 +117,7 @@ for name, fallback, source in [ print0(f"Using {name}={arg_val}") orig_model = model -model = torch.compile(model, dynamic=False) +# model = torch.compile(model, dynamic=False) depth = model.config.n_layer num_flops_per_token = model.estimate_flops() tokens_per_fwdbwd = args.device_batch_size * args.max_seq_len # tokens per iteration for a single rank @@ -131,7 +131,7 @@ token_bytes = get_token_bytes(device=device) # Initialize the Optimizer (combined MuonAdamW: Muon for matrix params, AdamW for rest) # Note that pretraining ramps weight_decay to zero by end of pretraining, so SFT continues with zero -optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0) +optimizer = model.setup_optimizer(unembedding_lr=args.unembedding_lr, embedding_lr=args.embedding_lr, matrix_lr=args.matrix_lr, weight_decay=0.0, device_type=device_type) # Optionally warm-start optimizer from pretrained checkpoint (momentum buffers etc.) # Note: load_state_dict overwrites param_group metadata (LRs, betas, etc.) with the