update cast to uint tests (#14768)

result in valid range should work, add intermediate cast to NIRRenderer since it's UB for [128, 256)
This commit is contained in:
chenyu
2026-02-15 10:55:13 -05:00
committed by GitHub
parent ceccc8eb86
commit 352845d8cc
3 changed files with 4 additions and 3 deletions

View File

@@ -296,7 +296,6 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, strat.sampled_from(dtypes_float+dtypes_int+dtypes_bool))
def test_int32_cast(self, a, dtype): universal_test_cast(a, dtypes.int32, dtype)
@unittest.skip("relied on hacks")
@given(strat.floats(width=32, min_value=1.0, max_value=254.0, allow_subnormal=False),
strat.sampled_from(dtypes_float), strat.sampled_from((dtypes.uint8, dtypes.uint16)))
def test_float_cast_to_unsigned(self, a, float_dtype, unsigned_dtype):

View File

@@ -3295,9 +3295,8 @@ class TestOps(unittest.TestCase):
@unittest.skipUnless(is_dtype_supported(dtypes.uchar), f"no uint8 on {Device.DEFAULT}")
class TestOpsUint8(unittest.TestCase):
@unittest.skip("relied on hacks")
def test_cast(self):
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True)
helper_test_op([(2,3,64,64)], lambda x: x.type(torch.uint8), lambda x: x.cast('uint8'), forward_only=True, low=0, high=255)
def test_cast_relu(self):
helper_test_op([(2,3,64,64)], lambda x: x.relu().type(torch.uint8), lambda x: x.relu().cast('uint8'), forward_only=True)

View File

@@ -131,6 +131,9 @@ class NIRRenderer(Renderer):
lambda x: x.replace(dtype=dtypes.uint8, src=x.src[0:1]+((x.src[1].cast(dtypes.uint8),) if len(x.src)>=2 else ())+x.src[2:]).cast(dtypes.bool)),
(UPat(Ops.STORE, src=(UPat(), UPat(dtype=dtypes.bool)), name="x", allow_any_len=True),
lambda x: x.replace(src=x.src[0:1] + (x.src[1].cast(dtypes.uint8),) + x.src[2:])),
# OpConvertFToU is undefined if Result Type is not wide enough, cast through int32
# ref: https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpConvertFToU
(UPat(Ops.CAST, (dtypes.uchar, dtypes.ushort), src=(UPat.var("x", dtypes.floats),), name="c"), lambda x,c: x.cast(dtypes.int32).cast(c.dtype)),
# load/store use pointer arithmetic, and the cast does nothing
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("off")), allow_any_len=True, name="x"), lambda x,buf,off: x.replace(
src=(buf,off.cast(dtypes.long))+x.src[2:]) if buf.dtype.addrspace != AddrSpace.REG and off.op not in (Ops.CAST, Ops.VECTORIZE) else None),