This commit is contained in:
Brian Edwards 2026-01-18 21:05:55 +02:00 committed by GitHub
commit d9a263bb5f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 17 additions and 17 deletions

View File

@ -308,7 +308,7 @@ if __name__ == "__main__":
# init compute
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init()
device_type = autodetect_device_type()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext()
# load the model and tokenizer
model, tokenizer, meta = load_model("base", device, phase="eval")

View File

@ -156,7 +156,7 @@ def main():
# distributed / precision setup
device_type = autodetect_device_type()
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext()
# Load model and tokenizer from command line or from file system
if args.hf_path is not None:

View File

@ -87,7 +87,7 @@ else:
token_bytes = get_token_bytes(device=device)
model_name = f"base_model (step {meta['step']})"
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext()
print0(f"Evaluating model: {model_name}")

View File

@ -79,9 +79,9 @@ user_config = vars(args).copy() # for logging
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=torch.bfloat16) if device_type != "cpu" else nullcontext()
synchronize = torch.cuda.synchronize if device_type != "cpu" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type != "cpu" else lambda: 0
if device_type == "cuda":
gpu_device_name = torch.cuda.get_device_name(0)
gpu_peak_flops = get_peak_flops(gpu_device_name)

View File

@ -27,7 +27,7 @@ args = parser.parse_args()
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
# Special tokens for the chat state machine

View File

@ -200,7 +200,7 @@ if __name__ == "__main__":
device_type = autodetect_device_type() if args.device_type == "" else args.device_type
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext()
model, tokenizer, meta = load_model(args.source, device, phase="eval", model_tag=args.model_tag, step=args.step)
engine = Engine(model, tokenizer)

View File

@ -69,7 +69,7 @@ device_type = autodetect_device_type() if args.device_type == "" else args.devic
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext()
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process

View File

@ -100,7 +100,7 @@ class WorkerPool:
def __init__(self, num_gpus: Optional[int] = None):
if num_gpus is None:
if device_type == "cuda":
if device_type != "cpu":
num_gpus = torch.cuda.device_count()
else:
num_gpus = 1 # e.g. cpu|mps
@ -112,11 +112,11 @@ class WorkerPool:
"""Load model on each GPU."""
print(f"Initializing worker pool with {self.num_gpus} GPUs...")
if self.num_gpus > 1:
assert device_type == "cuda", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
assert device_type != "cpu", "Only CUDA supports multiple workers/GPUs. cpu|mps does not."
for gpu_id in range(self.num_gpus):
if device_type == "cuda":
if device_type != "cpu":
device = torch.device(f"cuda:{gpu_id}")
print(f"Loading model on GPU {gpu_id}...")
else:
@ -125,7 +125,7 @@ class WorkerPool:
model, tokenizer, _ = load_model(source, device, phase="eval", model_tag=model_tag, step=step)
engine = Engine(model, tokenizer)
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext()
worker = Worker(
gpu_id=gpu_id,

View File

@ -67,9 +67,9 @@ device_type = autodetect_device_type() if args.device_type == "" else args.devic
ddp, ddp_rank, ddp_local_rank, ddp_world_size, device = compute_init(device_type)
master_process = ddp_rank == 0
ptdtype = torch.float32 if args.dtype == 'float32' else torch.bfloat16
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type == "cuda" else nullcontext()
synchronize = torch.cuda.synchronize if device_type == "cuda" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type == "cuda" else lambda: 0
autocast_ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype) if device_type != "cpu" else nullcontext()
synchronize = torch.cuda.synchronize if device_type != "cpu" else lambda: None
get_max_memory = torch.cuda.max_memory_allocated if device_type != "cpu" else lambda: 0
# wandb logging init
use_dummy_wandb = args.run == "dummy" or not master_process
@ -209,7 +209,7 @@ def mid_data_generator_bos_bestfit(split, buffer_size=100):
last_step = True
# Build tensors
use_cuda = device_type == "cuda"
use_cuda = device_type != "cpu"
batch_tensor = torch.tensor(rows, dtype=torch.long, pin_memory=use_cuda)
inputs = batch_tensor[:, :-1].to(device=device, dtype=torch.int32, non_blocking=use_cuda)
targets = batch_tensor[:, 1:].to(device=device, dtype=torch.int64, non_blocking=use_cuda)