fallback to cpu on compute_init function

fallback to cpu on compute_init function
This commit is contained in:
Sermet Pekin 2025-10-20 11:43:47 +03:00 committed by GitHub
parent 11e46b6439
commit cdb5a455ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -91,34 +91,26 @@ def get_dist_info():
def compute_init():
"""Basic initialization that we keep doing over and over, so make common."""
# CUDA is currently required
assert torch.cuda.is_available(), "CUDA is needed for a distributed run atm"
# Reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed(42)
# skipping full reproducibility for now, possibly investigate slowdown later
# torch.use_deterministic_algorithms(True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# Check if CUDA is available, otherwise fall back to CPU
if torch.cuda.is_available():
device = torch.device("cuda")
torch.manual_seed(42)
torch.cuda.manual_seed(42)
else:
device = torch.device("cpu")
torch.manual_seed(42)
logger.warning("CUDA is not available. Falling back to CPU.")
# Precision
torch.set_float32_matmul_precision("high") # uses tf32 instead of fp32 for matmuls
# Distributed setup: Distributed Data Parallel (DDP), optional
ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
if ddp:
if ddp and torch.cuda.is_available():
device = torch.device("cuda", ddp_local_rank)
torch.cuda.set_device(device) # make "cuda" default to this device
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
else:
device = torch.device("cuda")
if ddp_rank == 0:
logger.info(f"Distributed world size: {ddp_world_size}")
return ddp, ddp_rank, ddp_local_rank, ddp_world_size, device
def compute_cleanup():