mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 13:58:00 -05:00
mutate reduceop shapes pre ast creation [run_process_replay] (#5855)
This commit is contained in:
@@ -88,7 +88,8 @@ def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeT
|
||||
tmp = input_st.permute(permute_axis)
|
||||
return tmp, tmp.shape[-len(axis):]
|
||||
|
||||
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict, cache):
|
||||
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],\
|
||||
reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache):
|
||||
if buf.base.realized is not None or (buf.base in realizes and buf.base not in outs) or (buf, st) in cache: return
|
||||
cache.setdefault((buf, st))
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
@@ -125,21 +126,23 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
||||
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
|
||||
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
# unify the kernel dims
|
||||
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
|
||||
for out in outs: _recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
|
||||
ast: List[LazyOp] = []
|
||||
inputs: Dict[LazyBuffer, int] = {}
|
||||
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
|
||||
for i, out in enumerate(outs):
|
||||
_recurse_reduceops(out, out.st, realizes, outs, reduce_info, seen_ops)
|
||||
output_st = ShapeTracker.from_shape(reduce_st(*deque(reduce_info.values(), 1).pop()) if reduce_info else out.shape)
|
||||
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
|
||||
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, reduce_info, cache=cache)
|
||||
output_view = out.arg[0] if out.op is MetaOps.ASSIGN and out.arg else output_st
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop,), MemBuffer(i, out.dtype, output_view)))
|
||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
Reference in New Issue
Block a user