From fa78755f198b35334b9c5397d98ffc9527cd8e1a Mon Sep 17 00:00:00 2001 From: Jhenner Tigreros <32320832+JhennerTigreros@users.noreply.github.com> Date: Tue, 25 Jun 2024 20:07:47 -0500 Subject: [PATCH] Add new patterns to unfold division (#5139) * Add new patterns to unfold division * Create regression test and fix pattern --- test/test_linearizer.py | 21 +++++++++++++++++++++ tinygrad/codegen/uops.py | 17 +++++++++++++---- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7d33b26f7a..130badb496 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -661,6 +661,27 @@ class TestLinearizer(unittest.TestCase): with self.assertRaises(AssertionError): get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,) + def test_div_collapse(self): + def helper(t, msg, max_ops=0): + sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] + assert len(sched) == 1 + + lin = Linearizer(*sched[0].ast) + assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg + + a = Tensor.rand((4,4)) + b = Tensor.rand((4,4)) + d = Tensor.rand((4,4)) + + c = (a*b)/b + helper(c, "found UnaryOps.RECIP in (a*b)/b operation") + + c = a/a + helper(c, "found UnaryOps.RECIP in (a/a) operation") + + c = (a/b)/d + helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1) + def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps] diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 326265bb71..2078ed2d4d 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -53,6 +53,7 @@ class UOp: def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x)) def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self) def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x)) + def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x))) def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x)) def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x)) def ge(self, x): return -self.lt(x) @@ -210,8 +211,12 @@ constant_folder = PatternMatcher([ (UOp.var('x') + 0, lambda x: x), # x+0 -> x (UOp.var('x') - 0, lambda x: x), # x-0 -> x (UOp.var('x') * 1, lambda x: x), # x*1 -> x - (UOp.var('x') // 1, lambda x: x), # x/1 -> x - (UOp.var('x') // -1, lambda x: -x), # x/-1 -> -x + (UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1 + (UOp.var('x') // 1, lambda x: x), # x//1 -> x + (UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x + (UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1 + (UOp.var('x') / 1, lambda x: x), # x/1 -> x + (UOp.var('x') / -1, lambda x: -x), # x/-1 -> -x (UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x # ** zero folding ** #x*0 -> 0 or 0*x -> 0 @@ -230,10 +235,14 @@ constant_folder = PatternMatcher([ (UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)), # (x*c0)+(x*c1) -> x*(c0+c1) (UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])), - # (x*c0)/c0 -> x + # (x*c0)//c0 -> x ((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None), - # (x/c0)/c1 -> x/(c0*c1) + # (x*x2)/x2 -> x + ((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x), + # (x//c0)//c1 -> x//(c0*c1) ((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))), + # (x/x1)/x2 -> x/(x1*x2) + ((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)), # c0 + x < c1 -> x < c1 - c0 ((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")), lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),