Files
tinygrad/test/testextra/test_fp8_linear.py
b1tg 0fbc551622 train bert with fp8 (#13874)
* fp8 train

* clean

* lint

* test fix from #13439

* skip first/last layer

* rm __init__, restore unroll <=32 check

* tests

* clean test, remove unused

* multi-gpu test, clean quantize_to_fp8

* remove bert contiguous

* run script

* test: better check

* run script search

* add seed in bert data shuffle

* move script to mi350x folder

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2026-01-09 09:21:59 -05:00

93 lines
4.0 KiB
Python

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad import Tensor, dtypes, Device
from tinygrad.nn import Linear
from extra.fp8.fp8_linear import FP8Linear, convert_to_float8_training
from tinygrad.device import is_dtype_supported
from test.helpers import not_support_multi_device, needs_second_gpu
BS, T, in_dim, out_dim = 16, 4, 128, 128
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), f"no fp8e4m3 on {Device.DEFAULT}")
class TestFP8Linear(unittest.TestCase):
def setUp(self):
Tensor.manual_seed(42)
def _test_forward(self, shape, in_features, out_features):
fp8_layer = FP8Linear(in_features, out_features)
normal_layer = Linear(in_features, out_features)
weight = Tensor.randn(out_features, in_features, dtype=dtypes.float32) * 0.2
bias = Tensor.randn(out_features, dtype=dtypes.float32) * 0.2
fp8_layer.weight.assign(weight)
normal_layer.weight.assign(weight)
fp8_layer.bias.assign(bias)
normal_layer.bias.assign(bias)
x = Tensor.randn(*shape, dtype=dtypes.float32) * 0.2
y_fp8, y_normal = fp8_layer(x), normal_layer(x)
np.testing.assert_allclose(y_fp8.numpy(), y_normal.numpy(), rtol=0.1, atol=0.1)
def _test_backward(self, shape, in_features, out_features):
fp8_layer = FP8Linear(in_features, out_features)
normal_layer = Linear(in_features, out_features)
weight = Tensor.randn(out_features, in_features, dtype=dtypes.float32) * 0.2
bias = Tensor.randn(out_features, dtype=dtypes.float32) * 0.2
fp8_layer.weight, normal_layer.weight = weight.detach(), weight.detach()
fp8_layer.bias, normal_layer.bias = bias.detach(), bias.detach()
fp8_layer.weight.requires_grad = normal_layer.weight.requires_grad = True
x_fp8 = Tensor.randn(*shape, dtype=dtypes.float32, requires_grad=True) * 0.2
x_normal = x_fp8.detach().requires_grad_(True)
fp8_layer(x_fp8).sum().backward()
normal_layer(x_normal).sum().backward()
np.testing.assert_allclose(x_fp8.grad.numpy(), x_normal.grad.numpy(), rtol=1.0, atol=0.1)
np.testing.assert_allclose(fp8_layer.weight.grad.numpy(), normal_layer.weight.grad.numpy(), rtol=1.0, atol=0.1)
def test_forward_2d(self): self._test_forward((BS, in_dim), in_dim, out_dim)
def test_forward_3d(self): self._test_forward((BS, T, in_dim), in_dim, out_dim)
def test_backward_2d(self): self._test_backward((BS, in_dim), in_dim, out_dim)
def test_backward_3d(self): self._test_backward((BS, T, in_dim), in_dim, out_dim)
def test_filter(self):
class Model:
def __init__(self):
self.fc1 = Linear(32, 16)
self.fc2 = Linear(16, 8)
def __call__(self, x):
return self.fc2(self.fc1(x).relu())
model = Model()
x = Tensor.randn(16, 32)
y_before = model(x).numpy()
convert_to_float8_training(model, module_filter_fn=lambda _, fqn: "fc1" in fqn)
self.assertIsInstance(model.fc1, FP8Linear)
self.assertNotIsInstance(model.fc2, FP8Linear)
y_after = model(x).numpy()
np.testing.assert_allclose(y_after, y_before, rtol=0.1, atol=0.1)
@needs_second_gpu
@unittest.skipIf(not_support_multi_device(), "no multi")
def test_multi_gpu(self):
GPUS = tuple(f"{Device.DEFAULT}:{i}" for i in range(2))
fp8_layer = FP8Linear(in_dim, out_dim)
normal_layer = Linear(in_dim, out_dim)
weight = Tensor.randn(out_dim, in_dim, dtype=dtypes.float32) * 0.2
bias = Tensor.randn(out_dim, dtype=dtypes.float32) * 0.2
fp8_layer.weight.assign(weight)
fp8_layer.bias.assign(bias)
normal_layer.weight.assign(weight)
normal_layer.bias.assign(bias)
fp8_layer.weight.to_(GPUS)
fp8_layer.bias.to_(GPUS)
normal_layer.weight.to_(GPUS)
normal_layer.bias.to_(GPUS)
x = Tensor.randn(BS*2, in_dim, dtype=dtypes.float32) * 0.2
x_sharded = x.detach()
x = x.shard_(GPUS, axis=0)
y_normal = normal_layer(x).realize()
x_sharded.shard_(GPUS, axis=0)
y_fp8 = fp8_layer(x_sharded).realize()
np.testing.assert_allclose(y_fp8.numpy(), y_normal.numpy(), rtol=0.1, atol=0.1)
if __name__ == '__main__':
unittest.main()