infra for multi reduce asts (#5522)

* add reduce_info

* _recurse_reduceops base

* derive output shape

* refactor

* delete reduce_for_op

* save lines

* more line saving
This commit is contained in:
qazal
2024-07-17 22:23:46 +08:00
committed by GitHub
parent dcd462860f
commit fbe0233be3

View File

@@ -1,8 +1,8 @@
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, cast, get_args
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
from tinygrad.shape.symbolic import Variable
@@ -36,14 +36,12 @@ class ScheduleItem:
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
"""recursively create a lazyop"""
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
if buf is not buf.base: st, buf = buf.st+st, buf.base
arg = buf.arg
# consts are always fused and generated
if buf.op is MetaOps.CONST:
@@ -76,24 +74,34 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is MetaOps.CONTIGUOUS:
assert buf in outputs
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
if buf.op is MetaOps.ASSIGN:
assert buf in outputs
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
st, arg = reduce_info[buf]
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
return ret
return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, \
reduce_info, cache) for x in buf.srcs), arg))
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer]):
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache):
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return
if buf is not buf.base: st, buf = buf.st+st, buf.base
if buf.op in ReduceOps:
reduce_input, axis = buf.srcs[0], buf.arg
assert st.contiguous
st = ShapeTracker.from_shape(reduce_input.shape)
reduce_info[buf] = (st, axis)
for x in buf.srcs: _recurse_reduceops(x, st, realizes, outs, reduce_info, cache)
cache.setdefault((buf, st))
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
@@ -104,10 +112,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], re
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
ast: List[LazyOp] = []
inputs: List[LazyBuffer] = []
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
_recurse_reduceops(out, out.st, realizes, outs, reduce_info, {})
output_st = ShapeTracker.from_shape(next(iter(reduce_info)).shape if reduce_info else out.shape)
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache=cache)
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
@@ -267,7 +277,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
buf.buffer.options = None
# preschedule all buffers in realizes
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes, reduce_for_op)) for group in output_groups.values()}
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)