From 3e95e2bb0ba8a2297bed8e513344666503a5dd2c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 1 Aug 2024 20:00:05 +0800 Subject: [PATCH] mutate reduceop shapes pre ast creation [run_process_replay] (#5855) --- tinygrad/engine/schedule.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 9a174e755c..efe019ba8d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -88,7 +88,8 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT tmp = input_st.permute(permute_axis) return tmp, tmp.shape[-len(axis):] -def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache): +def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\ + reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache): if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return cache.setdefault((buf, st)) if buf is not buf.base: st, buf = buf.st+st, buf.base @@ -125,21 +126,23 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]): rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,)))) return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, [] if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, [] - var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs]) + # unify the kernel dims + reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {} + seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {} + for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops) + # create the stores + var_vals = merge_dicts([out.st.var_vals.copy() for out in outs]) assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN} cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {} ast: List[LazyOp] = [] inputs: Dict[LazyBuffer, int] = {} - reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {} - seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {} for i, out in enumerate(outs): - _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops) output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape) - output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache) + output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st output_view, vv = output_view.simplify().unbind() if vv: var_vals.update(vv) - ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view))) + ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_view))) return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]) # *** DAG creation: decide which LazyBuffers should realize ***