minimal diff for multioutput reduce pairs (#4030)

* simple fusion

* compiler cache patch

* Revert "compiler cache patch"

This reverts commit fa18049597.

* Revert "Revert "compiler cache patch""

This reverts commit 57f8d41f98.

* delete that

* early sort

* teeny renames

* spec

* .empty is great

* delete sort

* Update test_schedule.py

* this is one kernel now

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal
2024-04-17 17:55:44 +03:00
committed by GitHub
parent 8564e28a1b
commit f75020a903
2 changed files with 70 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ import unittest
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.engine.schedule import create_schedule
@@ -14,6 +14,7 @@ from tinygrad import nn, dtypes
from test.helpers import is_dtype_supported
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
if isinstance(t, Tensor): t = [t]
seen = set()
if to_prerealize:
for pre in to_prerealize:
@@ -21,7 +22,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule([t_.lazydata for t_ in ([t] if isinstance(t, Tensor) else t)], seen)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
if GRAPH:
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
@@ -432,6 +433,39 @@ class TestSchedule(unittest.TestCase):
out = x.contiguous() + y.contiguous()
check_schedule(out, 2)
def test_group_fuse(self):
a = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + 4
check_schedule([out0, out1], 1)
def test_group_inner_deps_fuse(self):
a = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + out0 + 4
check_schedule([out0, out1], 1)
def test_group_midreduce_nofuse(self):
a = Tensor.empty((4, 4))
b = Tensor.empty((4, 4))
out0 = a.sum() + 2
out1 = a.sum() + b.sum() + 4
check_schedule([out0, out1], 3)
def test_group_midexpand_nofuse(self):
a = Tensor.empty((32, 32, 32))
b = Tensor.empty((1, 16))
out0 = a.sum() + 2
out1 = a.sum() + b
check_schedule([out0, out1], 4)
def test_group_midshrink_fuse(self):
a = Tensor.empty(100, 100)
b = Tensor.empty(10,)
out0 = a.sum() + b[0]
out1 = a.sum() + 2
check_schedule([out0, out1], 1)
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_prefer_half_buffer(self):
x = Tensor.ones(4).contiguous().realize()
@@ -448,7 +482,7 @@ class TestSchedule(unittest.TestCase):
shared = x.sum().half().float()
a = shared * 2
b = shared * 3
sched = check_schedule([a, b], 3)
sched = check_schedule([a, b], 1)
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
# reduce

View File

@@ -75,19 +75,21 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
ast: List[LazyOp] = []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
op = _recursive_lazyop(out, inputs, (out, ), var_vals, output_st, realizes, cache={})
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view))
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
op = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={})
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view)))
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
@@ -157,13 +159,24 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
if len(realized_children) > 1:
for rc in realized_children:
rc_parents = deque(x.base for x in rc.srcs)
while rc_parents:
if (p:=rc_parents.pop()).realized or p.op is LoadOps.CONST: continue
if p is r: continue
# max one reduceop per kernel
if p.op in ReduceOps:
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
for x in p.srcs: rc_parents.append(x.base)
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
@@ -195,12 +208,15 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
tr = tr.srcs[0].base
reduce_for_op[tr] = r
realizes[tr] = None
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
else: reduce_for_op.update((tr, r) for tr in realized_children)
output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list)
for r in realizes:
if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue
output_groups[(reduce_for_op[r], ) if r in reduce_for_op else (r, )].append(r)
# preschedule all buffers in realizes
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}