add regression test for the neg folding pattern (#4979)

This commit is contained in:
qazal
2024-06-15 20:08:28 +08:00
committed by GitHub
parent dfadf82e10
commit d91f0ee85b

View File

@@ -137,6 +137,22 @@ class TestNonFloatUOps(TestUOps):
def test_where_float16(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
def test_neg_fold(self):
data0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
data1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False))
data2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False))
idx = UOp.const(dtypes.int, 0)
x = UOp(UOps.LOAD, dtypes.int, (data1, idx,))
y = UOp(UOps.LOAD, dtypes.int, (data2, idx,))
value = x+(-y)
uops = UOpGraph([UOp(UOps.STORE, None, (data0, idx, value))])
assert uops[-1].vin[2].arg is BinaryOps.SUB
value = -y+x
uops = UOpGraph([UOp(UOps.STORE, None, (data0, idx, value))])
with self.assertRaises(AssertionError): assert uops[-1].vin[2].arg is BinaryOps.SUB
class TestBoolUOps(TestUOps):
def _test_uop_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]: