diff --git a/test/test_schedule.py b/test/test_schedule.py index 932add6943..a6045f0a5d 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -6,7 +6,7 @@ import unittest from typing import List, Optional, Union from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps -from tinygrad.helpers import DEBUG, GRAPH +from tinygrad.helpers import DEBUG, GRAPH, flatten from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.graph import print_tree, realized_lazybuffer from tinygrad.engine.schedule import create_schedule @@ -14,6 +14,7 @@ from tinygrad import nn, dtypes from test.helpers import is_dtype_supported def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): + if isinstance(t, Tensor): t = [t] seen = set() if to_prerealize: for pre in to_prerealize: @@ -21,7 +22,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt for i,out in enumerate(s.outputs): if GRAPH: realized_lazybuffer(out, 0) seen.add(out) - sched = create_schedule([t_.lazydata for t_ in ([t] if isinstance(t, Tensor) else t)], seen) + sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen) if GRAPH: for i,s in enumerate(sched): for out in s.outputs: realized_lazybuffer(out, i+1) @@ -432,6 +433,39 @@ class TestSchedule(unittest.TestCase): out = x.contiguous() + y.contiguous() check_schedule(out, 2) + def test_group_fuse(self): + a = Tensor.empty((4, 4)) + out0 = a.sum() + 2 + out1 = a.sum() + 4 + check_schedule([out0, out1], 1) + + def test_group_inner_deps_fuse(self): + a = Tensor.empty((4, 4)) + out0 = a.sum() + 2 + out1 = a.sum() + out0 + 4 + check_schedule([out0, out1], 1) + + def test_group_midreduce_nofuse(self): + a = Tensor.empty((4, 4)) + b = Tensor.empty((4, 4)) + out0 = a.sum() + 2 + out1 = a.sum() + b.sum() + 4 + check_schedule([out0, out1], 3) + + def test_group_midexpand_nofuse(self): + a = Tensor.empty((32, 32, 32)) + b = Tensor.empty((1, 16)) + out0 = a.sum() + 2 + out1 = a.sum() + b + check_schedule([out0, out1], 4) + + def test_group_midshrink_fuse(self): + a = Tensor.empty(100, 100) + b = Tensor.empty(10,) + out0 = a.sum() + b[0] + out1 = a.sum() + 2 + check_schedule([out0, out1], 1) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): x = Tensor.ones(4).contiguous().realize() @@ -448,7 +482,7 @@ class TestSchedule(unittest.TestCase): shared = x.sum().half().float() a = shared * 2 b = shared * 3 - sched = check_schedule([a, b], 3) + sched = check_schedule([a, b], 1) for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs) # reduce diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 087ca3d5ff..2ba0803f7c 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -75,19 +75,21 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg) return ret -def _schedule_one(out:LazyBuffer, realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: +def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: inputs: List[LazyBuffer] = [] - var_vals: Dict[Variable, int] = out.st.var_vals.copy() - if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}: - op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) + ast: List[LazyOp] = [] + var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs]) + if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}: + ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs) else: - output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) - output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st - op = _recursive_lazyop(out, inputs, (out, ), var_vals, output_st, realizes, cache={}) - output_view, vv = output_view.simplify().unbind() - if vv: var_vals.update(vv) - op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view)) - return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals) + for i, out in enumerate(outs): + output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) + output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st + op = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={}) + output_view, vv = output_view.simplify().unbind() + if vv: var_vals.update(vv) + ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view))) + return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals) # recursively search the entire graph for all LazyBuffers, insert realizes after expands def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], @@ -157,13 +159,24 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul for tr,st in child_set.items(): if tr in realizes: realized_children[tr] = st - # can only have one output buffer # can only reduce contiguous # max one reduceop per kernel - if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): + if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r forced_realize = True break + if len(realized_children) > 1: + for rc in realized_children: + rc_parents = deque(x.base for x in rc.srcs) + while rc_parents: + if (p:=rc_parents.pop()).realized or p.op is LoadOps.CONST: continue + if p is r: continue + # max one reduceop per kernel + if p.op in ReduceOps: + can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r + forced_realize = True + break + for x in p.srcs: rc_parents.append(x.base) continue for tr_next in children[tr].keys(): if not tr_next.realized: @@ -195,12 +208,15 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul tr = tr.srcs[0].base reduce_for_op[tr] = r realizes[tr] = None - else: - assert len(realized_children) == 1 - reduce_for_op[next(iter(realized_children.keys()))] = r + else: reduce_for_op.update((tr, r) for tr in realized_children) + + output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list) + for r in realizes: + if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue + output_groups[(reduce_for_op[r], ) if r in reduce_for_op else (r, )].append(r) # preschedule all buffers in realizes - prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST} + prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()} schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs} assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}