diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f29bdea5bb..8aa49f6112 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -223,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]: if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce] return ctx.realizes -# break the SINK into kernels +# **** create kernels @dataclass(frozen=True) class Kernel: @@ -243,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp): return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape) DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER} + def append_to_kernel(ctx:KernelContext, x:UOp): new_srcs: list[UOp] = [] metadata = dict.fromkeys(x.arg.metadata) @@ -268,30 +269,7 @@ create_kernels = merge_views+PatternMatcher([ (UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None), ]) -# **** fix kernel AST - -# ** create buffer ops + enumerate buffers - -add_buffer_ops = PatternMatcher([ - # LOAD - (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), - # STORE (except for COPY/BUFFER_VIEW) - (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), - # partial assign can store to a non-contiguous ShapeTracker - (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), - lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), - # otherwise the store is contiguous - (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), - lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), - # if the last child is a VIEW we merge the ShapeTrackers and store the base - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), - lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), - # remove CONTIGUOUS/DEVICE from kernel AST - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), -]) - -# ** push views to buffer ops +# **** swizzler def apply_swizzle(u:UOp) -> UOp: with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) @@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp): assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}" return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape)) -def elementwise_view_right(root:UOp) -> UOp|None: +def elementwise_view_right(root:UOp): if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" # place view after applying the elementwise op @@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None: # reshape to match downstream shapes return root.replace(src=tuple(new_src)).reshape(root.shape) -def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: +def merge_double_reduce(root:UOp, first_reduce:UOp): assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) @@ -340,9 +318,9 @@ view_right = merge_views+PatternMatcher([ (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -# ** unbind variables +# **** unbind variables -def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp) -> UOp|None: +def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp): st = unwrap(x.st).simplify() if any(x.op is Ops.BIND for x in st.vars()): st, var_vals = st.unbind() @@ -354,7 +332,26 @@ def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): return var unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) -# ** fix_kernel_ops +# **** fix kernel AST + +add_buffer_ops = PatternMatcher([ + # LOAD + (UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)), + # STORE (except for COPY/BUFFER_VIEW) + (UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x), + # partial assign can store to a non-contiguous ShapeTracker + (UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()), + # otherwise the store is contiguous + (UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)), + lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()), + # if the last child is a VIEW we merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))), + lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)), + # remove CONTIGUOUS/DEVICE from kernel AST + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())), +]) def check_load_st(glbl:UOp, view:UOp): if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return