simpler reduceop children chasing (#4350)

* simplest case

* midreduce case

* all tests

* pending things

* unify tests
This commit is contained in:
qazal
2024-05-02 15:15:30 +03:00
committed by GitHub
parent 22376e53b7
commit 0b47818e0f
2 changed files with 35 additions and 4 deletions

View File

@@ -5,7 +5,7 @@
import unittest
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
@@ -680,5 +680,37 @@ class TestSchedule(unittest.TestCase):
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase
def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
b = r.sum(0) * 4
c = r.sum(1) * 2
schedule = check_schedule([b, c], 3)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_push_permute_chase(self):
a = Tensor.empty(4, 4, 4)
b = Tensor.empty(4, 4)
r = a.sum(2) + b
d = r.T * 4
e = r * d
schedule = check_schedule([d, e], 3)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_push_shrink_chase(self):
a = Tensor.empty(16, 16)
b = Tensor.empty(4)
c = Tensor.empty(16, )
r = a.sum(1) + c
d = r[:4] * b
schedule = check_schedule(d, 2)
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
def test_midreduce_nochase(self):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
if __name__ == '__main__':
unittest.main(verbosity=2)