mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
partial fusion spec (#4398)
This commit is contained in:
@@ -712,5 +712,40 @@ class TestSchedule(unittest.TestCase):
|
||||
schedule = check_schedule(b, 2)
|
||||
assert schedule[0].ast[0].src[0].op is ReduceOps.MAX
|
||||
|
||||
# pattern in test_transformer
|
||||
def test_partial_fuse1(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(16, 16)
|
||||
c = a.sum() + 2
|
||||
d = (a.sum() - b.sum()) * 4
|
||||
check_schedule([c, d], 3)
|
||||
|
||||
# pattern in conv
|
||||
def test_partial_fuse2(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(16, 16)
|
||||
c = a.sum() + 2
|
||||
d = b.sum() - c
|
||||
check_schedule([c, d], 2)
|
||||
|
||||
# pattern in adam
|
||||
def test_partial_fuse3(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(16, 16)
|
||||
c = a.sum() + 2
|
||||
d = a.sum() * 2
|
||||
e = c * d
|
||||
f = b.sum() - e
|
||||
check_schedule([c, d, e, f], 3)
|
||||
|
||||
def test_partial_fuse4(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = Tensor.empty(16, 16)
|
||||
c = a.sum() + 2
|
||||
d = a.sum() * 2
|
||||
e = c * d
|
||||
f = (b - d).sum() - e
|
||||
check_schedule([c, d, e, f], 3)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user