mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
infra for multi reduce asts (#5522)
* add reduce_info * _recurse_reduceops base * derive output shape * refactor * delete reduce_for_op * save lines * more line saving
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
import sys, pickle, atexit
|
||||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args
|
||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, cast, get_args
|
||||
from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
|
||||
from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
@@ -36,14 +36,12 @@ class ScheduleItem:
|
||||
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
||||
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp:
|
||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
|
||||
"""recursively create a lazyop"""
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
st = buf.st + st
|
||||
buf = buf.base
|
||||
# all buffers here are base now
|
||||
assert buf.op is not None
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
arg = buf.arg
|
||||
|
||||
# consts are always fused and generated
|
||||
if buf.op is MetaOps.CONST:
|
||||
@@ -76,24 +74,34 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
|
||||
if buf.op is MetaOps.ASSIGN:
|
||||
assert buf in outputs
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
||||
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
||||
st, arg = reduce_info[buf]
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = \
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, \
|
||||
reduce_info, cache) for x in buf.srcs), arg))
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer]):
|
||||
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache):
|
||||
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
if buf.op in ReduceOps:
|
||||
reduce_input, axis = buf.srcs[0], buf.arg
|
||||
assert st.contiguous
|
||||
st = ShapeTracker.from_shape(reduce_input.shape)
|
||||
reduce_info[buf] = (st, axis)
|
||||
for x in buf.srcs: _recurse_reduceops(x, st, realizes, outs, reduce_info, cache)
|
||||
cache.setdefault((buf, st))
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
||||
@@ -104,10 +112,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], re
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
|
||||
ast: List[LazyOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
for i, out in enumerate(outs):
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
_recurse_reduceops(out, out.st, realizes, outs, reduce_info, {})
|
||||
output_st = ShapeTracker.from_shape(next(iter(reduce_info)).shape if reduce_info else out.shape)
|
||||
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
|
||||
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache=cache)
|
||||
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
||||
@@ -267,7 +277,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
|
||||
buf.buffer.options = None
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes, reduce_for_op)) for group in output_groups.values()}
|
||||
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()}
|
||||
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
|
||||
|
||||
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||
|
||||
Reference in New Issue
Block a user