mirror of
https://github.com/karpathy/nanochat.git
synced 2025-12-06 04:12:13 +00:00
fix: get safe autocast dtype
This commit is contained in:
parent
662ff7eb7a
commit
1c5dd2b7ba
|
|
@ -91,12 +91,49 @@ def get_device_type():
|
|||
return "cpu"
|
||||
return "cuda"
|
||||
|
||||
def get_safe_autocast_dtype(device_type: str, preferred_dtype=None) -> torch.dtype:
|
||||
"""
|
||||
Return a safe dtype for autocast on the given device.
|
||||
|
||||
Args:
|
||||
device_type: "cuda" or "cpu"
|
||||
preferred_dtype: Preferred dtype (torch.dtype, str, or None)
|
||||
|
||||
Returns:
|
||||
A dtype that is safe for autocast on the device
|
||||
"""
|
||||
# Parse the preferred dtype
|
||||
if isinstance(preferred_dtype, torch.dtype):
|
||||
dtype = preferred_dtype
|
||||
elif isinstance(preferred_dtype, str):
|
||||
dtype_map = {
|
||||
"bfloat16": torch.bfloat16, "bf16": torch.bfloat16,
|
||||
"float16": torch.float16, "fp16": torch.float16, "half": torch.float16,
|
||||
"float32": torch.float32, "fp32": torch.float32,
|
||||
}
|
||||
dtype = dtype_map.get(preferred_dtype.lower())
|
||||
if dtype is None:
|
||||
raise ValueError(f"Unknown dtype string: {preferred_dtype}")
|
||||
elif preferred_dtype is None:
|
||||
# Default: bfloat16 on CUDA, bfloat16 on CPU (both support it)
|
||||
dtype = torch.bfloat16
|
||||
else:
|
||||
raise TypeError(f"Invalid dtype type: {type(preferred_dtype)}")
|
||||
|
||||
# Validate dtype compatibility with device
|
||||
# CPU autocast only supports bfloat16 and float16
|
||||
if device_type == "cpu" and dtype == torch.float32:
|
||||
logger.warning(
|
||||
f"CPU autocast doesn't support {dtype}, using bfloat16 instead"
|
||||
)
|
||||
dtype = torch.bfloat16
|
||||
|
||||
return dtype
|
||||
|
||||
def get_default_dtype():
|
||||
"""Get the default dtype for training: bfloat16 on GPU, float32 on CPU."""
|
||||
# bfloat16 is well-supported on modern GPUs but may have issues on CPU
|
||||
if torch.cuda.is_available():
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
"""Get the default dtype for training based on available hardware."""
|
||||
device_type = get_device_type()
|
||||
return get_safe_autocast_dtype(device_type)
|
||||
|
||||
def get_dist_info():
|
||||
if is_ddp():
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user