fix: get safe autocast dtype

This commit is contained in:
Kirk Lin 2025-10-15 11:04:57 +08:00
parent 662ff7eb7a
commit 1c5dd2b7ba

View File

@ -91,12 +91,49 @@ def get_device_type():
return "cpu"
return "cuda"
def get_safe_autocast_dtype(device_type: str, preferred_dtype=None) -> torch.dtype:
"""
Return a safe dtype for autocast on the given device.
Args:
device_type: "cuda" or "cpu"
preferred_dtype: Preferred dtype (torch.dtype, str, or None)
Returns:
A dtype that is safe for autocast on the device
"""
# Parse the preferred dtype
if isinstance(preferred_dtype, torch.dtype):
dtype = preferred_dtype
elif isinstance(preferred_dtype, str):
dtype_map = {
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
"float32": torch.float32, "fp32": torch.float32,
}
dtype = dtype_map.get(preferred_dtype.lower())
if dtype is None:
raise ValueError(f"Unknown dtype string: {preferred_dtype}")
elif preferred_dtype is None:
# Default: bfloat16 on CUDA, bfloat16 on CPU (both support it)
dtype = torch.bfloat16
else:
raise TypeError(f"Invalid dtype type: {type(preferred_dtype)}")
# Validate dtype compatibility with device
# CPU autocast only supports bfloat16 and float16
if device_type == "cpu" and dtype == torch.float32:
logger.warning(
f"CPU autocast doesn't support {dtype}, using bfloat16 instead"
)
dtype = torch.bfloat16
return dtype
def get_default_dtype():
"""Get the default dtype for training: bfloat16 on GPU, float32 on CPU."""
# bfloat16 is well-supported on modern GPUs but may have issues on CPU
if torch.cuda.is_available():
return torch.bfloat16
return torch.float32
"""Get the default dtype for training based on available hardware."""
device_type = get_device_type()
return get_safe_autocast_dtype(device_type)
def get_dist_info():
if is_ddp():