mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
minimal diff for multioutput reduce pairs (#4030)
* simple fusion * compiler cache patch * Revert "compiler cache patch" This reverts commitfa18049597. * Revert "Revert "compiler cache patch"" This reverts commit57f8d41f98. * delete that * early sort * teeny renames * spec * .empty is great * delete sort * Update test_schedule.py * this is one kernel now --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ import unittest
|
||||
from typing import List, Optional, Union
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.helpers import DEBUG, GRAPH
|
||||
from tinygrad.helpers import DEBUG, GRAPH, flatten
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
@@ -14,6 +14,7 @@ from tinygrad import nn, dtypes
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
if isinstance(t, Tensor): t = [t]
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
@@ -21,7 +22,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i,out in enumerate(s.outputs):
|
||||
if GRAPH: realized_lazybuffer(out, 0)
|
||||
seen.add(out)
|
||||
sched = create_schedule([t_.lazydata for t_ in ([t] if isinstance(t, Tensor) else t)], seen)
|
||||
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched):
|
||||
for out in s.outputs: realized_lazybuffer(out, i+1)
|
||||
@@ -432,6 +433,39 @@ class TestSchedule(unittest.TestCase):
|
||||
out = x.contiguous() + y.contiguous()
|
||||
check_schedule(out, 2)
|
||||
|
||||
def test_group_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_inner_deps_fuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + out0 + 4
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
def test_group_midreduce_nofuse(self):
|
||||
a = Tensor.empty((4, 4))
|
||||
b = Tensor.empty((4, 4))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b.sum() + 4
|
||||
check_schedule([out0, out1], 3)
|
||||
|
||||
def test_group_midexpand_nofuse(self):
|
||||
a = Tensor.empty((32, 32, 32))
|
||||
b = Tensor.empty((1, 16))
|
||||
out0 = a.sum() + 2
|
||||
out1 = a.sum() + b
|
||||
check_schedule([out0, out1], 4)
|
||||
|
||||
def test_group_midshrink_fuse(self):
|
||||
a = Tensor.empty(100, 100)
|
||||
b = Tensor.empty(10,)
|
||||
out0 = a.sum() + b[0]
|
||||
out1 = a.sum() + 2
|
||||
check_schedule([out0, out1], 1)
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_prefer_half_buffer(self):
|
||||
x = Tensor.ones(4).contiguous().realize()
|
||||
@@ -448,7 +482,7 @@ class TestSchedule(unittest.TestCase):
|
||||
shared = x.sum().half().float()
|
||||
a = shared * 2
|
||||
b = shared * 3
|
||||
sched = check_schedule([a, b], 3)
|
||||
sched = check_schedule([a, b], 1)
|
||||
for si in sched[:-2]: assert all(out.dtype is dtypes.half for out in si.outputs)
|
||||
|
||||
# reduce
|
||||
|
||||
@@ -75,19 +75,21 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[Laz
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
ast: List[LazyOp] = []
|
||||
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
if outs[0].op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
ast, inputs = [LazyOp(outs[0].op, (), outs[0].arg)], list(outs[0].srcs)
|
||||
else:
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
||||
op = _recursive_lazyop(out, inputs, (out, ), var_vals, output_st, realizes, cache={})
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view))
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
||||
op = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, cache={})
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view)))
|
||||
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
|
||||
@@ -157,13 +159,24 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
||||
for tr,st in child_set.items():
|
||||
if tr in realizes:
|
||||
realized_children[tr] = st
|
||||
# can only have one output buffer
|
||||
# can only reduce contiguous
|
||||
# max one reduceop per kernel
|
||||
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
if not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
if len(realized_children) > 1:
|
||||
for rc in realized_children:
|
||||
rc_parents = deque(x.base for x in rc.srcs)
|
||||
while rc_parents:
|
||||
if (p:=rc_parents.pop()).realized or p.op is LoadOps.CONST: continue
|
||||
if p is r: continue
|
||||
# max one reduceop per kernel
|
||||
if p.op in ReduceOps:
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
for x in p.srcs: rc_parents.append(x.base)
|
||||
continue
|
||||
for tr_next in children[tr].keys():
|
||||
if not tr_next.realized:
|
||||
@@ -195,12 +208,15 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
|
||||
tr = tr.srcs[0].base
|
||||
reduce_for_op[tr] = r
|
||||
realizes[tr] = None
|
||||
else:
|
||||
assert len(realized_children) == 1
|
||||
reduce_for_op[next(iter(realized_children.keys()))] = r
|
||||
else: reduce_for_op.update((tr, r) for tr in realized_children)
|
||||
|
||||
output_groups: DefaultDict[Tuple, List[LazyBuffer]] = defaultdict(list)
|
||||
for r in realizes:
|
||||
if r.realized is not None or r.op is LoadOps.CONST or r in seen: continue
|
||||
output_groups[(reduce_for_op[r], ) if r in reduce_for_op else (r, )].append(r)
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
|
||||
prescheduled = {group[0]:_schedule_group(tuple(group), realizes, reduce_for_op) for group in output_groups.values()}
|
||||
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
|
||||
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user