move permute_reduces to uop movementops [run_process_replay] (#6272)

This commit is contained in:
qazal
2024-08-25 15:25:51 +08:00
committed by GitHub
parent b86907c6c7
commit 70015bd89c

View File

@@ -102,11 +102,6 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer],
reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \
@@ -153,9 +148,14 @@ def reshape_uop(u:UOp, new_shape:Tuple[sint, ...], uop_sts:Dict[UOp, ShapeTracke
cache[u] = reshaped = u if new_srcs == u.src else replace(u, src=new_srcs)
return reshaped
def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis
tmp = input_st.permute(permute_axis)
return tmp, tmp.shape[-len(axis):]
def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[int, ...]]:
# push the movementop to the buffer uop
tmp, rshape = _permute_reduce(input_st, axis)
tmp, rshape = permute_reduce(input_st, axis)
prshape = prod(rshape)
strides = strides_for_shape(rshape)
nv: List[View] = []
@@ -164,7 +164,7 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None))
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = _permute_reduce(new_input_st, axis)
_, new_rshape = permute_reduce(new_input_st, axis)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return new_input_st, new_axis