mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
Add new patterns to unfold division (#5139)
* Add new patterns to unfold division * Create regression test and fix pattern
This commit is contained in:
@@ -661,6 +661,27 @@ class TestLinearizer(unittest.TestCase):
|
||||
with self.assertRaises(AssertionError):
|
||||
get_grouped_dims("gidx", 0, (Variable("start_pos", 0, 16),3,4), (16,16,16,), False,)
|
||||
|
||||
def test_div_collapse(self):
|
||||
def helper(t, msg, max_ops=0):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
|
||||
|
||||
a = Tensor.rand((4,4))
|
||||
b = Tensor.rand((4,4))
|
||||
d = Tensor.rand((4,4))
|
||||
|
||||
c = (a*b)/b
|
||||
helper(c, "found UnaryOps.RECIP in (a*b)/b operation")
|
||||
|
||||
c = a/a
|
||||
helper(c, "found UnaryOps.RECIP in (a/a) operation")
|
||||
|
||||
c = (a/b)/d
|
||||
helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1)
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
|
||||
@@ -53,6 +53,7 @@ class UOp:
|
||||
def __mul__(self, x): return UOp.alu(BinaryOps.MUL, self, ufix(self.dtype, x))
|
||||
def __rmul__(self, x): return UOp.alu(BinaryOps.MUL, ufix(self.dtype, x), self)
|
||||
def __floordiv__(self, x): return UOp.alu(BinaryOps.IDIV, self, ufix(self.dtype, x))
|
||||
def __truediv__(self, x): return UOp.alu(BinaryOps.MUL, self, UOp.alu(UnaryOps.RECIP, ufix(self.dtype, x)))
|
||||
def __mod__(self, x): return UOp.alu(BinaryOps.MOD, self, ufix(self.dtype, x))
|
||||
def lt(self, x): return UOp.alu(BinaryOps.CMPLT, self, ufix(self.dtype, x))
|
||||
def ge(self, x): return -self.lt(x)
|
||||
@@ -210,8 +211,12 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var('x') + 0, lambda x: x), # x+0 -> x
|
||||
(UOp.var('x') - 0, lambda x: x), # x-0 -> x
|
||||
(UOp.var('x') * 1, lambda x: x), # x*1 -> x
|
||||
(UOp.var('x') // 1, lambda x: x), # x/1 -> x
|
||||
(UOp.var('x') // -1, lambda x: -x), # x/-1 -> -x
|
||||
(UOp.var('x') // UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x//x -> 1
|
||||
(UOp.var('x') // 1, lambda x: x), # x//1 -> x
|
||||
(UOp.var('x') // -1, lambda x: -x), # x//-1 -> -x
|
||||
(UOp.var('x') / UOp.var('x'), lambda x: UOp.const(x.dtype, 1)), # x/x -> 1
|
||||
(UOp.var('x') / 1, lambda x: x), # x/1 -> x
|
||||
(UOp.var('x') / -1, lambda x: -x), # x/-1 -> -x
|
||||
(UOp.var('x', dtype=dtypes.bool).max(UOp.const(dtypes.bool, False)), lambda x: x), # max(x, False) -> x
|
||||
# ** zero folding **
|
||||
#x*0 -> 0 or 0*x -> 0
|
||||
@@ -230,10 +235,14 @@ constant_folder = PatternMatcher([
|
||||
(UOp.var("x") % UOp.const(None, 1), lambda x: UOp.const(x.dtype, 0)),
|
||||
# (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UOp.var("x") * UOp.cvar("c0") + UOp.var("x") * UOp.cvar("c1"), lambda x,c0,c1: x*exec_alu(BinaryOps.ADD, x.dtype, [c0.arg, c1.arg])),
|
||||
# (x*c0)/c0 -> x
|
||||
# (x*c0)//c0 -> x
|
||||
((UOp.var("x") * UOp.cvar("c0")) // UOp.cvar("c0"), lambda x,c0: x if c0.arg != 0 else None),
|
||||
# (x/c0)/c1 -> x/(c0*c1)
|
||||
# (x*x2)/x2 -> x
|
||||
((UOp.var("x") * UOp.var("x2")) / UOp.var("x2"), lambda x,x2: x),
|
||||
# (x//c0)//c1 -> x//(c0*c1)
|
||||
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
|
||||
# (x/x1)/x2 -> x/(x1*x2)
|
||||
((UOp.var("x") / UOp.var("x2")) / UOp.var("x3"), lambda x,x2,x3: x/(x2*x3)),
|
||||
# c0 + x < c1 -> x < c1 - c0
|
||||
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
|
||||
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.ADD, x.dtype, [c1.arg, -c0.arg])))),
|
||||
|
||||
Reference in New Issue
Block a user