mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 14:28:09 -05:00
New schedule.py Order [pr] (#8874)
This commit is contained in:
@@ -13,25 +13,70 @@ from tinygrad.device import Buffer
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# **** ScheduleItem return type
|
||||
# **** schedule simplifier
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
@property
|
||||
def outputs(self) -> tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
||||
@property
|
||||
def inputs(self) -> tuple[Buffer, ...]:
|
||||
"""Read only buffers in the schedule."""
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
# remove reduce on unmasked const
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
|
||||
# **** Schedule context and big graph
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# contiguous/buffer is already contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
])
|
||||
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
])
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
@@ -71,140 +116,6 @@ def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
|
||||
# **** AST graph rewrite
|
||||
|
||||
# ** movement ops
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
||||
|
||||
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
||||
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
||||
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
|
||||
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
||||
output_shape = swizzle_st.reduce(r.axis_arg)
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
|
||||
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# push the swizzle from src to root
|
||||
output_swizzle = swizzles[0]
|
||||
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
||||
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
||||
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
||||
# REDUCE(src.view()) -> REDUCE(src).view()
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
|
||||
# ALU(src.view()) -> ALU(src).view()
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
|
||||
# double reduce op collapses to a single reduce op
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
# ** ScheduleItem context builder
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: dict[Variable, int]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
st, var_vals = st.simplify().unbind()
|
||||
ctx.var_vals.update(var_vals)
|
||||
ctx.sts.add(st)
|
||||
return st.to_uop() if st != x.st else None
|
||||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
# BUFFER -> DEFINE_GLOBAL
|
||||
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
||||
# simplify and unbind the final VIEWs
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
# don't need SINK on COPY or BUFFER_VIEW
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
# don't need DEVICE anymore
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
||||
# PRELOAD becomes LOAD
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
# once images are loaded they become the base dtype
|
||||
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem:
|
||||
# unbind_vars + push views to edges
|
||||
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
|
||||
# deal with ASSIGN
|
||||
if len(ctx.assigns) != 0:
|
||||
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
||||
for x in list(sink.toposort)[::-1]:
|
||||
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
||||
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
||||
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads[x.buf_uop] = None
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
|
||||
# otherwise, it's not fine
|
||||
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@atexit.register
|
||||
def save_process_replay() -> None:
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
@@ -352,69 +263,150 @@ break_sched = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="st", src=(UPat(Ops.BUFFER, name="b"), UPat.var("x"))), store_or_fuse),
|
||||
])
|
||||
|
||||
# **** schedule simplifier
|
||||
# **** ScheduleItem creation
|
||||
|
||||
def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
if not all_int(x.shape): return None
|
||||
# remove reduce on unmasked const
|
||||
prshape = prod(unwrap(x.st).shape[i] for i in reduce.arg[1])
|
||||
ret = x.const_arg
|
||||
match reduce.arg[0]:
|
||||
case Ops.ADD: ret *= prshape
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return reduce.const_like(ret)
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: tuple[Buffer, ...]
|
||||
metadata: tuple[Metadata, ...]
|
||||
@property
|
||||
def outputs(self) -> tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i in self.output_idxs)
|
||||
@property
|
||||
def inputs(self) -> tuple[Buffer, ...]:
|
||||
"""Read only buffers in the schedule."""
|
||||
return tuple(b for i,b in enumerate(self.bufs) if i not in self.output_idxs)
|
||||
@functools.cached_property
|
||||
def output_idxs(self) -> tuple[int, ...]: return tuple(x.src[0].arg for x in self.ast.src) if self.ast.op is Ops.SINK else (0,)
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
||||
def replace_contiguous(ctx:dict[UOp, UOp], alu:UOp):
|
||||
new_src = list(alu.src)
|
||||
for i,s in enumerate(alu.src):
|
||||
if (replace_src:=ctx.get(s, None)) is not None: new_src[i] = replace_src
|
||||
if tuple(new_src) != alu.src: return alu.replace(src=tuple(new_src))
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
var_vals: dict[Variable, int]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
|
||||
sym = symbolic_simple+PatternMatcher([
|
||||
# UOp with size 0 is zero
|
||||
(UPat(set(Ops)-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 \
|
||||
and not (root.base.op is Ops.CONST and root.base.arg == 0) else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat(), UPat.cvar("x"),)), lambda root,x: root.const_like(x.const_arg)),
|
||||
# no COPY to same device, except clone (arg is True)
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat.var("copyin")), name="copy"),
|
||||
lambda copyin,copy: copyin if copyin.device == copy.device and copy.arg is not True else None),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.VIEW, name="vm1", src=(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm2", src=(UPat(Ops.CONTIGUOUS, name="base"))))),)),
|
||||
lambda cast,base,vm1,vm2: base.view(vm2.st+vm1.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
# remove contiguous if we can just view the buffer
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
||||
# contiguous/buffer is already contiguous
|
||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER)),)), lambda root: root.src[0]),
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
(UPat(Ops.CONTIGUOUS, name="contig", src=(UPat(Ops.VIEW, name="src"),)), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), replace_contiguous),
|
||||
# remove CONST/BIND/BUFFER from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, new_src, root.arg)
|
||||
if (new_src:=tuple(x for x in root.src if not x.is_realized and x.base.op not in {Ops.CONST, Ops.BIND})) != root.src else None),
|
||||
def apply_swizzle(u:UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
||||
|
||||
def swizzle_r(r:UOp, src:UOp, st:ShapeTracker) -> UOp:
|
||||
input_st = ShapeTracker.from_shape(unwrap(src.st).shape)
|
||||
tmp = input_st.permute(tuple(i for i in range(len(input_st.shape)) if i not in r.axis_arg)+r.axis_arg)
|
||||
prshape = prod(rshape:=tmp.shape[-len(r.axis_arg):])
|
||||
strides = strides_for_shape(rshape)
|
||||
nv = [View.create(v.shape+rshape, tuple(x*prshape for x in v.strides)+strides,
|
||||
v.offset*prshape, v.mask+tuple((0,s) for s in rshape) if v.mask is not None else None) for v in st.views]
|
||||
# update input_st and axis
|
||||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
new_axis = tuple(range(len(st.shape), len(st.shape) + len(r.axis_arg)))
|
||||
return apply_swizzle(src.view(new_input_st)).r(r.arg[0], new_axis).view(ShapeTracker.from_shape(st.shape))
|
||||
|
||||
def reduceop_view_right(r:UOp, v:UOp, src:UOp) -> UOp:
|
||||
if not (swizzle_st:=unwrap(v.st)).contiguous or v.size != src.size: raise AssertionError(f"can't push {v} down through {src}")
|
||||
output_shape = swizzle_st.reduce(r.axis_arg)
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, output_shape)) if s != u)).view(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
if len(swizzles:=[x for x in root.src if x.base is not x]) == 0: return None
|
||||
assert all(x.base.st is not None for x in swizzles), f"found shapeless VIEW src in {root}"
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# push the swizzle from src to root
|
||||
output_swizzle = swizzles[0]
|
||||
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
||||
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
return ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
# push VIEW to children
|
||||
view_right = merge_views+PatternMatcher([
|
||||
# STORE(.., ASSIGN(VIEW(BUFFER), new_val)) -> VIEW(STORE(.., new_val))
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat.assign(UPat.var("target"), UPat.var("val")))),
|
||||
lambda b,target,st,val: apply_swizzle(UOp.store(b, st, val).view(target.st))),
|
||||
# STORE is the last child, so we just merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat.var("val"),)))), lambda b,st,val: UOp.store(b, st.view(val.st), val)),
|
||||
# REDUCE(src.view(contiguous=False)) -> REDUCE(src.view(contiguous=True)).view()
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").view(name="v"), lambda v,r,src: None if v.st.contiguous else swizzle_r(r, src, v.st)),
|
||||
# REDUCE(src.view()) -> REDUCE(src).view()
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src").view(name="v"),), name="r"), reduceop_view_right),
|
||||
# ALU(src.view()) -> ALU(src).view()
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN, Ops.CONTIGUOUS, Ops.STORE), name="root"), elementwise_view_right),
|
||||
# double reduce op collapses to a single reduce op
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
remove_movement_ops = merge_views+PatternMatcher([
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.var("x"),)), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
st, var_vals = st.simplify().unbind()
|
||||
ctx.var_vals.update(var_vals)
|
||||
ctx.sts.add(st)
|
||||
return st.to_uop() if st != x.st else None
|
||||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
|
||||
to_si = PatternMatcher([
|
||||
# BUFFER -> DEFINE_GLOBAL
|
||||
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
||||
# simplify and unbind the final VIEWs
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
# don't need SINK on COPY or BUFFER_VIEW
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
# don't need DEVICE anymore
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.DEVICE),)), lambda view: view.replace(src=())),
|
||||
# PRELOAD becomes LOAD
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
# once images are loaded they become the base dtype
|
||||
(UPat(set(Ops)-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
])
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext, var_vals:dict[UOp, int]) -> ScheduleItem:
|
||||
# unbind_vars + push views to edges
|
||||
sink = graph_rewrite(graph_rewrite(pre, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(var_vals))
|
||||
# deal with ASSIGN
|
||||
if len(ctx.assigns) != 0:
|
||||
assign_preloads = ctx.preloads[si_ctx.bufs[0].buffer]
|
||||
for x in list(sink.toposort)[::-1]:
|
||||
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
||||
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
||||
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads[x.buf_uop] = None
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop is pre.src[0].buf_uop and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it becomes contiguous when you shrink expanded axes, it's fine
|
||||
if len(st.views) == 1 and st.shrink(tuple((0,1) if st == 0 else (0,s) for s,st in zip(st.shape, st.views[0].strides))).contiguous: pass
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
elif len(st.views) == 1 and (mask:=st.views[0].mask) is not None and ShapeTracker.from_shape(st.shape).shrink(mask) == st.shrink(mask): pass
|
||||
# otherwise, it's not fine
|
||||
else: raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs), tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@atexit.register
|
||||
def save_process_replay() -> None:
|
||||
for k,v in PROCESS_REPLAY_CAPTURE.items(): diskcache_put("schedule_process_replay", k, v, prepickled=True)
|
||||
|
||||
# **** schedule creation and toposort
|
||||
|
||||
@track_rewrites(named=True)
|
||||
|
||||
Reference in New Issue
Block a user