mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-22 13:28:06 -05:00
UOp if min==max folds to CONST (#5828)
* UOp if min==max folds to CONST * fix test
This commit is contained in:
@@ -78,17 +78,17 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('a', 0, 1))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2),a), arg=BinaryOps.ADD)]
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const(2), a), arg=BinaryOps.ADD)]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
print(sink)
|
||||
self.assertEqual(sink.op, UOps.ALU)
|
||||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
|
||||
|
||||
class TestUOpGraph(TestUOps):
|
||||
def test_add_constant_fold(self):
|
||||
@@ -213,13 +213,13 @@ class TestUOpGraph(TestUOps):
|
||||
self.assertEqual(len([x for x in g.uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 3)
|
||||
self.assertEqual(len(g.uops), 5)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
|
||||
@@ -209,6 +209,7 @@ constant_folder = PatternMatcher([
|
||||
# NOTE: this can be wrong for loaded NaN
|
||||
(NOp.var('x') * 0, lambda x: x.const(float('nan') if isinstance(x.arg, float) and (math.isnan(x.arg) or math.isinf(x.arg)) else 0)),
|
||||
(NOp.var('x') - NOp.var('x'), lambda x: x.const(0)), # x-x -> 0
|
||||
(UPat(op=UOps.ALU, name='x'), lambda x: x.const(x.vmin.arg) if x.op is not UOps.CONST and x.vmin.arg == x.vmax.arg else None),
|
||||
# lt folding
|
||||
(NOp.var('x').lt(NOp.var('y')),
|
||||
lambda x,y: NOp.const(dtypes.bool, True) if x.vmax.arg < y.vmin.arg else NOp.const(dtypes.bool, False) if x.vmin.arg >= y.vmax.arg else None),
|
||||
|
||||
Reference in New Issue
Block a user