Revert "FP8 support on NVIDIA (#8631)"

This reverts commit 2c8e4ea865.
This commit is contained in:
George Hotz
2025-04-09 12:27:41 +08:00
parent d1505137ad
commit 78caf55154
13 changed files with 45 additions and 203 deletions

View File

@@ -39,8 +39,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target,
rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e5m2: 1, dtypes.fp8e4m3: 1e-1}.get(target_dtype, tol_target_dtype))
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@@ -57,7 +56,6 @@ def _test_cast(a:Tensor, target_dtype:DType):
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
if target_dtype in dtypes.fp8s: raise unittest.SkipTest("no test for fp8s bitcast yet")
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
@@ -111,7 +109,7 @@ class TestDType(unittest.TestCase):
fields = dtypes.fields()
self.assertIn("float", fields)
self.assertIn("float32", fields)
self.assertEqual(len(fields), 26)
self.assertEqual(len(fields), 24)
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
@@ -207,31 +205,6 @@ class TestBFloat16DTypeCast(unittest.TestCase):
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
class TestFp8sDType(unittest.TestCase):
def _float_to_fp8_conversion_test(self, dtype, input_values, expected_values):
test_tensor = Tensor(input_values).cast(dtype).realize()
back_to_float32 = test_tensor.cast(dtypes.float32)
np.testing.assert_equal(tuple(back_to_float32.numpy().tolist()), expected_values)
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8e4m3 not supported")
def test_float_to_fp8e4m3_conversion(self):
self._float_to_fp8_conversion_test(dtypes.fp8e4m3,
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
[448.0, -1.0, 416.0, -288.0, -448.0, 20.0, 1.375, 0.0, 448.0, math.nan])
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "fp8e5m2 not supported")
def test_float_to_fp8e5m2_conversion(self):
self._float_to_fp8_conversion_test(dtypes.fp8e5m2,
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
[57344.0, -1, 384, -320, -57344.0, 20, 1.5, 0.0, 57344.0, math.nan])
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3) and is_dtype_supported(dtypes.fp8e5m2), "fp8s not supported")
def test_fp8e4m3_plus_fp8e5m2_output_dtype(self):
a = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e4m3)
b = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e5m2)
result = a + b
self.assertEqual(result.dtype, dtypes.half)
class TestHalfDType(TestDType): DTYPE = dtypes.half
class TestFloatDType(TestDType):
@@ -291,7 +264,6 @@ class TestBitCast(unittest.TestCase):
def test_shape_change_bitcast(self, dt1, dt2):
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
assume(dt1 not in dtypes.fp8s and dt2 not in dtypes.fp8s) # no test for fp8 bitcast yet
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
@@ -427,10 +399,6 @@ class TestHelpers(unittest.TestCase):
def test_bf16_is_float(self):
assert dtypes.is_float(dtypes.bfloat16)
def test_fp8s_are_float(self):
assert dtypes.is_float(dtypes.fp8e4m3)
assert dtypes.is_float(dtypes.fp8e5m2)
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
def test_scalar(self, dtype, amt):
assert dtype.vec(amt).scalar() == dtype
@@ -494,7 +462,7 @@ class TestTypeSpec(unittest.TestCase):
dtypes.default_int = default_int
assert dtypes.default_int == default_int
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
dtypes.default_float = default_float
assert dtypes.default_float == default_float
@@ -725,9 +693,7 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
@@ -759,9 +725,7 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
@@ -776,9 +740,7 @@ class TestAutoCastType(unittest.TestCase):
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum().dtype == dtypes.fp8e4m3
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum().dtype == dtypes.fp8e5m2
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
@@ -839,10 +801,10 @@ class TestAutoCastType(unittest.TestCase):
def test_gradient_dtype(self):
old_default_float = dtypes.default_float
for default_dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
if not is_dtype_supported(default_dtype): continue
dtypes.default_float = default_dtype
for dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
if not is_dtype_supported(dtype): continue
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")