mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
where fold try 2 (#3748)
* where fold try 2 * assign fold * test_where_fold works * add gated store support to ops_python --------- Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com>
This commit is contained in:
@@ -198,6 +198,25 @@ class TestLinearizer(unittest.TestCase):
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_assign_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
m = Tensor.ones(4, 4).shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
a.assign(a+m)
|
||||
a.realize()
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_where_fold(self):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
a.assign(b.where(2, a))
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(*sched[-1].ast)
|
||||
lin.hand_coded_optimizations()
|
||||
lin.linearize()
|
||||
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
|
||||
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
|
||||
|
||||
def test_simplify_uop(self):
|
||||
def helper_test_simplify(uop, dtype, vin, arg=None):
|
||||
ast = LazyOp(BufferOps.CONST, (),
|
||||
|
||||
Reference in New Issue
Block a user