Even when calling `torch.set_float32_matmul_precision('high')`, the `torch.compile` (Inductor) backend on some ROCm versions may still warn that TensorFloat32 is available but not enabled. This change explicitly sets `torch.backends.cuda.matmul.allow_tf32 = True` to ensure the setting is active and to silence the warning.
When running on AMD ROCm using `uv`-installed packages (`rocm-sdk-core`), the `ld.lld` linker is not in the default `/opt/rocm/llvm/bin/` location expected by `pytorch-triton-rocm`. This causes `InductorError` during `torch.compile`.
This change updates `speedrun.sh` to dynamically find the `ld.lld` binary within the active python environment's site-packages (`_rocm_sdk_core`) and export the `TRITON_HIP_LLD_PATH` environment variable, allowing Triton to locate the linker correctly.
On consumer AMD hardware (like APUs or gaming GPUs) running ROCm, the default `nccl` backend (which wraps RCCL) often fails with `invalid device function` due to architecture mismatches or kernel issues.
This change detects the presence of `torch.version.hip` and forces the `gloo` backend for `torch.distributed.init_process_group`. While `gloo` is slower for data transfer, it is CPU-based and significantly more robust for these setups, ensuring the training script can run without crashing.
The `speedrun.sh` script was hardcoding `NPROC_PER_NODE=8` if any GPU capability was detected, causing crashes on systems with fewer than 8 GPUs. Additionally, `nanochat/common.py` was autodetecting "cuda" even if `torch.cuda.device_count()` was 0 on some ROCm builds, leading to "invalid device ordinal" errors.
Changes:
- `speedrun.sh`: Dynamically set `NPROC_PER_NODE` using `torch.cuda.device_count()`.
- `nanochat/common.py`: Ensure `autodetect_device_type` only returns "cuda" if devices are actually present.
Uninstalling the conflicting `triton` package (upstream) on AMD systems often removes the `triton` directory shared with `pytorch-triton-rocm`, breaking the latter. This caused `ImportError: cannot import name 'Config' from 'triton'`.
This change adds a step to force reinstall `pytorch-triton-rocm` immediately after uninstalling `triton`, ensuring the correct package is present and intact for the runtime.
On AMD ROCm environments, `uv run` was detecting that the manually uninstalled `triton` package was missing (since it's a transitive dependency of `torch`) and reinstalling it during the tokenizer build step. This caused `ImportError: cannot import name 'Config' from 'triton'` due to conflict with `pytorch-triton-rocm`.
This change adds `--no-sync` to the `uv run` command for building the tokenizer, preventing `uv` from undoing the manual uninstallation of `triton`.
Explicitly uninstall `triton` when AMD GPU is detected.
The standard `triton` package (often pulled by NVIDIA dependencies or accident)
conflicts with `pytorch-triton-rocm` on AMD systems, causing
`ImportError: cannot import name 'Config' from 'triton'`.
This change ensures a clean ROCm environment by removing the conflicting package.
Also retains the `uv run --extra $EXTRAS` fix from the previous step.
Explicitly pass `--extra $EXTRAS` to `uv run` when building the tokenizer.
This prevents `uv` from reverting to the default (NVIDIA) dependency set
during the `maturin` build step, ensuring the correct PyTorch version
(ROCm) is preserved on AMD hardware.
- Explicitly add `device` argument to `tokenizing_distributed_data_loader` in `nanochat/dataloader.py` to prevent `TypeError: unexpected keyword argument 'device'` when called from `scripts/base_train.py`.
- Update `nanochat/configurator.py` to ignore command-line flags starting with `--` (e.g. `--help`) instead of raising `AssertionError`, improving robustness when running with various launchers or flags.