mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add regression test for the neg folding pattern (#4979)
This commit is contained in:
@@ -137,6 +137,22 @@ class TestNonFloatUOps(TestUOps):
|
||||
def test_where_float16(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
|
||||
|
||||
def test_neg_fold(self):
|
||||
data0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (0, True))
|
||||
data1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (1, False))
|
||||
data2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), (2, False))
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
x = UOp(UOps.LOAD, dtypes.int, (data1, idx,))
|
||||
y = UOp(UOps.LOAD, dtypes.int, (data2, idx,))
|
||||
|
||||
value = x+(-y)
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (data0, idx, value))])
|
||||
assert uops[-1].vin[2].arg is BinaryOps.SUB
|
||||
|
||||
value = -y+x
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (data0, idx, value))])
|
||||
with self.assertRaises(AssertionError): assert uops[-1].vin[2].arg is BinaryOps.SUB
|
||||
|
||||
class TestBoolUOps(TestUOps):
|
||||
def _test_uop_bool_fxn(self, op, fxn):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
|
||||
Reference in New Issue
Block a user