diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index ecea4e2d1b..b4be3e8bbe 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -106,23 +106,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: return apply_swizzle(src, new_input_st).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) def push_swizzle_down_through_reduce(r:UOp, v:UOp, src:UOp) -> UOp: - swizzle_st, src_st = unwrap(v.st), unwrap(src.st) - assert swizzle_st.contiguous, "can't push a non contiguous VIEW down to STORE" - assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE" + if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") output_shape = swizzle_st.reduce(r.axis_arg) - new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u) - return src.r(r.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape)) + return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]: if not (swizzles := [x for x in root.src if x.base is not x]): return None - swizzle_shapes = [(unwrap(x.st).shape, unwrap(x.src[0].st).shape) for x in swizzles] - assert all_same([(x, prod(y)) for x,y in swizzle_shapes]), f"swizzles must have the same size {swizzle_shapes}" - new_shape, new_input_shape = swizzle_shapes[0] - new_src = tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, ShapeTracker.from_shape(new_input_shape)) for x in root.src) - ret = root.replace(src=new_src) + assert all_same([(x.shape, prod(x.src[0].shape)) for x in swizzles]), f"swizzles must have the same size {swizzles}" + new_input_st = ShapeTracker.from_shape(swizzles[0].src[0].shape) + ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x, new_input_st) for x in root.src)) # update the ASSIGN offset to match the new shape - if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+ShapeTracker.from_shape(new_input_shape),) - return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(new_shape)) + if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,) + return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(swizzles[0].shape)) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"