mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
minimal diff for multioutput reduce pairs (#4030)
* simple fusion * compiler cache patch * Revert "compiler cache patch" This reverts commitfa18049597. * Revert "Revert "compiler cache patch"" This reverts commit57f8d41f98. * delete that * early sort * teeny renames * spec * .empty is great * delete sort * Update test_schedule.py * this is one kernel now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -14,6 +14,7 @@ from tinygrad import nn, dtypes
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
if isinstance(t, Tensor): t = [t]
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
@@ -21,7 +22,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i,out in enumerate(s.outputs):
|
||||
if GRAPH: realized_lazybuffer(out, 0)
|
||||
seen.add(out)
|
||||
sched = create_schedule([t_.lazydata for t_ in ([t] if isinstance(t, Tensor) else t)], seen)
|
||||
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched):
|
||||
for out in s.outputs: realized_lazybuffer(out, i+1)
|
||||
@@ -432,6 +433,39 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.contiguous() + y.contiguous()
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_group_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_inner_deps_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + out0 + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_midreduce_nofuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 3)
|
||||
|
||||
def test_group_midexpand_nofuse(self):
|
||||
a = Tensor.empty((32, 32, 32))
|
||||
b = Tensor.empty((1, 16))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b
|
||||
check_schedule([out0, out1], 4)
|
||||
|
||||
def test_group_midshrink_fuse(self):
|
||||
a = Tensor.empty(100, 100)
|
||||
b = Tensor.empty(10,)
|
||||
out0 = a.sum() + b[0]
|
||||
out1 = a.sum() + 2
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
x = Tensor.ones(4).contiguous().realize()
|
||||
@@ -448,7 +482,7 @@ class TestSchedule(unittest.TestCase):
|
||||
shared = x.sum().half().float()
|
||||
a = shared * 2
|
||||
b = shared * 3
|
||||
sched = check_schedule([a, b], 3)
|
||||
sched = check_schedule([a, b], 1)
|
||||
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
|
||||
|
||||
# reduce
|
||||
|
||||
Reference in New Issue
Block a user