diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index abc6e5918c..535cb93da0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,8 +1,8 @@ import sys, pickle, atexit from collections import defaultdict, deque from dataclasses import dataclass -from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, get_args -from tinygrad.ops import MetaOps, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps +from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Union, cast, get_args +from tinygrad.ops import MetaOps, BufferOps, LazyOp, Op, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps from tinygrad.engine.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, colored, prod, dedup, all_int, merge_dicts, getenv, Metadata from tinygrad.shape.symbolic import Variable @@ -36,14 +36,12 @@ class ScheduleItem: # *** DAG transformation: List[LazyBuffer] -> ScheduleItem *** def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker, - realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], cache) -> LazyOp: + realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], + reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp: """recursively create a lazyop""" if (buf, st) in cache: return cache[(buf, st)] - if buf != buf.base: - st = buf.st + st - buf = buf.base - # all buffers here are base now - assert buf.op is not None + if buf is not buf.base: st, buf = buf.st+st, buf.base + arg = buf.arg # consts are always fused and generated if buf.op is MetaOps.CONST: @@ -76,24 +74,34 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz # if a CONTIGUOUS or ASSIGN made it all the way here, just skip it if buf.op is MetaOps.CONTIGUOUS: assert buf in outputs - return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache) + return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache) if buf.op is MetaOps.ASSIGN: assert buf in outputs assert buf.srcs[1].base is buf.srcs[1], "assign must be to base" assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}" - return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, cache) + return _recursive_lazyop(buf.srcs[0], inputs, outputs, var_vals, st, realizes, assign_targets, reduce_info, cache) # if it's a reduce, we have to change the shapetracker if buf.op in ReduceOps: assert st.contiguous, "ReduceOps late fusion must be contiguous" - st = ShapeTracker.from_shape(buf.srcs[0].shape) + st, arg = reduce_info[buf] # otherwise we fuse it like normal - cache[(buf, st)] = ret = \ - LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg) - return ret + return cache.setdefault((buf, st), LazyOp(cast(Op,buf.op), tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, \ + reduce_info, cache) for x in buf.srcs), arg)) -def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer]): +def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache): + if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return + if buf is not buf.base: st, buf = buf.st+st, buf.base + if buf.op in ReduceOps: + reduce_input, axis = buf.srcs[0], buf.arg + assert st.contiguous + st = ShapeTracker.from_shape(reduce_input.shape) + reduce_info[buf] = (st, axis) + for x in buf.srcs: _recurse_reduceops(x, st, realizes, outs, reduce_info, cache) + cache.setdefault((buf, st)) + +def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]): """describe the computation for a LazyBuffer with LazyOp + inputs + var_vals""" if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) @@ -104,10 +112,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], re cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {} ast: List[LazyOp] = [] inputs: List[LazyBuffer] = [] + reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {} for i, out in enumerate(outs): - output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) + _recurse_reduceops(out, out.st, realizes, outs, reduce_info, {}) + output_st = ShapeTracker.from_shape(next(iter(reduce_info)).shape if reduce_info else out.shape) output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st - lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache=cache) + lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache) output_view, vv = output_view.simplify().unbind() if vv: var_vals.update(vv) ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view))) @@ -267,7 +277,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]): buf.buffer.options = None # preschedule all buffers in realizes - prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes, reduce_for_op)) for group in output_groups.values()} + prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes)) for group in output_groups.values()} schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]} graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)