diff --git a/test/test_dtype.py b/test/test_dtype.py index 6cb1aea187..7762166b03 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -7,7 +7,7 @@ from tinygrad.helpers import getenv, DEBUG, CI from tinygrad.dtype import DType, DTYPES_DICT, least_upper_dtype, fp8_to_float, float_to_fp8, _to_np_dtype, _to_torch_dtype, truncate from tinygrad.renderer.ptx import PTXRenderer from tinygrad import Device, Tensor, dtypes -from hypothesis import assume, given, settings, strategies as strat +from hypothesis import given, settings, strategies as strat from test.helpers import rand_for_dtype from test.unit.test_dtype_spec import _assert_eq, core_dtypes, dtype_ints, dtype_floats, FP8E4M3_MAX, FP8E5M2_MAX import pytest @@ -52,8 +52,6 @@ def _test_cast(a:Tensor, target_dtype:DType): if target_dtype in dtypes.fp8s: expected = list(map(lambda x: truncate[target_dtype](x), expected)) _test_op(lambda: a.cast(target_dtype), target_dtype, expected) def _test_bitcast(a:Tensor, target_dtype:DType, target=None): - if isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and a.dtype == dtypes.int8 and target_dtype.itemsize != a.dtype.itemsize: - raise unittest.SkipTest("shape changing bitcast of int8 broken on PTX") expected = torch.tensor(a.tolist(), dtype=_to_torch_storage_type(a.dtype)).view(_to_torch_dtype(target_dtype)).tolist() if target_dtype in dtypes.fp8s: expected = list(map(lambda x: fp8_to_float(x, target_dtype), expected)) _test_op(lambda: a.bitcast(target_dtype), target_dtype, target or expected) @@ -294,7 +292,6 @@ class TestInt8DType(TestDType): def test_int8_to_uint16_negative(self): _test_op(lambda: Tensor([-1, -2, -3, -4], dtype=dtypes.int8).cast(dtypes.uint16), dtypes.uint16, [2**16-1, 2**16-2, 2**16-3, 2**16-4]) - @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, PTXRenderer), "broken in ptx") def test_bitcast_alt(self): a = Tensor([72, -90, 27, 40, -53, 70, 96, 51], dtype=dtypes.int8).bitcast(dtypes.short) self.assertListEqual(a.tolist(), [-22968, 10267, 18123, 13152]) @@ -308,8 +305,6 @@ class TestUint8DType(TestDType): class TestBitCast(unittest.TestCase): @given(strat.sampled_from(dtype_ints + dtype_floats), strat.sampled_from(dtype_ints + dtype_floats)) def test_shape_change_bitcast(self, dt1, dt2): - # NOTE: this has to be assume to prevent hypothesis from skipping all samples - assume(not (isinstance(Device[Device.DEFAULT].renderer, PTXRenderer) and dt1 == dtypes.int8)) # TODO: bitcasting int8 fails in PTX data = rand_for_dtype(dt1, 32).reshape(2, 2, 8) expected = torch.tensor(data.tolist(), dtype=_to_torch_storage_type(dt1)).view(_to_torch_dtype(dt2)) if dt2 in dtypes.fp8s: