diff --git a/test/test_schedule.py b/test/test_schedule.py index 734e32b813..ba87a2c095 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1871,7 +1871,7 @@ class TestIndexing(unittest.TestCase): ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - r = r + 2 + r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),)) sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) # this AST first needs to swizzle, but it doesn't have implicit movementops diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index a8d87a3453..6837c3387b 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -260,13 +260,9 @@ create_kernels = merge_views+PatternMatcher([ # ** create buffer ops + enumerate buffers -def load_buf(ctx:list[UOp], x:UOp): - if x not in ctx: ctx.append(x) - return UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), unwrap(x.st).to_uop())) - add_buffer_ops = PatternMatcher([ # LOAD - (UPat(Ops.BUFFER, name="x"), load_buf), + (UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))), # STORE (except for COPY/BUFFER_VIEW) (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), @@ -278,8 +274,9 @@ add_buffer_ops = PatternMatcher([ def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) -def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: - input_st = ShapeTracker.from_shape(unwrap(src.st).shape) +def swizzle_reduceop(r:UOp, src:UOp, view:UOp): + if (st:=unwrap(view.st)).contiguous: return None + input_st = ShapeTracker.from_shape(src.shape) tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) strides = strides_for_shape(rshape) @@ -290,20 +287,18 @@ def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) -def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: - if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") - output_shape = swizzle_st.reduce(r.axis_arg) - return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) +def reduceop_view_right(src:UOp, v:UOp, r:UOp): + assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" + return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) def elementwise_view_right(root:UOp) -> UOp|None: - if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None - assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" + if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" - # push the swizzle from src to root - output_swizzle = swizzles[0] - new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) - ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) - return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) + # place view after applying the elementwise op + new_shape = swizzles[0].base.shape + ret = root.replace(src=tuple(x.base if x.base.shape == new_shape else apply_swizzle(x.view(ShapeTracker.from_shape(new_shape))) for x in root.src)) + # reshape to match downstream shapes + return ret.reshape(root.shape) def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" @@ -317,12 +312,12 @@ view_right = merge_views+PatternMatcher([ lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), # STORE is the last child, so we just merge the ShapeTrackers and store the base (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), - # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), - # REDUCE(src.view()) -> REDUCE(src).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), - # ALU(src.view()) -> ALU(src).view() - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), + # push a non contiguous ShapeTracker through reduceop + (UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r"),), name="view"), swizzle_reduceop), + # apply view after reduceops + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.VIEW, src=(UPat.var("src"),), name="v"),), name="r"), reduceop_view_right), + # apply view after elementwise ops + (UPat(GroupOp.All-GroupOp.Buffer, name="root"), elementwise_view_right), # double reduce op collapses to a single reduce op (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) @@ -372,7 +367,7 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp: # substitute kernel sources for the target buffer ast = k.arg.ast.substitute({s.src[1].arg.ast:s.src[0] for s in k.src if s.op is Ops.ASSIGN}).sink() # add buffer ops - ast = graph_rewrite(ast, add_buffer_ops, bufs:=[s.buf_uop for s in k.src], bottom_up=True) + ast = graph_rewrite(ast, add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True) if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}") # unbind_vars + push views to edges ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)