diff --git a/test/backend/test_dtype_alu.py b/test/backend/test_dtype_alu.py index f3ce0c7732..58f2f7d61d 100644 --- a/test/backend/test_dtype_alu.py +++ b/test/backend/test_dtype_alu.py @@ -296,7 +296,6 @@ class TestDTypeALU(unittest.TestCase): @given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool)) def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype) - @unittest.skip("relied on hacks") @given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False), strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16))) def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype): diff --git a/test/backend/test_ops.py b/test/backend/test_ops.py index 1c86012701..56606b7f0d 100644 --- a/test/backend/test_ops.py +++ b/test/backend/test_ops.py @@ -3295,9 +3295,8 @@ class TestOps(unittest.TestCase): @unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}") class TestOpsUint8(unittest.TestCase): - @unittest.skip("relied on hacks") def test_cast(self): - helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True) + helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255) def test_cast_relu(self): helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True) diff --git a/tinygrad/renderer/nir.py b/tinygrad/renderer/nir.py index 8ad9bf9fbd..de257f3c16 100644 --- a/tinygrad/renderer/nir.py +++ b/tinygrad/renderer/nir.py @@ -131,6 +131,9 @@ class NIRRenderer(Renderer): lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)), (UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True), lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])), + # OpConvertFToU is undefined if Result Type is not wide enough, cast through int32 + # ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU + (UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)), # load/store use pointer arithmetic, and the cast does nothing (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace( src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.VECTORIZE) else None),