diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 4c0d53f3ff..e7c3447f4e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -102,11 +102,6 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) -def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: - permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis) + axis - tmp = input_st.permute(permute_axis) - return tmp, tmp.shape[-len(axis):] - def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer, None], outs:List[LazyBuffer], reduce_info:Dict[Tuple[LazyBuffer, ShapeTracker], Tuple[ShapeTracker, Tuple[int, ...]]], cache:Dict[Tuple[LazyBuffer, ShapeTracker], Optional[Tuple[LazyBuffer, ShapeTracker]]]) -> \ @@ -153,9 +148,14 @@ def reshape_uop(u:UOp, new_shape:Tuple[sint, ...], uop_sts:Dict[UOp, ShapeTracke cache[u] = reshaped = u if new_srcs == u.src else replace(u, src=new_srcs) return reshaped +def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]: + permute_axis = tuple(i for i in range(len(input_st.shape)) if i not in axis)+axis + tmp = input_st.permute(permute_axis) + return tmp, tmp.shape[-len(axis):] + def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[int, ...]]: # push the movementop to the buffer uop - tmp, rshape = _permute_reduce(input_st, axis) + tmp, rshape = permute_reduce(input_st, axis) prshape = prod(rshape) strides = strides_for_shape(rshape) nv: List[View] = [] @@ -164,7 +164,7 @@ def swizzle_reduceop(input_st:ShapeTracker, swizzle:ShapeTracker, axis:Tuple[int v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None)) # update input_st and axis new_input_st = tmp + ShapeTracker(tuple(nv)) - _, new_rshape = _permute_reduce(new_input_st, axis) + _, new_rshape = permute_reduce(new_input_st, axis) new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape))) return new_input_st, new_axis