New schedule.py Order [pr] (#8874)

This commit is contained in:
qazal
2025-02-03 07:59:11 -05:00
committed by GitHub
parent b075aefc12
commit b6c617272a

View File

@@ -13,25 +13,70 @@ from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
# **** ScheduleItem return type
# **** schedule simplifier
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
@property
def outputs(self) -> tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
@property
def inputs(self) -> tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
# remove reduce on unmasked const
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return reduce.const_like(ret)
# **** Schedule context and big graph
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
])
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
])
# **** UOp realization
@dataclass(frozen=True)
class ScheduleContext:
@@ -71,140 +116,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
# **** AST graph rewrite
# ** movement ops
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def elementwise_view_right(root:UOp) -> UOp|None:
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# push the swizzle from src to root
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# REDUCE(src.view()) -> REDUCE(src).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
# ALU(src.view()) -> ALU(src).view()
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** ScheduleItem context builder
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: dict[Variable, int]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
ctx.var_vals.update(var_vals)
ctx.sts.add(st)
return st.to_uop() if st != x.st else None
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
# simplify and unbind the final VIEWs
(UPat(Ops.VIEW, name="x"), _append_st_vars),
# don't need SINK on COPY or BUFFER_VIEW
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# don't need DEVICE anymore
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
])
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
ctx[var.replace(src=())] = val.arg
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem:
# unbind_vars + push views to edges
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
# deal with ASSIGN
if len(ctx.assigns) != 0:
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads[x.buf_uop] = None
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
# if it has a single view and it's equal when you shrink a contig, it's fine
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
# otherwise, it's not fine
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** UOp realization
class UPatScheduled(UPat):
def __init__(self, *args, **kwargs):
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
@@ -352,69 +263,150 @@ break_sched = PatternMatcher([
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
])
# **** schedule simplifier
# **** ScheduleItem creation
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
if not all_int(x.shape): return None
# remove reduce on unmasked const
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
ret = x.const_arg
match reduce.arg[0]:
case Ops.ADD: ret *= prshape
case Ops.MUL: ret **= prshape
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
case _: return None
return reduce.const_like(ret)
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...]
metadata: tuple[Metadata, ...]
@property
def outputs(self) -> tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
@property
def inputs(self) -> tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
@functools.cached_property
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
new_src = list(alu.src)
for i,s in enumerate(alu.src):
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
@dataclass(frozen=True)
class ScheduleItemContext:
var_vals: dict[Variable, int]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
sym = symbolic_simple+PatternMatcher([
# UOp with size 0 is zero
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
# COPY(CONST) creates a new CONST on the destination device
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
# no COPY to same device, except clone (arg is True)
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
# remove cast to image when it's already a contiguous image
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
# remove contiguous if we can just view the buffer
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
# contiguous/buffer is already contiguous
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
# support for using a contiguous permuted view instead of the parent view if one exists
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
# remove CONST/BIND/BUFFER from SINK
(UPat(Ops.SINK, name="root"),
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
strides = strides_for_shape(rshape)
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
# update input_st and axis
new_input_st = tmp + ShapeTracker(tuple(nv))
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
output_shape = swizzle_st.reduce(r.axis_arg)
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
def elementwise_view_right(root:UOp) -> UOp|None:
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# push the swizzle from src to root
output_swizzle = swizzles[0]
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
# push VIEW to children
view_right = merge_views+PatternMatcher([
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
# STORE is the last child, so we just merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
# REDUCE(src.view()) -> REDUCE(src).view()
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
# ALU(src.view()) -> ALU(src).view()
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
# double reduce op collapses to a single reduce op
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
remove_movement_ops = merge_views+PatternMatcher([
# NOTE: movement ops are always applied to base
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
(UPat(Ops.VIEW, name="view"),
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
st, var_vals = st.simplify().unbind()
ctx.var_vals.update(var_vals)
ctx.sts.add(st)
return st.to_uop() if st != x.st else None
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
# simplify and unbind the final VIEWs
(UPat(Ops.VIEW, name="x"), _append_st_vars),
# don't need SINK on COPY or BUFFER_VIEW
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# don't need DEVICE anymore
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
# once images are loaded they become the base dtype
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
])
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
ctx[var.replace(src=())] = val.arg
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem:
# unbind_vars + push views to edges
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
# deal with ASSIGN
if len(ctx.assigns) != 0:
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads[x.buf_uop] = None
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
# if it has a single view and it's equal when you shrink a contig, it's fine
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
# otherwise, it's not fine
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@atexit.register
def save_process_replay() -> None:
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
# **** schedule creation and toposort
@track_rewrites(named=True)