minimal diff for multioutput reduce pairs (#4030)

* simple fusion

* compiler cache patch

* Revert "compiler cache patch"

This reverts commit fa18049597.

* Revert "Revert "compiler cache patch""

This reverts commit 57f8d41f98.

* delete that

* early sort

* teeny renames

* spec

* .empty is great

* delete sort

* Update test_schedule.py

* this is one kernel now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal
2024-04-17 17:55:44 +03:00
committed by GitHub
parent 8564e28a1b
commit f75020a903
2 changed files with 70 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ import unittest
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.engine.schedule import create_schedule
@@ -14,6 +14,7 @@ from tinygrad import nn, dtypes
from test.helpers import is_dtype_supported
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
if isinstance(t, Tensor): t = [t]
seen = set()
if to_prerealize:
for pre in to_prerealize:
@@ -21,7 +22,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule([t_.lazydata for t_ in ([t] if isinstance(t, Tensor) else t)], seen)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
if GRAPH:
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
@@ -432,6 +433,39 @@ class TestSchedule(unittest.TestCase):
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
def test_group_fuse(self):
a = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + 4
check_schedule([out0, out1], 1)
def test_group_inner_deps_fuse(self):
a = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + out0 + 4
check_schedule([out0, out1], 1)
def test_group_midreduce_nofuse(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 3)
def test_group_midexpand_nofuse(self):
a = Tensor.empty((32, 32, 32))
b = Tensor.empty((1, 16))
out0 = a.sum() + 2
out1 = a.sum() + b
check_schedule([out0, out1], 4)
def test_group_midshrink_fuse(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
out0 = a.sum() + b[0]
out1 = a.sum() + 2
check_schedule([out0, out1], 1)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()
@@ -448,7 +482,7 @@ class TestSchedule(unittest.TestCase):
shared = x.sum().half().float()
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 3)
sched = check_schedule([a, b], 1)
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
# reduce