mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
update cast to uint tests (#14768)
result in valid range should work, add intermediate cast to NIRRenderer since it's UB for [128, 256)
This commit is contained in:
@@ -296,7 +296,6 @@ class TestDTypeALU(unittest.TestCase):
|
||||
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
|
||||
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
|
||||
|
||||
@unittest.skip("relied on hacks")
|
||||
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
|
||||
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
|
||||
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):
|
||||
|
||||
@@ -3295,9 +3295,8 @@ class TestOps(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
|
||||
class TestOpsUint8(unittest.TestCase):
|
||||
@unittest.skip("relied on hacks")
|
||||
def test_cast(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255)
|
||||
|
||||
def test_cast_relu(self):
|
||||
helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True)
|
||||
|
||||
@@ -131,6 +131,9 @@ class NIRRenderer(Renderer):
|
||||
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
|
||||
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
|
||||
lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
|
||||
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
|
||||
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
|
||||
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
|
||||
# load/store use pointer arithmetic, and the cast does nothing
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
|
||||
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.VECTORIZE) else None),
|
||||
|
||||
Reference in New Issue
Block a user