From b6c617272aba46bd055207faaa49099e85589ed7 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 3 Feb 2025 07:59:11 -0500 Subject: [PATCH] New schedule.py Order [pr] (#8874) --- tinygrad/engine/schedule.py | 406 ++++++++++++++++++------------------ 1 file changed, 199 insertions(+), 207 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index e0aebefb35..98d527b97a 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -13,25 +13,70 @@ from tinygrad.device import Buffer # creation can recurse a lot sys.setrecursionlimit(10000) -# **** ScheduleItem return type +# **** schedule simplifier -@dataclass(frozen=True) -class ScheduleItem: - ast: UOp - bufs: tuple[Buffer, ...] - metadata: tuple[Metadata, ...] - @property - def outputs(self) -> tuple[Buffer, ...]: - """Read/write or write only buffers in the schedule.""" - return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) - @property - def inputs(self) -> tuple[Buffer, ...]: - """Read only buffers in the schedule.""" - return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) - @functools.cached_property - def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) +def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: + if not all_int(x.shape): return None + # remove reduce on unmasked const + prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) + ret = x.const_arg + match reduce.arg[0]: + case Ops.ADD: ret *= prshape + case Ops.MUL: ret **= prshape + case Ops.MAX: pass # NOTE: Ops.MAX is passthrough + case _: return None + return reduce.const_like(ret) -# **** Schedule context and big graph +def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): + if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) +def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): + new_src = list(alu.src) + for i,s in enumerate(alu.src): + if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src + if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) + +sym = symbolic_simple+PatternMatcher([ + # UOp with size 0 is zero + (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ + and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), + # reduce of size 0 is the identity element + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), + lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), + # reduce of const is collapsed (TODO: make this a generic rule for stride0) + (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), + # COPY(CONST) creates a new CONST on the destination device + (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), + # no COPY to same device, except clone (arg is True) + (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), + lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), + # remove cast to image when it's already a contiguous image + (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), + lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), + # remove contiguous if we can just view the buffer + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), + lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), + # contiguous/buffer is already contiguous + (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), + # support for using a contiguous permuted view instead of the parent view if one exists + (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), + (UPat(GroupOp.ALU, name="alu"), replace_contiguous), + # remove CONST/BIND/BUFFER from SINK + (UPat(Ops.SINK, name="root"), + lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) + if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), +]) + +remove_movement_ops = merge_views+PatternMatcher([ + # NOTE: movement ops are always applied to base + (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), + # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) + (UPat(Ops.VIEW, name="view"), + lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), +]) + +# **** UOp realization @dataclass(frozen=True) class ScheduleContext: @@ -71,140 +116,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret -# **** AST graph rewrite - -# ** movement ops - -def apply_swizzle(u:UOp) -> UOp: - with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) - -def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: - input_st = ShapeTracker.from_shape(unwrap(src.st).shape) - tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) - prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) - strides = strides_for_shape(rshape) - nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, - v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] - # update input_st and axis - new_input_st = tmp + ShapeTracker(tuple(nv)) - new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) - return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) - -def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: - if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") - output_shape = swizzle_st.reduce(r.axis_arg) - return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) - -def elementwise_view_right(root:UOp) -> UOp|None: - if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None - assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" - assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" - # push the swizzle from src to root - output_swizzle = swizzles[0] - new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) - ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) - return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) - -def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: - assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" - assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" - return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) - -# push VIEW to children -view_right = merge_views+PatternMatcher([ - # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), - lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), - # STORE is the last child, so we just merge the ShapeTrackers and store the base - (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), - # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), - # REDUCE(src.view()) -> REDUCE(src).view() - (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), - # ALU(src.view()) -> ALU(src).view() - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), - # double reduce op collapses to a single reduce op - (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), -]) - -# ** ScheduleItem context builder - -@dataclass(frozen=True) -class ScheduleItemContext: - var_vals: dict[Variable, int] - sts: set[ShapeTracker] = field(default_factory=set) - bufs: list[UOp] = field(default_factory=list) - -def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: - if (st:=unwrap(x.st)) in ctx.sts: return None - st, var_vals = st.simplify().unbind() - ctx.var_vals.update(var_vals) - ctx.sts.add(st) - return st.to_uop() if st != x.st else None - -def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: - ctx.bufs.append(x) - return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) - -to_si = PatternMatcher([ - # BUFFER -> DEFINE_GLOBAL - (UPat(Ops.BUFFER, name="x"), _append_buf), - # simplify and unbind the final VIEWs - (UPat(Ops.VIEW, name="x"), _append_st_vars), - # don't need SINK on COPY or BUFFER_VIEW - (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))), - # don't need contiguous or assign anymore - (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), - (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), - # don't need DEVICE anymore - (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), - # PRELOAD becomes LOAD - (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), - # once images are loaded they become the base dtype - (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), -]) - -def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): - ctx[var.replace(src=())] = val.arg - return var -unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) - -def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem: - # unbind_vars + push views to edges - sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) - # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL - ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) - # deal with ASSIGN - if len(ctx.assigns) != 0: - assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] - for x in list(sink.toposort)[::-1]: - # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER - if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") - # PRELOAD tells the toposort this kernel should run before ASSIGN - if x.op is Ops.PRELOAD: - assign_preloads[x.buf_uop] = None - # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD - if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: - # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine - if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass - # if it has a single view and it's equal when you shrink a contig, it's fine - elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass - # otherwise, it's not fine - else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - # capture process replay - if CAPTURE_PROCESS_REPLAY: - with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) - return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) - -PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} -if CAPTURE_PROCESS_REPLAY: - @atexit.register - def save_process_replay() -> None: - for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) - -# **** UOp realization - class UPatScheduled(UPat): def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs}))) @@ -352,69 +263,150 @@ break_sched = PatternMatcher([ (UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse), ]) -# **** schedule simplifier +# **** ScheduleItem creation -def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None: - if not all_int(x.shape): return None - # remove reduce on unmasked const - prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1]) - ret = x.const_arg - match reduce.arg[0]: - case Ops.ADD: ret *= prshape - case Ops.MUL: ret **= prshape - case Ops.MAX: pass # NOTE: Ops.MAX is passthrough - case _: return None - return reduce.const_like(ret) +@dataclass(frozen=True) +class ScheduleItem: + ast: UOp + bufs: tuple[Buffer, ...] + metadata: tuple[Metadata, ...] + @property + def outputs(self) -> tuple[Buffer, ...]: + """Read/write or write only buffers in the schedule.""" + return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs) + @property + def inputs(self) -> tuple[Buffer, ...]: + """Read only buffers in the schedule.""" + return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs) + @functools.cached_property + def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,) -def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): - if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti) -def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp): - new_src = list(alu.src) - for i,s in enumerate(alu.src): - if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src - if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src)) +@dataclass(frozen=True) +class ScheduleItemContext: + var_vals: dict[Variable, int] + sts: set[ShapeTracker] = field(default_factory=set) + bufs: list[UOp] = field(default_factory=list) -sym = symbolic_simple+PatternMatcher([ - # UOp with size 0 is zero - (UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \ - and not (root.base.op is Ops.CONST and root.base.arg == 0) else None), - # DETACH and CONTIGUOUS_BACKWARD are NOOPs here - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]), - # reduce of size 0 is the identity element - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), - lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # reduce of const is collapsed (TODO: make this a generic rule for stride0) - (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop), - # COPY(CONST) creates a new CONST on the destination device - (UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)), - # no COPY to same device, except clone (arg is True) - (UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"), - lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None), - # remove cast to image when it's already a contiguous image - (UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)), - lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), - # remove contiguous if we can just view the buffer - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)), - lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None), - # contiguous/buffer is already contiguous - (UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]), - # support for using a contiguous permuted view instead of the parent view if one exists - (UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous), - (UPat(GroupOp.ALU, name="alu"), replace_contiguous), - # remove CONST/BIND/BUFFER from SINK - (UPat(Ops.SINK, name="root"), - lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg) - if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None), +def apply_swizzle(u:UOp) -> UOp: + with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left) + +def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp: + input_st = ShapeTracker.from_shape(unwrap(src.st).shape) + tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg) + prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):]) + strides = strides_for_shape(rshape) + nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides, + v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views] + # update input_st and axis + new_input_st = tmp + ShapeTracker(tuple(nv)) + new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg))) + return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape)) + +def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp: + if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}") + output_shape = swizzle_st.reduce(r.axis_arg) + return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape)) + +def elementwise_view_right(root:UOp) -> UOp|None: + if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None + assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}" + assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}" + # push the swizzle from src to root + output_swizzle = swizzles[0] + new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape) + ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src)) + return ret.view(ShapeTracker.from_shape(output_swizzle.shape)) + +def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp: + assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu" + assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time" + return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)) + +# push VIEW to children +view_right = merge_views+PatternMatcher([ + # STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val)) + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))), + lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))), + # STORE is the last child, so we just merge the ShapeTrackers and store the base + (UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)), + # REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view() + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)), + # REDUCE(src.view()) -> REDUCE(src).view() + (UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right), + # ALU(src.view()) -> ALU(src).view() + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right), + # double reduce op collapses to a single reduce op + (UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce), ]) -remove_movement_ops = merge_views+PatternMatcher([ - # NOTE: movement ops are always applied to base - (UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))), - # some masked views can collapse to 0, VIEW(x) -> CONST(VIEW) - (UPat(Ops.VIEW, name="view"), - lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None), +def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None: + if (st:=unwrap(x.st)) in ctx.sts: return None + st, var_vals = st.simplify().unbind() + ctx.var_vals.update(var_vals) + ctx.sts.add(st) + return st.to_uop() if st != x.st else None + +def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp: + ctx.bufs.append(x) + return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1) + +to_si = PatternMatcher([ + # BUFFER -> DEFINE_GLOBAL + (UPat(Ops.BUFFER, name="x"), _append_buf), + # simplify and unbind the final VIEWs + (UPat(Ops.VIEW, name="x"), _append_st_vars), + # don't need SINK on COPY or BUFFER_VIEW + (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))), + # don't need contiguous or assign anymore + (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x), + (UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x), + # don't need DEVICE anymore + (UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())), + # PRELOAD becomes LOAD + (UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)), + # once images are loaded they become the base dtype + (UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None), ]) +def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp): + ctx[var.replace(src=())] = val.arg + return var +unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),]) + +def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem: + # unbind_vars + push views to edges + sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right) + # remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL + ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals)) + # deal with ASSIGN + if len(ctx.assigns) != 0: + assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer] + for x in list(sink.toposort)[::-1]: + # we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER + if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph") + # PRELOAD tells the toposort this kernel should run before ASSIGN + if x.op is Ops.PRELOAD: + assign_preloads[x.buf_uop] = None + # if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD + if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous: + # if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine + if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass + # if it has a single view and it's equal when you shrink a contig, it's fine + elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass + # otherwise, it's not fine + else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) + # capture process replay + if CAPTURE_PROCESS_REPLAY: + with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast)) + return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None))) + +PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {} +if CAPTURE_PROCESS_REPLAY: + @atexit.register + def save_process_replay() -> None: + for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True) + # **** schedule creation and toposort @track_rewrites(named=True)