mutate reduceop shapes pre ast creation [run_process_replay] (#5855)

This commit is contained in:
qazal
2024-08-01 20:00:05 +08:00
committed by GitHub
parent ba0a0008aa
commit 3e95e2bb0b

View File

@@ -88,7 +88,8 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache):
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\
reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache):
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return
cache.setdefault((buf, st))
if buf is not buf.base: st, buf = buf.st+st, buf.base
@@ -125,21 +126,23 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
# unify the kernel dims
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
ast: List[LazyOp] = []
inputs: Dict[LazyBuffer, int] = {}
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
for i, out in enumerate(outs):
_recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_view)))
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
# *** DAG creation: decide which LazyBuffers should realize ***