|
|
|
|
@@ -1,15 +1,16 @@
|
|
|
|
|
from typing import Any
|
|
|
|
|
from typing import Any, cast
|
|
|
|
|
import functools, operator
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
|
|
|
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute
|
|
|
|
|
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, RewriteNotReady, _substitute, ssimplify
|
|
|
|
|
from tinygrad.uop.symbolic import sym
|
|
|
|
|
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context
|
|
|
|
|
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, RANGEIFY, Context, flatten, dedup
|
|
|
|
|
from tinygrad.schedule.multi import multi_pm
|
|
|
|
|
|
|
|
|
|
from tinygrad.schedule.kernelize import Kernel
|
|
|
|
|
from tinygrad.uop.ops import track_rewrites, graph_rewrite_map, graph_rewrite, identity_element, sint, AxisType
|
|
|
|
|
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
|
|
|
|
|
|
|
|
|
double_reshape = PatternMatcher([
|
|
|
|
|
@@ -19,30 +20,42 @@ double_reshape = PatternMatcher([
|
|
|
|
|
|
|
|
|
|
earliest_rewrites = double_reshape+PatternMatcher([
|
|
|
|
|
# non shape changing RESHAPE is NOOP
|
|
|
|
|
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
|
|
|
|
#(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
|
|
|
|
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
|
|
|
|
#(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0].f(Ops.NOOP, tag=x.tag)),
|
|
|
|
|
|
|
|
|
|
# just removing it works...
|
|
|
|
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
|
|
|
|
|
|
|
|
|
# preserve tags?
|
|
|
|
|
# UOp with size 0 is zero
|
|
|
|
|
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
|
|
|
|
# reduce of size 0 is the identity element
|
|
|
|
|
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
|
|
|
|
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
|
|
|
|
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
|
|
|
|
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
|
|
|
|
|
|
|
|
|
# copy reorder
|
|
|
|
|
# TODO: this is causing many copies wih the replace tag None
|
|
|
|
|
# RESHAPE after COPY
|
|
|
|
|
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
|
|
|
|
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).reshape(r.arg)),
|
|
|
|
|
# TODO: this should be BUFFER_VIEW
|
|
|
|
|
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).shrink(r.arg)),
|
|
|
|
|
(UPat(Ops.COPY, src=(UPat(Ops.SHRINK, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d), tag=None).shrink(r.arg)),
|
|
|
|
|
|
|
|
|
|
# const hacks
|
|
|
|
|
(UPat(Ops.CONST, name="x"), lambda x:
|
|
|
|
|
x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
|
|
|
|
len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
|
|
|
|
#(UPat(Ops.CONST, name="x"), lambda x:
|
|
|
|
|
# x.replace(src=(x.src[0].src[0],)).reshape((1,)*len(x.shape)).expand(x.shape) if \
|
|
|
|
|
# len(x.src) and x.src[0].op is Ops.VIEW and not any(s == 0 for s in x.shape) else None),
|
|
|
|
|
|
|
|
|
|
# assign only to buffer
|
|
|
|
|
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x"))),
|
|
|
|
|
lambda x,target: x if target.base.op is not Ops.BUFFER else None),
|
|
|
|
|
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.BUFFER}, name="target"), UPat(name="x")), name="assign"),
|
|
|
|
|
lambda x,target,assign: x.f(Ops.NOOP, tag=assign.tag) if target.base.op is not Ops.BUFFER else None),
|
|
|
|
|
|
|
|
|
|
# contiguous/buffer/copy/assign is already contiguous
|
|
|
|
|
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
|
|
|
|
#(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# 1. add contiguous where we have to
|
|
|
|
|
# *****************
|
|
|
|
|
# 1. add realize where we have to
|
|
|
|
|
|
|
|
|
|
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
|
|
|
|
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
|
|
|
|
@@ -69,10 +82,12 @@ do_realize = PatternMatcher([
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
add_contiguous = PatternMatcher([
|
|
|
|
|
(UPat(GroupOp.All, name="x"), lambda ctx,x: x.replace(tag=1).realize() if x in ctx and x.tag is None else None),
|
|
|
|
|
(UPat(GroupOp.All, name="x"),
|
|
|
|
|
lambda ctx,x: x.replace(tag=(x.tag,)).realize() if x in ctx and not isinstance(x.tag, tuple) else None),
|
|
|
|
|
])
|
|
|
|
|
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
|
|
|
|
remove_tuple_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=x.tag[0]) if isinstance(x.tag, tuple) else None)])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 2. mark all children
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -99,7 +114,8 @@ pm_children = PatternMatcher([
|
|
|
|
|
(UPat(GroupOp.All-{Ops.CHILD, Ops.CHILDREN}, name="x"), mark_children),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# 3. rangeify
|
|
|
|
|
# *****************
|
|
|
|
|
# 3a. rangeify (movement)
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class RangeifyContext:
|
|
|
|
|
@@ -175,13 +191,20 @@ pm_mops = PatternMatcher([
|
|
|
|
|
(UPat(Ops.PAD, name="r").f(Ops.INDEX, allow_any_len=True, name="idx"), map_pad),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 3b. rangeify (ops)
|
|
|
|
|
|
|
|
|
|
# bufferization can happen in three ways
|
|
|
|
|
# 1. there's an explicit REALIZE in the graph
|
|
|
|
|
# 2. the ranges from the children don't match and we have to create a buffer (only on children)
|
|
|
|
|
# 3. might_end_axis triggers because we should be closing a loop to save compute
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class BufferizeOpts:
|
|
|
|
|
# on AddrSpace.LOCAL, device is the id
|
|
|
|
|
device: str|tuple[str, ...]|int
|
|
|
|
|
device: str|tuple[str, ...]|int|None
|
|
|
|
|
addrspace: AddrSpace = AddrSpace.GLOBAL
|
|
|
|
|
tags: tuple[int, ...] = ()
|
|
|
|
|
|
|
|
|
|
def map_partial_realize(ctx:RangeifyContext, x:UOp, idx:UOp):
|
|
|
|
|
if x.arg is None: return None # map_contiguous can handle this
|
|
|
|
|
@@ -195,21 +218,17 @@ def map_partial_realize(ctx:RangeifyContext, x:UOp, idx:UOp):
|
|
|
|
|
ranges.append(idx.src[1+i])
|
|
|
|
|
continue
|
|
|
|
|
passthrough_idx.append(idx.src[1+i])
|
|
|
|
|
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
|
|
|
|
|
ranges.append(ctx.new_range(s))
|
|
|
|
|
new_ranges.append(ranges[-1])
|
|
|
|
|
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST], arg=BufferizeOpts(device=x.device))
|
|
|
|
|
# TODO: this should be able to be global or local
|
|
|
|
|
ret = x.src[0].index(*ranges).bufferize(*[x for x in new_ranges if x.op is not Ops.CONST],
|
|
|
|
|
arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL))
|
|
|
|
|
return ret.index(*passthrough_idx)
|
|
|
|
|
|
|
|
|
|
def map_realize(ctx:RangeifyContext, x:UOp):
|
|
|
|
|
if x.arg is not None: return None
|
|
|
|
|
ranges = []
|
|
|
|
|
for s in x.shape[len(x.src)-1:]:
|
|
|
|
|
ranges.append(ctx.new_range(s) if resolve(s!=1) else UOp.const(dtypes.index, 0))
|
|
|
|
|
ret = x.src[0].index(*ranges).bufferize(*x.src[1:], *[x for x in ranges if x.op is not Ops.CONST], arg=BufferizeOpts(device=x.device))
|
|
|
|
|
# was there a shrink? move this before the bufferize?
|
|
|
|
|
# TODO: do we need this?
|
|
|
|
|
if resolve(prod(x.shape) != prod(ret.shape)): ret = ret.forced_reshape((prod(ret.shape),)).shrink(((0, prod(x.shape)),))
|
|
|
|
|
return ret.forced_reshape(x.shape)
|
|
|
|
|
ranges = [ctx.new_range(s) for s in x.shape]
|
|
|
|
|
return x.src[0].index(*ranges).bufferize(*x.src[1:], *ranges, arg=BufferizeOpts(device=x.device, tags=(x.src[0].tag,)))
|
|
|
|
|
|
|
|
|
|
def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
|
|
|
|
rngs = list(idx.src[1:])
|
|
|
|
|
@@ -218,7 +237,7 @@ def map_reduce(ctx:RangeifyContext, idx:UOp, red:UOp):
|
|
|
|
|
if i in red.arg[1]:
|
|
|
|
|
rngs[i] = ctx.new_range(s, axistype=AxisType.REDUCE)
|
|
|
|
|
new_ranges.append(rngs[i])
|
|
|
|
|
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0])
|
|
|
|
|
return UOp(Ops.REDUCE, red.dtype, src=(red.src[0].index(*rngs),)+tuple(new_ranges), arg=red.arg[0], tag=red.tag)
|
|
|
|
|
|
|
|
|
|
def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
|
|
|
|
if c not in ctx.seen_children: ctx.seen_children[c] = {}
|
|
|
|
|
@@ -256,7 +275,14 @@ def index_child(ctx:RangeifyContext, c:UOp, x:UOp, idx:UOp):
|
|
|
|
|
# index based on the shared ranges
|
|
|
|
|
ret = c.index(*out_rngs)
|
|
|
|
|
# if all ranges aren't the same between children, we have to bufferize
|
|
|
|
|
if len(idx_ranges) > 0: ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=x.device)).index(*[idx.src[1+i] for i in idx_ranges])
|
|
|
|
|
if len(idx_ranges) > 0:
|
|
|
|
|
if len(idx_ranges) == len(out_rngs):
|
|
|
|
|
# this is a global bufferize
|
|
|
|
|
ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=x.device))
|
|
|
|
|
else:
|
|
|
|
|
assert RANGEIFY > 1, "this isn't supported with RANGEIFY=1"
|
|
|
|
|
ret = ret.bufferize(*end_ranges, arg=BufferizeOpts(device=None, addrspace=AddrSpace.LOCAL))
|
|
|
|
|
ret = ret.index(*[idx.src[1+i] for i in idx_ranges])
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
|
|
|
|
@@ -266,7 +292,7 @@ def children_gate(ctx:RangeifyContext, idx:UOp, c:UOp):
|
|
|
|
|
def might_end_axis(idx:UOp):
|
|
|
|
|
if idx.arg is None: return None
|
|
|
|
|
# TODO: write a proper cost function here
|
|
|
|
|
if all(x.op not in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.BUFFERIZE} for x in idx.toposort()): return None
|
|
|
|
|
if all(x.op not in {Ops.BUFFER, Ops.REALIZE, Ops.BUFFERIZE} for x in idx.toposort()): return None
|
|
|
|
|
if all(x.op not in {Ops.REDUCE_AXIS} for x in idx.toposort()): return None
|
|
|
|
|
to_end_axis = []
|
|
|
|
|
for i,a in enumerate(idx.src[1:]):
|
|
|
|
|
@@ -275,6 +301,8 @@ def might_end_axis(idx:UOp):
|
|
|
|
|
if to_end_axis: return idx.replace(src=(idx.src[0].realize(arg=tuple(to_end_axis)),)+idx.src[1:], arg=None)
|
|
|
|
|
return idx.replace(arg=None)
|
|
|
|
|
|
|
|
|
|
def unprocessed_index(x:UOp): raise RuntimeError(f"unprocessed index on {x.src[0].op}")
|
|
|
|
|
|
|
|
|
|
pm_rangeify = pm_mops+PatternMatcher([
|
|
|
|
|
# sink contigs to kick it off
|
|
|
|
|
(UPat(Ops.REALIZE, src=(UPat(),), name="x", allow_any_len=True), map_realize),
|
|
|
|
|
@@ -294,24 +322,30 @@ pm_rangeify = pm_mops+PatternMatcher([
|
|
|
|
|
# handle arg on any op with weight. old endrange stuff
|
|
|
|
|
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
|
|
|
|
|
|
|
|
|
# handle assign
|
|
|
|
|
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
|
|
|
|
|
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],))),
|
|
|
|
|
|
|
|
|
|
# move MAP through elementwise ALU / reduce. these are the items with cost
|
|
|
|
|
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union(
|
|
|
|
|
{Ops.STORE, Ops.ASSIGN, Ops.COPY, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS})),), allow_any_len=True, name="x"),
|
|
|
|
|
{Ops.STORE, Ops.COPY, Ops.DEVICE, Ops.BIND, Ops.CONTIGUOUS, Ops.NOOP})),), allow_any_len=True, name="x"),
|
|
|
|
|
lambda x: x.src[0].replace(src=tuple([s.index(*x.src[1:]) for s in x.src[0].src]))),
|
|
|
|
|
(UPat(Ops.INDEX, src=(UPat(Ops.REDUCE_AXIS, name="red"),), allow_any_len=True, name="idx"), map_reduce),
|
|
|
|
|
|
|
|
|
|
# assert if there's any index we didn't process
|
|
|
|
|
(UPat(GroupOp.All-{Ops.REALIZE, Ops.BUFFERIZE}).f(Ops.INDEX, name="x"), unprocessed_index),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 3.5 cleanups
|
|
|
|
|
|
|
|
|
|
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
|
|
|
|
# TODO: figure out how to reenable this
|
|
|
|
|
def cleanup_dead_axes(b:UOp):
|
|
|
|
|
parents = b.src[0].toposort()
|
|
|
|
|
new_rng = []
|
|
|
|
|
hit = False
|
|
|
|
|
reshape: list[sint] = []
|
|
|
|
|
for s,rng in zip(b.shape, b.src[1:]):
|
|
|
|
|
if rng not in parents and rng.op is Ops.RANGE:
|
|
|
|
|
if rng not in b.src[0].sparents and rng.op is Ops.RANGE:
|
|
|
|
|
reshape.append(1)
|
|
|
|
|
hit = True
|
|
|
|
|
else:
|
|
|
|
|
@@ -327,19 +361,20 @@ def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
|
|
|
|
assert len(buf.src) == len(idx.src), "index on wrong bufferize"
|
|
|
|
|
assert all(x.op is Ops.RANGE for x in buf.src[1:])
|
|
|
|
|
|
|
|
|
|
# if it's user contiguous, we never remove it
|
|
|
|
|
if src.op is Ops.CONTIGUOUS: return None
|
|
|
|
|
|
|
|
|
|
# here is where we compute the cost
|
|
|
|
|
# for now just no REDUCE, COPY, or ASSIGN
|
|
|
|
|
# TODO: exclude fusion of user contiguous
|
|
|
|
|
#ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
|
|
|
|
|
#if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran): return None
|
|
|
|
|
ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
|
|
|
|
|
if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran): return None
|
|
|
|
|
|
|
|
|
|
# simple, matching old behavior
|
|
|
|
|
if src.op is not Ops.INDEX: return None
|
|
|
|
|
#if src.op is not Ops.INDEX: return None
|
|
|
|
|
|
|
|
|
|
# this is the ranges replaced
|
|
|
|
|
return src.substitute(dict(zip(buf.src[1:], idx.src[1:])))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
|
|
|
|
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
|
|
|
|
# remove noop buffers. if we look at the next index we can remove even more of these
|
|
|
|
|
@@ -352,6 +387,7 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
|
|
|
|
#(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 4. put in buffers for bufferize
|
|
|
|
|
# TODO: should BUFFERIZE look a lot more like STORE
|
|
|
|
|
# BUFFERIZE has device in arg
|
|
|
|
|
@@ -359,36 +395,54 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
|
|
|
|
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
|
|
|
|
|
# NOTE: this has been fixed up a bit
|
|
|
|
|
|
|
|
|
|
def bufferize_to_store(x:UOp, locals_allowed=False):
|
|
|
|
|
def bufferize_to_store(x:UOp):
|
|
|
|
|
rngs = x.src[1:]
|
|
|
|
|
shape = tuple([int(r.vmax+1) for r in rngs])
|
|
|
|
|
sym_shape = tuple([ssimplify(r.src[0]) for r in rngs])
|
|
|
|
|
size = prod(shape)
|
|
|
|
|
assert size > 0, f"no zero sized buffers {shape}"
|
|
|
|
|
|
|
|
|
|
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
|
|
|
|
if x.src[0].op is Ops.ASSIGN:
|
|
|
|
|
assign_target, assign_src = x.src[0].src
|
|
|
|
|
assign_target, assign_src, assign_mops = x.src[0].src
|
|
|
|
|
assert assign_target.op is Ops.INDEX
|
|
|
|
|
return assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
|
|
|
|
# in assign, this is the buffer size, not the bufferize size
|
|
|
|
|
# TODO: assign_mops here
|
|
|
|
|
ret = assign_target.replace(dtype=sdtype).store(assign_src, *rngs, dtype=x.dtype)
|
|
|
|
|
mops = []
|
|
|
|
|
walk = assign_mops
|
|
|
|
|
while walk is not assign_mops.base:
|
|
|
|
|
mops.append((walk.op, walk.arg))
|
|
|
|
|
walk = walk.src[0]
|
|
|
|
|
for m in mops[::-1]: ret = ret._mop(*m)
|
|
|
|
|
return ret.forced_reshape(shape).replace(tag=x.arg.tags)
|
|
|
|
|
|
|
|
|
|
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
|
|
|
|
if sdtype.addrspace == AddrSpace.GLOBAL:
|
|
|
|
|
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
|
|
|
|
else:
|
|
|
|
|
if not locals_allowed: return None
|
|
|
|
|
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=x.arg.device)
|
|
|
|
|
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
|
|
|
|
ret = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=x.dtype)
|
|
|
|
|
ret = ret.forced_reshape(shape)
|
|
|
|
|
# TODO: is this right? what if it's offset
|
|
|
|
|
if shape is not sym_shape: ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
|
|
|
|
return ret.replace(tag=x.arg.tags)
|
|
|
|
|
|
|
|
|
|
pm_add_buffers_local = pm_mops+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, True)),
|
|
|
|
|
])
|
|
|
|
|
# handle locals
|
|
|
|
|
tag = x.arg.device
|
|
|
|
|
if tag is None: tag = UOp.unique().arg # TODO: hack
|
|
|
|
|
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
|
|
|
|
|
# store has the other dtype here
|
|
|
|
|
# TODO: how is this unified?
|
|
|
|
|
return buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], *rngs, dtype=sdtype).forced_reshape(shape, dtype=x.dtype)
|
|
|
|
|
|
|
|
|
|
pm_add_buffers = pm_mops+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
|
|
|
|
|
|
|
|
|
# move RESHAPEs through MSELECT/MSTACK
|
|
|
|
|
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
|
|
|
|
lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
|
|
|
|
|
#(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
|
|
|
|
# lambda m: m.replace(src=tuple([x.src[0] for x in m.src])).reshape(m.src[0].arg)),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
# 5. split into kernels
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@@ -426,9 +480,12 @@ to_define_global = PatternMatcher([
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
rangeify_codegen = PatternMatcher([
|
|
|
|
|
# no CONTIGUOUS in the kernel graph
|
|
|
|
|
# no NOOP in the kernel graph
|
|
|
|
|
# TODO: this can be moved into codegen?
|
|
|
|
|
(UPat(Ops.CONTIGUOUS, name="x"), lambda x: x.src[0]),
|
|
|
|
|
(UPat((Ops.NOOP, Ops.CONTIGUOUS), name="x"), lambda x: x.src[0]),
|
|
|
|
|
|
|
|
|
|
# strip the arg from store
|
|
|
|
|
(UPat(Ops.STORE, name="x"), lambda x: x.replace(arg=None) if x.arg is not None else None),
|
|
|
|
|
|
|
|
|
|
# add loads to non ptr indexes
|
|
|
|
|
# TODO: this can be moved into codegen?
|
|
|
|
|
@@ -444,41 +501,67 @@ rangeify_codegen = PatternMatcher([
|
|
|
|
|
lambda src, barrier, gate: src.load(UOp(Ops.IF, src=(gate, barrier)))),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
def split_store(x:UOp):
|
|
|
|
|
def split_store(ctx:list[UOp], x:UOp):
|
|
|
|
|
if len(x.ranges): return None
|
|
|
|
|
ctx = LocalAddBufferContext()
|
|
|
|
|
ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=ctx, name="kernel split", bottom_up=True)
|
|
|
|
|
if x.src[0].ptrdtype.addrspace is AddrSpace.LOCAL: return None
|
|
|
|
|
|
|
|
|
|
# local kernel rewrite
|
|
|
|
|
lctx = LocalAddBufferContext()
|
|
|
|
|
ret = graph_rewrite(x, to_define_global+rangeify_codegen, ctx=lctx, name="kernel split", bottom_up=True)
|
|
|
|
|
|
|
|
|
|
# gather the metadata
|
|
|
|
|
metadatas = [ctx[x.tag].metadata for x in ret.sparents if x.tag is not None]
|
|
|
|
|
|
|
|
|
|
# NOTE: the hack for COPY is here
|
|
|
|
|
ret = ret.sink() if ret.src[1].op is not Ops.COPY else ret.src[1]
|
|
|
|
|
kernel = UOp(Ops.KERNEL, src=tuple(ctx.map.values())+tuple(ctx.vars.keys()), arg=Kernel(ret,()))
|
|
|
|
|
kernel_arg = Kernel(ret,tuple(dedup(flatten([x for x in metadatas if x is not None]))))
|
|
|
|
|
kernel = UOp(Ops.KERNEL, src=tuple(lctx.map.values())+tuple(lctx.vars.keys()), arg=kernel_arg)
|
|
|
|
|
return x.as_buf().assign(kernel)
|
|
|
|
|
|
|
|
|
|
split_kernels = PatternMatcher([
|
|
|
|
|
(UPat(Ops.STORE, name="x"), split_store),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
|
|
|
|
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|
|
|
|
tensor_map = graph_rewrite_map(sink, multi_pm+earliest_rewrites, name="earliest")
|
|
|
|
|
realize_map: dict[UOp, UOp] = {}
|
|
|
|
|
graph_rewrite(tensor_map[sink], do_realize, ctx=realize_map, name="Input Graph")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add realize")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], remove_tags, input_map=tensor_map, name="remove tags")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], pm_children, ctx=ChildrenContext(), bottom_up=True, input_map=tensor_map, name="children")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], pm_rangeify, ctx=RangeifyContext(), bottom_up=True, input_map=tensor_map, name="rangeify")
|
|
|
|
|
# NOTE: running symbolic can break the graph, leaving RANGE/INDEX/BUFFERIZE in the final graph
|
|
|
|
|
#tensor_map = graph_rewrite_map(tensor_map[sink], symbolic_simple, input_map=tensor_map, name="symbolic")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], pm_cleanups, bottom_up=True, input_map=tensor_map, name="buffer cost")
|
|
|
|
|
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Rangeify Graph")
|
|
|
|
|
def tag_uop(ctx:list[UOp], x:UOp):
|
|
|
|
|
if x.tag is not None: return None
|
|
|
|
|
ctx.append(x)
|
|
|
|
|
return x.replace(tag=len(ctx)-1)
|
|
|
|
|
add_tags = PatternMatcher([
|
|
|
|
|
# don't tag BUFFERs, they are global
|
|
|
|
|
(UPat(GroupOp.All-{Ops.BUFFER, Ops.DEVICE, Ops.UNIQUE, Ops.DEFINE_VAR, Ops.BIND}, name="x"), tag_uop),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], pm_add_buffers, bottom_up=True, input_map=tensor_map, name="add buffers")
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], split_kernels, input_map=tensor_map, name="split kernels")
|
|
|
|
|
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len([u for u in UOp.sink(*ret.values()).toposort() if u.op is Ops.KERNEL]))}", True)
|
|
|
|
|
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|
|
|
|
uop_list: list[UOp] = []
|
|
|
|
|
tsink = graph_rewrite(sink, add_tags, ctx=uop_list, bottom_up=True, name="number the uops")
|
|
|
|
|
tsink = graph_rewrite(tsink, multi_pm+earliest_rewrites, name="earliest rewrites")
|
|
|
|
|
realize_map: dict[UOp, UOp] = {}
|
|
|
|
|
graph_rewrite(tsink, do_realize, ctx=realize_map, name="Input Graph")
|
|
|
|
|
# NOTE: we don't use contiguous here, contiguous is a user op
|
|
|
|
|
tsink = graph_rewrite(tsink, add_contiguous, ctx=realize_map, bottom_up=True, name="add realize")
|
|
|
|
|
tsink = graph_rewrite(tsink, remove_tuple_tags, name="remove tuple tags")
|
|
|
|
|
tsink = graph_rewrite(tsink, pm_children, ctx=ChildrenContext(), bottom_up=True, name="get children")
|
|
|
|
|
|
|
|
|
|
# rangeify
|
|
|
|
|
tsink = graph_rewrite(tsink, pm_rangeify, ctx=RangeifyContext(), bottom_up=True, name="rangeify")
|
|
|
|
|
#tsink = graph_rewrite(tsink, symbolic_simple, bottom_up=True, name="symbolic") # this supports const folding
|
|
|
|
|
tsink = graph_rewrite(tsink, pm_cleanups, bottom_up=True, name="remove costly buffers")
|
|
|
|
|
|
|
|
|
|
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
|
|
|
|
# if it's not tagged by here, it's out
|
|
|
|
|
tsink = UOp.sink(*[x for x in tsink.parents if x.op is Ops.BUFFERIZE and len(x.arg.tags)])
|
|
|
|
|
|
|
|
|
|
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
|
|
|
|
|
|
|
|
|
# bufferize -> store
|
|
|
|
|
tsink = graph_rewrite(tsink, pm_add_buffers, bottom_up=True, name="bufferize to store")
|
|
|
|
|
tsink = graph_rewrite(tsink, split_kernels, ctx=uop_list, name="split kernels")
|
|
|
|
|
|
|
|
|
|
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
|
|
|
|
kernel_assign: dict[UOp, UOp] = {}
|
|
|
|
|
assign_rep: dict[UOp, UOp] = {}
|
|
|
|
|
for u in tensor_map[sink].toposort():
|
|
|
|
|
for u in tsink.toposort():
|
|
|
|
|
if u.op is not Ops.ASSIGN: continue
|
|
|
|
|
kernel_assign[u.buf_uop] = u
|
|
|
|
|
for s in u.src[1].src:
|
|
|
|
|
@@ -487,8 +570,14 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
|
|
|
|
|
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
|
|
|
|
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
|
|
|
|
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
|
|
|
|
if assign_rep:
|
|
|
|
|
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
|
|
|
|
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
|
|
|
|
|
|
|
|
|
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Kernel Graph")
|
|
|
|
|
return tensor_map
|
|
|
|
|
if getenv("VIZ"): graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
|
|
|
|
|
|
|
|
|
becomes_map: dict[UOp, UOp] = {}
|
|
|
|
|
for s in tsink.src:
|
|
|
|
|
assert s.tag is not None
|
|
|
|
|
for a in s.tag:
|
|
|
|
|
if a is None: continue
|
|
|
|
|
becomes_map[uop_list[cast(int, a)]] = s.replace(tag=None)
|
|
|
|
|
return becomes_map
|
|
|
|
|
|