mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
simpler reduceop children chasing (#4350)
* simplest case * midreduce case * all tests * pending things * unify tests
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
@@ -680,5 +680,37 @@ class TestSchedule(unittest.TestCase):
|
||||
# sched = check_schedule([b, c], 4)
|
||||
# doesn't store either in half because it doesn't chase
|
||||
|
||||
def test_reduce_simple_chase(self):
|
||||
a = Tensor.empty(4, 4, 4)
|
||||
r = a.sum(0) + 6
|
||||
b = r.sum(0) * 4
|
||||
c = r.sum(1) * 2
|
||||
schedule = check_schedule([b, c], 3)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_push_permute_chase(self):
|
||||
a = Tensor.empty(4, 4, 4)
|
||||
b = Tensor.empty(4, 4)
|
||||
r = a.sum(2) + b
|
||||
d = r.T * 4
|
||||
e = r * d
|
||||
schedule = check_schedule([d, e], 3)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_push_shrink_chase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(4)
|
||||
c = Tensor.empty(16, )
|
||||
r = a.sum(1) + c
|
||||
d = r[:4] * b
|
||||
schedule = check_schedule(d, 2)
|
||||
assert schedule[0].ast[0].src[0].op is BinaryOps.ADD
|
||||
|
||||
def test_midreduce_nochase(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
schedule = check_schedule(b, 2)
|
||||
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user