Add new patterns to unfold division (#5139)

* Add new patterns to unfold division

* Create regression test and fix pattern
This commit is contained in:
Jhenner Tigreros
2024-06-25 20:07:47 -05:00
committed by GitHub
parent c4fdb9c725
commit fa78755f19
2 changed files with 34 additions and 4 deletions

View File

@@ -661,6 +661,27 @@ class TestLinearizer(unittest.TestCase):
with self.assertRaises(AssertionError):
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
def test_div_collapse(self):
def helper(t, msg, max_ops=0):
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(*sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.rand((4,4))
b = Tensor.rand((4,4))
d = Tensor.rand((4,4))
c = (a*b)/b
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
c = a/a
helper(c, "found UnaryOps.RECIP in (a/a) operation")
c = (a/b)/d
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]

View File

@@ -53,6 +53,7 @@ class UOp:
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
def ge(self, x): return -self.lt(x)
@@ -210,8 +211,12 @@ constant_folder = PatternMatcher([
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
(UOp.var('x') // 1, lambda x: x), # x/1 -> x
(UOp.var('x') // -1, lambda x: -x), # x/-1 -> -x
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
(UOp.var('x') / 1, lambda x: x), # x/1 -> x
(UOp.var('x') / -1, lambda x: -x), # x/-1 -> -x
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
# ** zero folding **
#x*0 -> 0 or 0*x -> 0
@@ -230,10 +235,14 @@ constant_folder = PatternMatcher([
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
# (x*c0)+(x*c1) -> x*(c0+c1)
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
# (x*c0)/c0 -> x
# (x*c0)//c0 -> x
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
# (x/c0)/c1 -> x/(c0*c1)
# (x*x2)/x2 -> x
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
# (x//c0)//c1 -> x//(c0*c1)
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
# (x/x1)/x2 -> x/(x1*x2)
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
# c0 + x < c1 -> x < c1 - c0
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),