mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
@@ -94,7 +94,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
prg = k.to_program()
|
||||
print(prg.src)
|
||||
if_uops = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertIn(len(if_uops), {1,3})
|
||||
self.assertIn(len(if_uops), {1,2,3})
|
||||
conditions = if_uops[0].src[0].sparents
|
||||
self.assertLessEqual(len(conditions), 9)
|
||||
|
||||
|
||||
@@ -1043,7 +1043,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
assert k is not None
|
||||
ifs = [u for u in k.uops if u.op is Ops.IF]
|
||||
self.assertEqual(len(ifs), 4)
|
||||
self.assertEqual(len(ifs), 3)
|
||||
#for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
|
||||
self.assertLessEqual(len(ifs[0].src[0].sparents), 17)
|
||||
|
||||
|
||||
@@ -1055,6 +1055,8 @@ symbolic = symbolic_simple+PatternMatcher([
|
||||
(UPat(GroupOp.Commutative, name='x'), lambda x: x.replace(src=x.src[::-1]) if x.src[1].tuplize < x.src[0].tuplize else None),
|
||||
# group like
|
||||
((UPat.var("x") + UPat.var("y")) + UPat.var("x") * UPat.cvar("c"), lambda x,y,c: (x+x*c)+y),
|
||||
# ** boolean algebra **
|
||||
(UPat.var("x") | (UPat.var("x") & UPat.var()), lambda x: x), # x|(x&y) -> x
|
||||
# ** combine terms **
|
||||
(UPat.var("x") * UPat.cvar("c0") + UPat.var("x") * UPat.cvar("c1"), lambda x,c0,c1: x*(c0+c1)), # (x*c0)+(x*c1) -> x*(c0+c1)
|
||||
(UPat.var("x") + UPat.var("x") * UPat.cvar("c"), lambda x,c: x*(c+1)), # (x+x*c)-> x*(c+1)
|
||||
|
||||
Reference in New Issue
Block a user