early update the reduceop axis [run_process_replay] (#5854)

This commit is contained in:
qazal
2024-08-01 19:08:40 +08:00
committed by GitHub
parent eb91423cb4
commit ba0a0008aa

View File

@@ -97,6 +97,7 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
if buf.op in ReduceOps and buf not in reduce_info:
axis = buf.arg
if not st.contiguous:
# push the movementop to the input
tmp, rshape = _permute_reduce(input_st, axis)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
@@ -105,17 +106,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
input_st = tmp + ShapeTracker(tuple(nv))
# update the axis
_, new_rshape = _permute_reduce(input_st, axis)
axis = tuple(range(len(input_st.shape)-len(new_rshape), len(input_st.shape)))
elif reduce_info:
top_reduce, (top_reduce_input_st, top_reduce_axes) = deque(reduce_info.items(), 1).pop()
_, rshape = _permute_reduce(top_reduce_input_st, top_reduce_axes)
new_axis = tuple(range(len(top_reduce_input_st.shape)-len(rshape), len(top_reduce_input_st.shape)))
if buf.op is buf.srcs[0].base.op:
# merge this reduce with its parent
reduce_info[top_reduce] = (top_reduce_input_st, axis+new_axis)
reduce_info[top_reduce] = (top_reduce_input_st, top_reduce_axes+axis)
return
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
# reshape this reduce per its top axis
input_st = input_st.reshape(tuple(1 if i in new_axis else s for i,s in enumerate(top_reduce_input_st.shape)))
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
reduce_info[buf] = (input_st, axis)
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):