diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 6b404f49ff..35074c7399 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -106,9 +106,8 @@ def permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTr # ** movementops rewrite rules -def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]: - if (st:=unwrap(view.st)).contiguous: return None - tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(rsrc.st).shape), r.axis_arg) +def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: + tmp, rshape = permute_reduce(ShapeTracker.from_shape(unwrap(src.st).shape), r.axis_arg) prshape = prod(rshape) strides = strides_for_shape(rshape) nv: List[View] = [] @@ -119,10 +118,10 @@ def view_r(view:UOp, r:UOp, rsrc:UOp) -> Optional[UOp]: new_input_st = tmp + ShapeTracker(tuple(nv)) _, new_rshape = permute_reduce(new_input_st, r.axis_arg) new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape))) - return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) + return st_fixup(src, lambda st:st+new_input_st, {}).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) -def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp: - swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st) +def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp, src:UOp) -> UOp: + swizzle_st, src_st = unwrap(swizzle.st), unwrap(src.st) assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE" assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE" output_shape = swizzle_st.reduce(root.axis_arg) @@ -136,8 +135,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: assert all_same([(x, prod(x), prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" new_shape, new_input_shape = swizzle_shapes[0] fixup_cache: Dict[UOp, UOp] = {} - new_srcs = [x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src] - ret = UOp(root.op, root.dtype, tuple(new_srcs), root.arg) + ret = root.replace(src=tuple(x.src[0] if x in swizzles else st_fixup(x, lambda st:st.reshape(new_input_shape), fixup_cache) for x in root.src)) return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: @@ -161,10 +159,10 @@ view_right = merge_views+PatternMatcher([ # ASSIGN can override st (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.ASSIGN, name="a"))), lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None), - # VIEW on a reduce creates a new VIEW - (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), + # non contiguous VIEW on a reduce creates a new VIEW + (UPat(Ops.REDUCE_AXIS, src=UPat.var("src"), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), # push a VIEW down to STORE, through a reduce (ONLY reshapes) - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, name="swizzle"),), name="root"), push_swizzle_down_through_reduce), + (UPat(Ops.REDUCE_AXIS, src=(UPat.var(name="src").view(name="swizzle"),), name="root"), push_swizzle_down_through_reduce), # push VIEW(s) down to STORE, through an elementwise op (ONLY reshapes) (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), push_swizzle_down_through_elementwise), (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),