This commit re-adds the `PYTORCH_CUDA_ALLOC_CONF` environment variable to the training scripts. This setting helps prevent memory fragmentation and is beneficial for both CUDA and ROCm environments. This change was inadvertently removed during a previous refactoring.
This change adds the `PYTORCH_CUDA_ALLOC_CONF` environment variable to the main `speedrun.sh` execution script.
Setting `expandable_segments:True` is recommended by PyTorch to manage memory more efficiently and prevent fragmentation, addressing a `UserWarning` observed during execution.
This commit adds the `HSA_OVERRIDE_GFX_VERSION` environment variable to the `speedrun.sh` script. This is a workaround to enable support for newer AMD GPU architectures (e.g., gfx1151) that are not yet officially supported in the pre-compiled PyTorch ROCm builds.
This change also includes an update to the `README.md` to explain this workaround to users.
This commit fixes a `torch.AcceleratorError: HIP error: invalid device function` that occurred during weight initialization on ROCm devices. It also improves the device detection logic to correctly identify and prioritize the ROCm backend.
The key changes are:
- Patched `nanochat/gpt.py` to initialize weights on the CPU before moving them to the target device, which avoids the HIP kernel error.
- Simplified and corrected the device detection logic in `nanochat/common.py` to ensure the ROCm backend is properly selected when available.
This commit addresses several runtime errors encountered during the execution of the `speedrun.sh` script and improves the overall configuration of the project.
The key changes are:
- Patched `nanochat/configurator.py` to be more robust by handling flag-like arguments and ignoring unknown arguments. This resolves the `AssertionError`.
- Fixed the argument handling for `chat_eval.py` in `speedrun.sh` to prevent argument parsing errors.
- Updated `pyproject.toml` to correctly define optional dependencies for development.
This change adds support for ROCm and makes the codebase device-agnostic, allowing it to run on different hardware backends including ROCm, CUDA, and CPU.
The key changes are:
- Modified `pyproject.toml` to use ROCm-compatible PyTorch wheels and added the `pytorch-triton-rocm` dependency.
- Refactored `nanochat/common.py` to dynamically detect the available hardware and set the device and distributed backend accordingly.
- Updated all training, evaluation, and inference scripts to be device-agnostic, removing hardcoded CUDA references.
- Adapted `speedrun.sh` for single-device execution by replacing `torchrun` with `python`.
- Updated `nanochat/report.py` to provide more generic GPU information.