mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-07 12:52:16 +00:00
Refactor wandb initialization in chat_sft.py
This commit is contained in:
parent
523714b5c8
commit
b9f01eedd9
|
|
@ -12,12 +12,11 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft
|
||||||
import os
|
import os
|
||||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
|
||||||
import wandb
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, DummyWandb, autodetect_device_type
|
from nanochat.common import compute_init, compute_cleanup, get_base_dir, print0, get_wandb, autodetect_device_type
|
||||||
from nanochat.checkpoint_manager import load_model
|
from nanochat.checkpoint_manager import load_model
|
||||||
from nanochat.checkpoint_manager import save_checkpoint
|
from nanochat.checkpoint_manager import save_checkpoint
|
||||||
from nanochat.engine import Engine
|
from nanochat.engine import Engine
|
||||||
|
|
@ -69,8 +68,7 @@ ptdtype = torch.float32 if dtype == 'float32' else torch.bfloat16
|
||||||
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
|
||||||
|
|
||||||
# wandb logging init
|
# wandb logging init
|
||||||
use_dummy_wandb = run == "dummy" or not master_process
|
wandb_run = get_wandb("nanochat-sft" , run=run, master_process=master_process, user_config=user_config)
|
||||||
wandb_run = DummyWandb() if use_dummy_wandb else wandb.init(project="nanochat-sft", name=run, config=user_config, save_code=True)
|
|
||||||
|
|
||||||
# Load the model and tokenizer
|
# Load the model and tokenizer
|
||||||
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user