mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
early update the reduceop axis [run_process_replay] (#5854)
This commit is contained in:
@@ -97,6 +97,7 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
if buf.op in ReduceOps and buf not in reduce_info:
|
||||
axis = buf.arg
|
||||
if not st.contiguous:
|
||||
# push the movementop to the input
|
||||
tmp, rshape = _permute_reduce(input_st, axis)
|
||||
prshape = prod(rshape)
|
||||
strides = strides_for_shape(rshape)
|
||||
@@ -105,17 +106,17 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
|
||||
nv.append(View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
|
||||
input_st = tmp + ShapeTracker(tuple(nv))
|
||||
# update the axis
|
||||
_, new_rshape = _permute_reduce(input_st, axis)
|
||||
axis = tuple(range(len(input_st.shape)-len(new_rshape), len(input_st.shape)))
|
||||
elif reduce_info:
|
||||
top_reduce, (top_reduce_input_st, top_reduce_axes) = deque(reduce_info.items(), 1).pop()
|
||||
_, rshape = _permute_reduce(top_reduce_input_st, top_reduce_axes)
|
||||
new_axis = tuple(range(len(top_reduce_input_st.shape)-len(rshape), len(top_reduce_input_st.shape)))
|
||||
if buf.op is buf.srcs[0].base.op:
|
||||
# merge this reduce with its parent
|
||||
reduce_info[top_reduce] = (top_reduce_input_st, axis+new_axis)
|
||||
reduce_info[top_reduce] = (top_reduce_input_st, top_reduce_axes+axis)
|
||||
return
|
||||
reduce_info[top_reduce] = (top_reduce_input_st, new_axis)
|
||||
# reshape this reduce per its top axis
|
||||
input_st = input_st.reshape(tuple(1 if i in new_axis else s for i,s in enumerate(top_reduce_input_st.shape)))
|
||||
input_st = input_st.reshape(tuple(1 if i in top_reduce_axes else s for i,s in enumerate(top_reduce_input_st.shape)))
|
||||
reduce_info[buf] = (input_st, axis)
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
|
||||
Reference in New Issue
Block a user