mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -39,8 +39,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float
|
||||
if DEBUG >= 2: print(tensor.numpy())
|
||||
try:
|
||||
assert tensor.dtype == target_dtype
|
||||
np.testing.assert_allclose(tensor.numpy(), target,
|
||||
rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2, dtypes.fp8e5m2: 1, dtypes.fp8e4m3: 1e-1}.get(target_dtype, tol_target_dtype))
|
||||
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
|
||||
except AssertionError as e:
|
||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||
|
||||
@@ -57,7 +56,6 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
||||
if target_dtype in dtypes.fp8s: raise unittest.SkipTest("no test for fp8s bitcast yet")
|
||||
if getenv("PTX") and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize:
|
||||
raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX")
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
|
||||
@@ -111,7 +109,7 @@ class TestDType(unittest.TestCase):
|
||||
fields = dtypes.fields()
|
||||
self.assertIn("float", fields)
|
||||
self.assertIn("float32", fields)
|
||||
self.assertEqual(len(fields), 26)
|
||||
self.assertEqual(len(fields), 24)
|
||||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
|
||||
|
||||
@@ -207,31 +205,6 @@ class TestBFloat16DTypeCast(unittest.TestCase):
|
||||
converted = random_values.cast(dtypes.bfloat16).cast(dtypes.float32)
|
||||
np.testing.assert_allclose(converted.numpy(), random_values.cast(dtypes.float32).numpy(), rtol=1e-2, atol=1e-3)
|
||||
|
||||
class TestFp8sDType(unittest.TestCase):
|
||||
def _float_to_fp8_conversion_test(self, dtype, input_values, expected_values):
|
||||
test_tensor = Tensor(input_values).cast(dtype).realize()
|
||||
back_to_float32 = test_tensor.cast(dtypes.float32)
|
||||
np.testing.assert_equal(tuple(back_to_float32.numpy().tolist()), expected_values)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3), "fp8e4m3 not supported")
|
||||
def test_float_to_fp8e4m3_conversion(self):
|
||||
self._float_to_fp8_conversion_test(dtypes.fp8e4m3,
|
||||
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
|
||||
[448.0, -1.0, 416.0, -288.0, -448.0, 20.0, 1.375, 0.0, 448.0, math.nan])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e5m2), "fp8e5m2 not supported")
|
||||
def test_float_to_fp8e5m2_conversion(self):
|
||||
self._float_to_fp8_conversion_test(dtypes.fp8e5m2,
|
||||
[10000000.0, -1.0, 402.0, -300.0, -10000000.0, 20.0, 1.4123, 0.0, math.inf, math.nan],
|
||||
[57344.0, -1, 384, -320, -57344.0, 20, 1.5, 0.0, 57344.0, math.nan])
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.fp8e4m3) and is_dtype_supported(dtypes.fp8e5m2), "fp8s not supported")
|
||||
def test_fp8e4m3_plus_fp8e5m2_output_dtype(self):
|
||||
a = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e4m3)
|
||||
b = Tensor([1.0, 2.0, 3.0], dtype=dtypes.fp8e5m2)
|
||||
result = a + b
|
||||
self.assertEqual(result.dtype, dtypes.half)
|
||||
|
||||
class TestHalfDType(TestDType): DTYPE = dtypes.half
|
||||
|
||||
class TestFloatDType(TestDType):
|
||||
@@ -291,7 +264,6 @@ class TestBitCast(unittest.TestCase):
|
||||
def test_shape_change_bitcast(self, dt1, dt2):
|
||||
# NOTE: this has to be assume to prevent hypothesis from skipping all samples
|
||||
assume(dt2 != dtypes.bfloat16 and dt1 != dtypes.bfloat16) # no test for bf16 bitcast yet
|
||||
assume(dt1 not in dtypes.fp8s and dt2 not in dtypes.fp8s) # no test for fp8 bitcast yet
|
||||
assume(not (getenv("PTX") and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX
|
||||
data = rand_for_dtype(dt1, 32).reshape(2, 2, 8)
|
||||
_test_op(lambda: Tensor(data, dtype=dt1).bitcast(dt2), dt2, data.view(_to_np_dtype(dt2)).tolist())
|
||||
@@ -427,10 +399,6 @@ class TestHelpers(unittest.TestCase):
|
||||
def test_bf16_is_float(self):
|
||||
assert dtypes.is_float(dtypes.bfloat16)
|
||||
|
||||
def test_fp8s_are_float(self):
|
||||
assert dtypes.is_float(dtypes.fp8e4m3)
|
||||
assert dtypes.is_float(dtypes.fp8e5m2)
|
||||
|
||||
@given(strat.sampled_from([d for d in DTYPES_DICT.values() if dtypes.is_float(d) or dtypes.is_int(d)]), strat.integers(min_value=2, max_value=8))
|
||||
def test_scalar(self, dtype, amt):
|
||||
assert dtype.vec(amt).scalar() == dtype
|
||||
@@ -494,7 +462,7 @@ class TestTypeSpec(unittest.TestCase):
|
||||
dtypes.default_int = default_int
|
||||
assert dtypes.default_int == default_int
|
||||
|
||||
for default_float in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_float in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
dtypes.default_float = default_float
|
||||
assert dtypes.default_float == default_float
|
||||
|
||||
@@ -725,9 +693,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).sum().dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).sum().dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).sum().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).sum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).sum().dtype == dtypes.fp8e5m2
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).sum().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64
|
||||
|
||||
@@ -759,9 +725,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).mean().dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).mean().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).mean().dtype == dtypes.fp8e5m2
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).mean().dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).mean().dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).mean().dtype == dtypes.float64
|
||||
|
||||
@@ -776,9 +740,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint32)).cumsum(0).dtype == dtypes.uint32
|
||||
assert (Tensor([0, 1], dtype=dtypes.uint64)).cumsum(0).dtype == dtypes.uint64
|
||||
assert (Tensor([0, 1], dtype=dtypes.float16)).cumsum(0).dtype == dtypes.float16
|
||||
assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e4m3)).cumsum().dtype == dtypes.fp8e4m3
|
||||
assert (Tensor([0, 1], dtype=dtypes.fp8e5m2)).cumsum().dtype == dtypes.fp8e5m2
|
||||
#assert (Tensor([0, 1], dtype=dtypes.bfloat16)).cumsum(0).dtype == dtypes.bfloat16
|
||||
assert (Tensor([0, 1], dtype=dtypes.float32)).cumsum(0).dtype == dtypes.float32
|
||||
assert (Tensor([0, 1], dtype=dtypes.float64)).cumsum(0).dtype == dtypes.float64
|
||||
|
||||
@@ -839,10 +801,10 @@ class TestAutoCastType(unittest.TestCase):
|
||||
def test_gradient_dtype(self):
|
||||
old_default_float = dtypes.default_float
|
||||
|
||||
for default_dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for default_dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(default_dtype): continue
|
||||
dtypes.default_float = default_dtype
|
||||
for dtype in [*dtypes.fp8s, dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
for dtype in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]:
|
||||
if not is_dtype_supported(dtype): continue
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
|
||||
Reference in New Issue
Block a user