diff --git a/nanochat/fp8.py b/nanochat/fp8.py index 3f056d1..19e15ae 100644 --- a/nanochat/fp8.py +++ b/nanochat/fp8.py @@ -271,6 +271,7 @@ class _Float8MatmulND(torch.autograd.Function): return grad_input, grad_weight + class Float8Linear(nn.Linear): """Drop-in nn.Linear replacement that does FP8 compute.