mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
delete grouper and kernelize (#12517)
* delete grouper and kernelize * +sys.setrecursionlimit
This commit is contained in:
@@ -81,7 +81,6 @@ print("******** third, the UOp ***********")
|
|||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
from tinygrad.helpers import RANGEIFY
|
from tinygrad.helpers import RANGEIFY
|
||||||
from tinygrad.schedule.kernelize import get_kernelize_map
|
|
||||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||||
|
|
||||||
# allocate some values + load in values
|
# allocate some values + load in values
|
||||||
@@ -95,7 +94,7 @@ out = a + b
|
|||||||
s = UOp(Ops.SINK, dtypes.void, (out,))
|
s = UOp(Ops.SINK, dtypes.void, (out,))
|
||||||
|
|
||||||
# group the computation into kernels
|
# group the computation into kernels
|
||||||
becomes_map = get_rangeify_map(s) if RANGEIFY else get_kernelize_map(s)
|
becomes_map = get_rangeify_map(s)
|
||||||
|
|
||||||
# the compute maps to an assign
|
# the compute maps to an assign
|
||||||
assign = becomes_map[a+b].base
|
assign = becomes_map[a+b].base
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
|
|||||||
|
|
||||||
Group UOps into kernels.
|
Group UOps into kernels.
|
||||||
|
|
||||||
::: tinygrad.schedule.kernelize.get_kernelize_map
|
::: tinygrad.schedule.rangeify.get_rangeify_map
|
||||||
options:
|
options:
|
||||||
members: false
|
members: false
|
||||||
show_labels: false
|
show_labels: false
|
||||||
|
|||||||
@@ -2,9 +2,7 @@ import sys
|
|||||||
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
|
from tinygrad import Tensor, fetch, GlobalCounters, dtypes
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
from tinygrad.nn.onnx import OnnxRunner
|
from tinygrad.nn.onnx import OnnxRunner
|
||||||
from tinygrad.schedule.kernelize import get_kernelize_map
|
|
||||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||||
from tinygrad.helpers import RANGEIFY
|
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
|
|
||||||
@@ -35,7 +33,7 @@ if __name__ == "__main__":
|
|||||||
if not in_target_path[s]:
|
if not in_target_path[s]:
|
||||||
independent_set[s] = None
|
independent_set[s] = None
|
||||||
independent = UOp.sink(*independent_set.keys())
|
independent = UOp.sink(*independent_set.keys())
|
||||||
kernelized = (get_rangeify_map if RANGEIFY else get_kernelize_map)(independent)
|
kernelized = get_rangeify_map(independent)
|
||||||
independent = independent.substitute(kernelized)
|
independent = independent.substitute(kernelized)
|
||||||
schedule, var_vals = create_schedule_with_vars(independent)
|
schedule, var_vals = create_schedule_with_vars(independent)
|
||||||
run_schedule(schedule)
|
run_schedule(schedule)
|
||||||
|
|||||||
@@ -12,10 +12,11 @@ from tinygrad import nn, dtypes, Device, Tensor, Variable
|
|||||||
from tinygrad.device import is_dtype_supported
|
from tinygrad.device import is_dtype_supported
|
||||||
from tinygrad.dtype import DType, ImageDType
|
from tinygrad.dtype import DType, ImageDType
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, graph_rewrite
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
from tinygrad.uop.symbolic import symbolic_simple
|
||||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY
|
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp, RANGEIFY
|
||||||
from tinygrad.schedule.kernelize import merge_views, get_kernelize_map, Kernel
|
from tinygrad.codegen.opt.swizzler import merge_views
|
||||||
|
from tinygrad.schedule.rangeify import get_rangeify_map, Kernel
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||||
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
|
from test.helpers import expect_rangeify_fails, expect_nonrangeify_fails
|
||||||
@@ -29,7 +30,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
|||||||
else:
|
else:
|
||||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||||
sink = UOp.sink(t) if t.op is not Ops.SINK else t
|
sink = UOp.sink(t) if t.op is not Ops.SINK else t
|
||||||
becomes_map = get_kernelize_map(sink)
|
becomes_map = get_rangeify_map(sink)
|
||||||
sched, _ = create_schedule_with_vars(sink.substitute(becomes_map))
|
sched, _ = create_schedule_with_vars(sink.substitute(becomes_map))
|
||||||
# test lowering all the ScheduleItems to ExecItems
|
# test lowering all the ScheduleItems to ExecItems
|
||||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||||
@@ -68,9 +69,6 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
|||||||
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
np.testing.assert_allclose(img.grad.numpy(), ref_img.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||||
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
np.testing.assert_allclose(w.grad.numpy(), ref_w.grad.detach().numpy(), atol=1e-6 if dtype == dtypes.float else 1e-2)
|
||||||
|
|
||||||
@track_rewrites(name=True)
|
|
||||||
def schedule_graph_rewrite(big_sink:UOp): return get_kernelize_map(big_sink)[big_sink]
|
|
||||||
|
|
||||||
class TestSchedule(unittest.TestCase):
|
class TestSchedule(unittest.TestCase):
|
||||||
def test_arange_avgpool2d(self, kcount=1):
|
def test_arange_avgpool2d(self, kcount=1):
|
||||||
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
|
x = Tensor.arange(25).reshape(1,1,5,5).cast(dtypes.float32)
|
||||||
@@ -2244,17 +2242,11 @@ class TestCopyFolding(unittest.TestCase):
|
|||||||
a = Tensor.empty(4).uop
|
a = Tensor.empty(4).uop
|
||||||
b = a.copy_to_device(a.device)
|
b = a.copy_to_device(a.device)
|
||||||
check_schedule(b, 0, filter_sink=False)
|
check_schedule(b, 0, filter_sink=False)
|
||||||
b = schedule_graph_rewrite(b)
|
|
||||||
# NOTE: Tensor.empty(4) always creates a VIEW(BUFFER) with ShapeTracker((4,)), we simplify this to jsut a BUFFER
|
|
||||||
# in the scheduler because buffer already has shape (4,)
|
|
||||||
self.assertIs(b, a.base)
|
|
||||||
|
|
||||||
def test_copy_to_same_device_alt(self):
|
def test_copy_to_same_device_alt(self):
|
||||||
a = Tensor.empty(4, 4).uop
|
a = Tensor.empty(4, 4).uop
|
||||||
b = a.copy_to_device(a.device)
|
b = a.copy_to_device(a.device)
|
||||||
check_schedule(b, 0, filter_sink=False)
|
check_schedule(b, 0, filter_sink=False)
|
||||||
b = schedule_graph_rewrite(b)
|
|
||||||
self.assertIs(b.base, a.base)
|
|
||||||
|
|
||||||
def test_copy_to_same_device_sched(self):
|
def test_copy_to_same_device_sched(self):
|
||||||
a = Tensor.ones(4).contiguous().realize().uop.as_buf()
|
a = Tensor.ones(4).contiguous().realize().uop.as_buf()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
|||||||
from tinygrad.helpers import all_same, prod, unwrap, colored
|
from tinygrad.helpers import all_same, prod, unwrap, colored
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
from tinygrad.shape.view import View, strides_for_shape, get_contraction_with_reduce
|
||||||
from tinygrad.schedule.grouper import ALWAYS_CONTIGUOUS
|
from tinygrad.schedule.rangeify import ALWAYS_CONTIGUOUS
|
||||||
from tinygrad.dtype import ImageDType, dtypes
|
from tinygrad.dtype import ImageDType, dtypes
|
||||||
|
|
||||||
merge_views = PatternMatcher([
|
merge_views = PatternMatcher([
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
from tinygrad.uop.ops import Ops, UOp, resolve, can_pad, GroupOp, UPat, PatternMatcher, graph_rewrite
|
|
||||||
from tinygrad.helpers import all_int, prod, unwrap, dedup, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, FUSE_CONV_BW
|
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
|
||||||
|
|
||||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
|
||||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.DEFINE_GLOBAL,
|
|
||||||
Ops.DEFINE_LOCAL, Ops.DEFINE_REG, Ops.LOAD}
|
|
||||||
|
|
||||||
# **** Grouper decides which of the UOps realize
|
|
||||||
|
|
||||||
def realize(ctx:dict[UOp, None], tr:UOp) -> None: ctx[tr] = None
|
|
||||||
|
|
||||||
def realize_parents(ctx:dict[UOp, None], rb:UOp) -> None:
|
|
||||||
for s in rb.src:
|
|
||||||
if s.op not in ALWAYS_CONTIGUOUS: ctx[s] = None
|
|
||||||
|
|
||||||
def realize_before_view(ctx:dict[UOp, None], view:UOp, tr:UOp) -> None:
|
|
||||||
st = unwrap(view.st)
|
|
||||||
# always realize unsafe pad ops before masked view
|
|
||||||
if any(v.mask is not None for v in st.views) and not can_pad(tr, ctx): return realize(ctx, tr)
|
|
||||||
# fold simple pads
|
|
||||||
if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(tr.shape) and resolve(prod(tr.shape) >= prod([y-x for x,y in m])): return
|
|
||||||
# realize before expand
|
|
||||||
if resolve(prod(tr.shape) < prod(st.shape)) and not DONT_REALIZE_EXPAND: return realize(ctx, tr)
|
|
||||||
|
|
||||||
do_realize = PatternMatcher([
|
|
||||||
# always realize SINK parents
|
|
||||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
|
||||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
|
||||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
|
||||||
# realize before expand or unsafe pad ops
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="tr"),), name="view"), realize_before_view),
|
|
||||||
# realize parents of COPY, MSELECT, MSTACK
|
|
||||||
(UPat((Ops.COPY, Ops.MSELECT, Ops.MSTACK), name="rb"), realize_parents),
|
|
||||||
])
|
|
||||||
|
|
||||||
def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:dict[UOp, dict[UOp, None]], realizes:dict[UOp, None],
|
|
||||||
reduce_for_op:dict[UOp, UOp], group:dict[UOp, None], cache:dict[tuple[UOp, ShapeTracker], None]) -> None:
|
|
||||||
if (tr, st) in cache: return
|
|
||||||
cache.setdefault((tr, st))
|
|
||||||
rsize = unwrap(r.st).size
|
|
||||||
if tr in realizes and tr is not r:
|
|
||||||
# can only fuse contiguous
|
|
||||||
# max one reduceop per kernel
|
|
||||||
if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r)
|
|
||||||
return group.setdefault(tr)
|
|
||||||
for tr_next in children.get(tr, {}):
|
|
||||||
# max one reduceop per kernel
|
|
||||||
if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r)
|
|
||||||
# can only fuse contiguous
|
|
||||||
if len(st_childs:=dedup(unwrap(x.st) for x in tr_next.src if x.base == tr)) > 1: return group.setdefault(r)
|
|
||||||
recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache)
|
|
||||||
|
|
||||||
def group_realizes(sink:UOp) -> dict[UOp, None]:
|
|
||||||
# start by adding uops that always realize
|
|
||||||
realizes: dict[UOp, None] = {}
|
|
||||||
sink = graph_rewrite(sink, do_realize, ctx=realizes, name="do_realize")
|
|
||||||
if DONT_GROUP_REDUCES: return realizes
|
|
||||||
|
|
||||||
# construct children graph (only for bases)
|
|
||||||
children: dict[UOp, dict[UOp, None]] = {}
|
|
||||||
assigns: dict[UOp, None] = {}
|
|
||||||
for u in (toposort:=sink.toposort()):
|
|
||||||
if u.op in {Ops.VIEW, Ops.SINK}: continue
|
|
||||||
if u.op is Ops.ASSIGN: assigns[u.buf_uop] = None
|
|
||||||
for s in u.src: children.setdefault(s.base, {})[u] = None
|
|
||||||
|
|
||||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
|
||||||
reduce_for_op: dict[UOp, UOp] = {}
|
|
||||||
double_reduces: list[UOp] = []
|
|
||||||
for r in toposort:
|
|
||||||
if r.op is not Ops.REDUCE_AXIS: continue
|
|
||||||
if len(r.arg) == 3 and r.arg[2] is True: continue
|
|
||||||
if FUSE_CONV_BW and r.src[0].base.op is Ops.REDUCE_AXIS and r.src[0] is not r.src[0].base: double_reduces.append(r)
|
|
||||||
if r in realizes: continue
|
|
||||||
group: dict[UOp, None] = {}
|
|
||||||
recursive_group(r, unwrap(r.st), r, children, realizes, reduce_for_op, group, cache={})
|
|
||||||
# max one reduceop per kernel
|
|
||||||
can_chase = all(tr not in reduce_for_op for tr in group)
|
|
||||||
for u in r.toposort(gate=lambda u: u not in realizes):
|
|
||||||
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST:
|
|
||||||
can_chase = False
|
|
||||||
break
|
|
||||||
# TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs
|
|
||||||
forced_realize = r in group
|
|
||||||
# can only have one output
|
|
||||||
if not forced_realize and len(group) > 1: forced_realize = True
|
|
||||||
# can only fuse assign if no other assign_target is used in the kernel
|
|
||||||
if not forced_realize and (assign_targets:={x.buf_uop for x in group if x.op is Ops.ASSIGN}):
|
|
||||||
parents = [r, *group]
|
|
||||||
while parents and not forced_realize:
|
|
||||||
p = parents.pop().base
|
|
||||||
if p.op is Ops.BUFFER and p in assigns and p not in assign_targets: forced_realize, can_chase = True, False
|
|
||||||
if p in realizes: continue
|
|
||||||
parents.extend(p.src)
|
|
||||||
if forced_realize or not group:
|
|
||||||
tr = r
|
|
||||||
if can_chase:
|
|
||||||
# can chase this down to contiguous children
|
|
||||||
st = unwrap(tr.st)
|
|
||||||
while len(lst:=children.get(tr, {})) == 1:
|
|
||||||
tr_next = next(iter(lst))
|
|
||||||
st_childs = dedup(unwrap(s.st) for s in tr_next.src if s.base is tr)
|
|
||||||
if len(st_childs) > 1: break
|
|
||||||
if st.size != st_childs[0].size: break
|
|
||||||
st = st + st_childs[0]
|
|
||||||
if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break
|
|
||||||
tr = tr_next
|
|
||||||
# don't cast to higher size before store (tr cannot be realized if forced_realize)
|
|
||||||
if tr.op is Ops.CAST and tr.dtype.itemsize > tr.src[0].dtype.itemsize:
|
|
||||||
tr = tr.src[0].base
|
|
||||||
group = {tr: None}
|
|
||||||
realizes[tr] = None
|
|
||||||
reduce_for_op.update((tr, r) for tr in group)
|
|
||||||
# fuse double reduces with no other child
|
|
||||||
for reduceop in double_reduces:
|
|
||||||
top_reduce = reduceop.src[0].base
|
|
||||||
if len(children.get(top_reduce, {})) == 1: del realizes[top_reduce]
|
|
||||||
return realizes
|
|
||||||
@@ -1,374 +0,0 @@
|
|||||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, identity_element, resolve
|
|
||||||
from tinygrad.uop.ops import track_rewrites, _substitute, KernelInfo
|
|
||||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
|
||||||
from tinygrad.uop.symbolic import symbolic_simple
|
|
||||||
from tinygrad.helpers import all_int, all_same, prod, dedup, unwrap, getenv, pluralize, DEBUG, SPLIT_REDUCEOP
|
|
||||||
from tinygrad.dtype import ImageDType
|
|
||||||
from tinygrad.schedule.multi import multi_pm
|
|
||||||
from tinygrad.schedule.grouper import group_realizes, ALWAYS_CONTIGUOUS
|
|
||||||
from tinygrad.schedule.rangeify import Kernel
|
|
||||||
from tinygrad.codegen.opt.swizzler import merge_views, apply_swizzle, swizzle_reduceop
|
|
||||||
from tinygrad.codegen.opt import Opt
|
|
||||||
|
|
||||||
# creation can recurse a lot
|
|
||||||
import sys
|
|
||||||
sys.setrecursionlimit(10000)
|
|
||||||
|
|
||||||
# **** schedule simplifier
|
|
||||||
|
|
||||||
def simplify_stride0_reduce(reduce:UOp, x:UOp):
|
|
||||||
# must be unmasked (NOTE: can be relaxed if not masked on stride 0 axis)
|
|
||||||
if any(v.mask is not None for v in unwrap(x.st).views): return None
|
|
||||||
# must have all stride 0 in the relevant axis (NOTE: can do partial)
|
|
||||||
if not all(unwrap(x.st).views[-1].strides[axis] == 0 for axis in reduce.arg[1]) or not all_int(x.shape): return None
|
|
||||||
prshape = prod(x.shape[i] for i in reduce.arg[1])
|
|
||||||
ret = x.shrink(tuple((0,s) if i not in reduce.arg[1] else (0,1) for i,s in enumerate(x.shape)))
|
|
||||||
match reduce.arg[0]:
|
|
||||||
case Ops.ADD: return ret*prshape
|
|
||||||
case Ops.MUL: return ret.pow(prshape)
|
|
||||||
case Ops.MAX: return ret # NOTE: Ops.MAX is passthrough
|
|
||||||
|
|
||||||
def split_reduceop(reduce:UOp, x:UOp):
|
|
||||||
if not SPLIT_REDUCEOP or not all_int(x.shape) or (prod(x.shape)//prod(reduce.shape))<getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return None
|
|
||||||
# if there are few globals, make some reduces into globals by splitting into two kernels
|
|
||||||
# cap output buffer to 2**22: heuristic number of global outputs to achieve max occupancy with enough locals+upcasts for gemm
|
|
||||||
# ~2**10 should be enough if GROUP is used
|
|
||||||
# 256 split maximum should be "negligible reduce" for low prod(reduce.shape), 8 split minimum.
|
|
||||||
# split is moved to the end to provide maximum locality for the second phase reduce.
|
|
||||||
real_strides = unwrap(x.st).real_strides(ignore_valid=True)
|
|
||||||
if not (split_candidates:=[(i,d) for i in reduce.arg[1] for d in range(min(256,2**getenv("REDUCEOP_SPLIT_SIZE",22)//prod(reduce.shape)),8-1,-1)
|
|
||||||
if x.shape[i]%d==0 and real_strides[i]!=0]): return None
|
|
||||||
dim_to_split, divisor = split_candidates[0]
|
|
||||||
splitted_shape = x.shape[:dim_to_split]+(divisor,)+(x.shape[dim_to_split]//divisor,)+x.shape[dim_to_split+1:]
|
|
||||||
splitted = x.reshape(splitted_shape).permute(tuple([d for d in range(len(splitted_shape)) if d!=dim_to_split]+[dim_to_split]))
|
|
||||||
if DEBUG >= 3: print(f"split {divisor}: {x.shape} -> {splitted.shape} -> {reduce.shape}")
|
|
||||||
# reduce original axes, then split
|
|
||||||
return splitted.r(*reduce.arg).r(reduce.arg[0], (len(reduce.shape),)).reshape(reduce.shape)
|
|
||||||
|
|
||||||
def copy_reorder_view(copy:UOp, view:UOp, base:UOp):
|
|
||||||
if prod(view.shape) < prod(base.shape): return view.contiguous().copy_to_device(copy.device)
|
|
||||||
return base.copy_to_device(copy.device).view(view.arg)
|
|
||||||
|
|
||||||
kernelize_sym = symbolic_simple+PatternMatcher([
|
|
||||||
# UOp with size 0 is zero
|
|
||||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
|
||||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here
|
|
||||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD), name="x"), lambda x: x.src[0]),
|
|
||||||
# reduce of size 0 is the identity element
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
|
||||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
|
||||||
# reduce on stride 0 is collapsed
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), simplify_stride0_reduce),
|
|
||||||
# split_reduceop
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), split_reduceop),
|
|
||||||
# COPY(CONST) creates a new CONST on the destination device
|
|
||||||
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
|
|
||||||
# non device changing COPY is a NOOP
|
|
||||||
(UPat(Ops.COPY, name="c", src=(UPat.var("x"), UPat(Ops.DEVICE))), lambda c,x: x if c.device == x.device else None),
|
|
||||||
# store a shrink before COPY, otherwise view after the COPY
|
|
||||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, src=(UPat.var("base"),), name="view"), UPat(Ops.DEVICE)), name="copy"), copy_reorder_view),
|
|
||||||
# remove cast to image when it's already a contiguous image
|
|
||||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
|
||||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
|
||||||
# CAST before masking constants
|
|
||||||
(UPat.cvar("x").view().cast(name="c"), lambda x,c: x.cast(c.dtype).view(c.src[0].arg)),
|
|
||||||
# make things that can't be images not images
|
|
||||||
(UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW, Ops.CONST, Ops.DEVICE}, name="u"), lambda u: u.replace(dtype=dt.base) if isinstance(dt:=u.dtype,ImageDType)
|
|
||||||
and (prod(u.shape) != prod(dt.shape) or not any(u.shape[x]%4 == 0 for x in u.st.unit_stride_axes())) else None),
|
|
||||||
# remove contiguous if we can just view the buffer
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER, name="buf"),)),)),
|
|
||||||
lambda root,view,buf: view if view.st.contiguous and view.size == buf.size else None),
|
|
||||||
# contiguous/buffer/copy/assign is already contiguous
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="root", src=(UPat((Ops.CONTIGUOUS, Ops.BUFFER, Ops.COPY, Ops.ASSIGN)),)), lambda root: root.src[0]),
|
|
||||||
# substitute BITCAST/CONTIGUOUS with BUFFER_VIEW on DISK
|
|
||||||
(UPat((Ops.BITCAST, Ops.CONTIGUOUS), src=(UPat.var("x"),), name="t"), lambda x,t: UOp(Ops.BUFFER_VIEW, t.dtype, (x.base,),
|
|
||||||
(t.size, x.st.views[0].offset)).reshape(t.shape) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
|
||||||
# double ASSIGN to same target is one ASSIGN
|
|
||||||
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))))), lambda x,t: t.assign(x.contiguous())),
|
|
||||||
# ASSIGN to unrealized replaces the UOp
|
|
||||||
(UPat(Ops.ASSIGN, src=(UPat.var("t"), UPat.var("x"))), lambda x,t: x.contiguous() if t.base.op not in {Ops.BUFFER, Ops.BUFFER_VIEW} and
|
|
||||||
not (t.base.op is Ops.MSTACK and all(x.op is Ops.BUFFER for x in t.base.src)) else None),
|
|
||||||
# put CAST to smaller dtype before EXPAND
|
|
||||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm"),)), lambda cast,vm: vm.base.cast(cast.dtype).view(vm.st)
|
|
||||||
if cast.dtype.itemsize <= vm.dtype.itemsize and resolve(prod(vm.shape) > vm.st.real_size()) else None),
|
|
||||||
# put UnaryOps before EXPANDs, if it can fuse with the input
|
|
||||||
(UPat(GroupOp.Unary, src=(UPat(Ops.VIEW, src=(UPat(GroupOp.All-ALWAYS_CONTIGUOUS, name="inp"),), name="v"),), name="alu"),
|
|
||||||
lambda inp,v,alu: inp.alu(alu.op).view(v.st) if resolve(prod(alu.shape) > v.st.real_size()) else None),
|
|
||||||
])
|
|
||||||
|
|
||||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
|
||||||
|
|
||||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
|
||||||
if (sti:=unwrap(src.st).invert(src.base.shape)) is not None: ctx[src.base] = contig.view(sti)
|
|
||||||
|
|
||||||
replace_contiguous = PatternMatcher([
|
|
||||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.VIEW, name="src"),), name="contig"), found_contiguous),
|
|
||||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
|
||||||
])
|
|
||||||
|
|
||||||
# **** create kernels
|
|
||||||
|
|
||||||
def create_kernel(x:UOp, b:UOp|None=None):
|
|
||||||
if b is None: b = UOp.new_buffer(x.device, x.size, x.dtype)
|
|
||||||
kernel = UOp(Ops.KERNEL, src=(b,)+x.src, arg=Kernel(x.sink(), m if (m:=x.metadata) else ()))
|
|
||||||
buffer = b.base if b.size == b.base.size else UOp(Ops.BUFFER_VIEW, b.dtype, (b.base,), (b.size, b.arg.views[0].offset))
|
|
||||||
# we have to shrink the buffer back to the symbolic shape
|
|
||||||
return buffer.assign(kernel).reshape(tuple(d.vmax if isinstance(d, UOp) else d for d in x.shape)).shrink(tuple((0, d) for d in x.shape))
|
|
||||||
|
|
||||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER, Ops.MSELECT, Ops.MSTACK, Ops.MULTI, Ops.BIND}
|
|
||||||
def append_to_kernel(x:UOp):
|
|
||||||
new_srcs: list[UOp] = []
|
|
||||||
metadata = x.arg.metadata
|
|
||||||
for s in x.src:
|
|
||||||
if s.op in DONT_PLACE_IN_KERNEL: new_srcs.append(s)
|
|
||||||
else:
|
|
||||||
new_srcs.extend(s.src)
|
|
||||||
# NOTE: because const and device are shared UOps they don't change metadata
|
|
||||||
# NOTE: if it's a reshape after ASSIGN we're not fusing that parent kernel
|
|
||||||
if s.base.op not in {Ops.CONST, Ops.DEVICE} and (not (s.op is Ops.RESHAPE and s.base.op is Ops.ASSIGN)) and (m:=s.metadata): metadata += m
|
|
||||||
if (new_src:=tuple(dedup(new_srcs))) != x.src: return x.replace(src=new_src, arg=Kernel(x.arg.ast, tuple(dedup(metadata))))
|
|
||||||
|
|
||||||
create_kernels = PatternMatcher([
|
|
||||||
# always give assign/contiguous a kernel
|
|
||||||
(UPat.assign(UPat.var("b"), UPat(GroupOp.All-{Ops.KERNEL}), name="x"), create_kernel),
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="x"), create_kernel),
|
|
||||||
# walk back the local graph until we reach a realized source
|
|
||||||
(UPat(Ops.KERNEL, name="x"), append_to_kernel),
|
|
||||||
# push RESHAPE through MSELECT
|
|
||||||
(UPat(Ops.MSELECT, src=(UPat(Ops.RESHAPE, name="r"),), name="ms"), lambda ms,r: r.src[0].mselect(ms.arg).reshape(r.arg)),
|
|
||||||
# push RESHAPE through MSTACK
|
|
||||||
(UPat(Ops.MSTACK, src=UPat(Ops.RESHAPE), name="ms"),
|
|
||||||
lambda ms: UOp(Ops.MSTACK, ms.dtype, tuple(x.src[0] for x in ms.src)).reshape(ms.src[0].arg)),
|
|
||||||
])
|
|
||||||
|
|
||||||
def add_stores(ctx, sink: UOp):
|
|
||||||
stores = []
|
|
||||||
for i,x in enumerate(sink.src):
|
|
||||||
gbl = UOp(Ops.DEFINE_GLOBAL, (s:=x.base).dtype.ptr(ctx[i].size), (), i)
|
|
||||||
# if this is an assign then we already have a buffer with a view that should be the target of the store
|
|
||||||
if x.op is Ops.ASSIGN: stores.append(UOp.store(gbl.view(unwrap(s.st)), s))
|
|
||||||
# otherwise we have to create the shapetracker and shrink it to the correct symbolic shape
|
|
||||||
else: stores.append(
|
|
||||||
UOp.store(gbl.reshape(tuple(int(d.vmax) if isinstance(d,UOp) else d for d in s.shape)).shrink(tuple((0,d) for d in s.shape)),s))
|
|
||||||
return UOp.sink(*stores, arg=sink.arg)
|
|
||||||
# **** fix kernel AST
|
|
||||||
|
|
||||||
def unbind_view(x:UOp):
|
|
||||||
if any(x.op is Ops.BIND for x in x.arg.vars()): return x.replace(arg=x.arg.unbind()[0])
|
|
||||||
return None
|
|
||||||
|
|
||||||
replace_buffers = PatternMatcher([
|
|
||||||
# sink on contig creates a KernelInfo
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="c").sink(name="s"),
|
|
||||||
lambda s,c: s.replace(src=(c.replace(arg=None),), arg=KernelInfo(opts_to_apply=c.arg)) \
|
|
||||||
if s.arg is None and c.arg is not None and isinstance(c.arg[0], Opt) else None),
|
|
||||||
# replace ASSIGN with the target BUFFER
|
|
||||||
(UPat(Ops.ASSIGN, src=(UPat((Ops.BUFFER, Ops.LOAD)), UPat(Ops.KERNEL)), name="assign", allow_any_len=True), lambda assign: assign.src[0]),
|
|
||||||
# HACK: select the 0 branch of MSTACK (the device is wrong after this, is that okay?)
|
|
||||||
(UPat(Ops.MSTACK, name="x"), lambda x: x.src[0]),
|
|
||||||
# LOAD
|
|
||||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)).load()),
|
|
||||||
# no SINK for meta ops
|
|
||||||
(UPat(Ops.SINK, src=(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Meta, name="x"),),))), lambda x:x),
|
|
||||||
# STORE (except for meta ops)
|
|
||||||
(UPat(Ops.SINK, src=UPat(GroupOp.All-{Ops.STORE}), name="sink"), add_stores),
|
|
||||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
|
||||||
(UPat((Ops.CONTIGUOUS, Ops.MSELECT), src=(UPat.var("x"),)), lambda x: x),
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
|
||||||
# passthrough ASSIGN (but let MSTACK process first)
|
|
||||||
(UPat(Ops.ASSIGN, src=(UPat(GroupOp.All-{Ops.MSTACK}), UPat()), name="x"), lambda x: x.src[1]),
|
|
||||||
# remove any BINDs from VIEWS
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(), UPat((Ops.BIND, Ops.DEFINE_VAR))), allow_any_len=True, name="x"), lambda x: x.replace(src=x.src[0:1])),
|
|
||||||
# remove any BINDs from DEFINE_VARs
|
|
||||||
(UPat(Ops.BIND, name="x"), lambda x: x.src[0]),
|
|
||||||
# remove BINDs from ShapeTrackers
|
|
||||||
(UPat(Ops.VIEW, name="x"), unbind_view),
|
|
||||||
])
|
|
||||||
|
|
||||||
def fix_kernel_ast(k:UOp) -> UOp|None:
|
|
||||||
if k.arg.ast.op in GroupOp.Meta or all(s.op is Ops.STORE for s in k.arg.ast.src): return None
|
|
||||||
# replace buffer with define_global + add load/store last
|
|
||||||
bufs = []
|
|
||||||
for s in k.src:
|
|
||||||
if s.op is Ops.BIND: continue
|
|
||||||
s = s.buf_uop
|
|
||||||
# traverse back through MSELECT and MSTACK. HACK: 0 branch of MSTACK only
|
|
||||||
while s.op in {Ops.MSELECT, Ops.MSTACK}: s = s.src[0]
|
|
||||||
bufs.append(s)
|
|
||||||
# replace global memory ops with the BUFFER they write to
|
|
||||||
# NOTE: merge_views is needed to unbind the reshapes
|
|
||||||
ast = graph_rewrite(k.arg.ast, merge_views+replace_buffers, bufs, bottom_up=True, name="replace buffers")
|
|
||||||
if ast.op is Ops.SINK and not all_same([x.device for x in k.src if x.op is not Ops.BIND]):
|
|
||||||
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop.buffer for b in k.src)}")
|
|
||||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
|
||||||
|
|
||||||
create_ast = PatternMatcher([
|
|
||||||
(UPat(Ops.KERNEL, name="k"), fix_kernel_ast),
|
|
||||||
(UPat(Ops.DEFINE_VAR, src=(UPat(),), allow_any_len=True, name="x"), lambda x: x.replace(src=())),
|
|
||||||
])
|
|
||||||
|
|
||||||
# ** add metadata of KERNEL outputs
|
|
||||||
|
|
||||||
def append_metadata(root:UOp, k:UOp):
|
|
||||||
if not root.metadata or (new_metadata:=tuple(dedup(k.arg.metadata+root.metadata))) == k.arg.metadata: return None
|
|
||||||
return root.replace(src=(root.src[0], k.replace(arg=Kernel(k.arg.ast, new_metadata)))+root.src[2:])
|
|
||||||
|
|
||||||
replace_metadata = PatternMatcher([(UPat(Ops.ASSIGN, src=(UPat(), UPat(Ops.KERNEL, name="k")), name="root", allow_any_len=True), append_metadata),])
|
|
||||||
|
|
||||||
pm_fuse = PatternMatcher([
|
|
||||||
# FUSE on CONTIGUOUS removes FUSE
|
|
||||||
(UPat(Ops.CONTIGUOUS, name="c").fuse(), lambda c: c),
|
|
||||||
|
|
||||||
# FUSE triggers swizzle on reduceop
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(Ops.REDUCE_AXIS, src=(UPat.var("src"),), name="r").or_casted(),), name="view").fuse(),
|
|
||||||
lambda r,src,view: ret.cast(view.dtype) if (ret:=swizzle_reduceop(r, src, view, fuse=True)) is not None else None),
|
|
||||||
|
|
||||||
# FUSE on reduce (without view) adds fuse marker to grouper
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="r").fuse(),
|
|
||||||
lambda r: r.replace(src=(r.src[0].fuse(),), arg=r.arg+(True,)) if len(r.arg) == 2 else None),
|
|
||||||
|
|
||||||
# remove FUSE and insert CONTIGUOUS if it's an unsafe pad
|
|
||||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.UnsafePad, name="alu"),), name="view").fuse(),
|
|
||||||
lambda alu, view: alu.contiguous().view(view.st) if any(v.mask is not None for v in view.st.views) else None),
|
|
||||||
|
|
||||||
# FUSE elementwise.
|
|
||||||
(UPat(Ops.VIEW, src=(UPat({*GroupOp.ALU, Ops.CAST}, name="alu"),), name="view").fuse(),
|
|
||||||
lambda alu, view: alu.replace(src=tuple(apply_swizzle(x.view(view.arg)).fuse() for x in alu.src))),
|
|
||||||
|
|
||||||
# push FUSE through to srcs
|
|
||||||
(UPat(Ops.FUSE, name="x"), lambda x: x.src[0].replace(src=tuple(y.fuse() for y in x.src[0].src))),
|
|
||||||
])
|
|
||||||
|
|
||||||
def do_fusion(x:UOp):
|
|
||||||
found_contiguous = {}
|
|
||||||
def gate_contiguous(x):
|
|
||||||
if is_contiguous:=(x.op is Ops.CONTIGUOUS): found_contiguous[x] = x.replace(src=(UOp(Ops.VIEW, arg=x.st), UOp.unique()))
|
|
||||||
return not is_contiguous
|
|
||||||
x.toposort(gate=gate_contiguous)
|
|
||||||
del gate_contiguous
|
|
||||||
return graph_rewrite(x.substitute(found_contiguous), pm_fuse, name="local fusion").substitute({v:k for k,v in found_contiguous.items()})
|
|
||||||
|
|
||||||
def fuse_arange(root:UOp):
|
|
||||||
# skip if root is arange
|
|
||||||
if root.src[0].base.op is Ops.CONST: return None
|
|
||||||
# gather all local aranges (including any fused ones)
|
|
||||||
local_arange: list[UOp] = []
|
|
||||||
def gate_reduce(u):
|
|
||||||
if u.op is Ops.REDUCE_AXIS and u.src[0].base.op is Ops.CONST: local_arange.append(u)
|
|
||||||
return u.op not in {*ALWAYS_CONTIGUOUS, Ops.REDUCE_AXIS} or u is root
|
|
||||||
toposort = root.toposort(gate=gate_reduce)
|
|
||||||
if not local_arange: return None
|
|
||||||
# fuse the nearest expand child of arange
|
|
||||||
local_children: dict[UOp, list[UOp]] = {}
|
|
||||||
for u in toposort:
|
|
||||||
for s in u.src: local_children.setdefault(s, []).append(u)
|
|
||||||
fuse_rep: dict[UOp, UOp] = {}
|
|
||||||
for r in local_arange:
|
|
||||||
# skip if already fused
|
|
||||||
if len(r.arg) > 2: continue
|
|
||||||
q = list(local_children[r])
|
|
||||||
while q:
|
|
||||||
u = q.pop()
|
|
||||||
if not (curr_children:=local_children.get(u, [])): continue
|
|
||||||
for child in curr_children:
|
|
||||||
other_paths = {s for s in child.toposort() if s.op in {Ops.REDUCE_AXIS, Ops.BUFFER} and s not in {root, r}}
|
|
||||||
fuse_rep[child] = child.replace(src=tuple(s.fuse() if s is u else s for s in child.src))
|
|
||||||
if other_paths: break
|
|
||||||
else: q.extend(curr_children)
|
|
||||||
return root.substitute(fuse_rep, name="fuse_arange") if fuse_rep else None
|
|
||||||
|
|
||||||
do_fuse = PatternMatcher([
|
|
||||||
(UPat(Ops.FUSE, name="x"), do_fusion),
|
|
||||||
(UPat(Ops.REDUCE_AXIS, name="root"), fuse_arange),
|
|
||||||
])
|
|
||||||
|
|
||||||
add_contiguous = PatternMatcher([(UPat(GroupOp.All-{Ops.CONTIGUOUS, Ops.ASSIGN}, name="x"),
|
|
||||||
lambda ctx,x: x.replace(tag=1).contiguous() if x in ctx and x.tag is None else None)])
|
|
||||||
|
|
||||||
# TODO: get this from the device through GrouperOpts
|
|
||||||
DEVICE_MAX_BUFS = {"METAL":32, "WEBGPU":8}
|
|
||||||
|
|
||||||
def limit_bufs(root:UOp):
|
|
||||||
# check if backend has a buffer limit
|
|
||||||
device = root.device if isinstance(root.device, str) else root.device[0].split(":")[0]
|
|
||||||
if not (MAX_BUFS:=getenv("MAX_KERNEL_BUFFERS", DEVICE_MAX_BUFS.get(device, 0))): return None
|
|
||||||
# count number of unique buffers flowing into this op
|
|
||||||
bufs: set[UOp] = set()
|
|
||||||
def gate_input(u:UOp):
|
|
||||||
if (is_load:=(u.op in {Ops.BUFFER, Ops.CONTIGUOUS, Ops.ASSIGN, Ops.MSTACK, Ops.DEFINE_VAR})): bufs.add(u)
|
|
||||||
return not is_load
|
|
||||||
root.toposort(gate=gate_input)
|
|
||||||
# NOTE: this -1 is for the output buffer
|
|
||||||
if len(bufs)>=MAX_BUFS-1:
|
|
||||||
return root.replace(src=tuple(s if s.base in bufs else s.replace(tag=1).contiguous() for s in root.src))
|
|
||||||
|
|
||||||
def view_add_srcs(x:UOp):
|
|
||||||
if len(avars:=x.arg.vars()) and len(x.src) == 1:
|
|
||||||
return x.replace(src=x.src+tuple(avars))
|
|
||||||
return None
|
|
||||||
|
|
||||||
finalize_contiguous = PatternMatcher([
|
|
||||||
# if an op takes more than one input, check combined LOADs don't exceed device limits
|
|
||||||
(UPat(set.union(GroupOp.Binary, GroupOp.Ternary), name="root"), limit_bufs),
|
|
||||||
# merge contiguous
|
|
||||||
(UPat(Ops.CONTIGUOUS, src=(UPat(Ops.CONTIGUOUS),), name="x"), lambda x: x.src[0]),
|
|
||||||
# simplify views
|
|
||||||
(UPat(Ops.VIEW, src=(UPat.var('x')), name="v"), lambda x,v: x.view(new_st) if (new_st:=v.arg.simplify()) != v.arg else None),
|
|
||||||
# vars to views srcs
|
|
||||||
(UPat(Ops.VIEW, name="x"), view_add_srcs),
|
|
||||||
])
|
|
||||||
|
|
||||||
remove_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
|
||||||
|
|
||||||
@track_rewrites(name=lambda sink,ret: f"Schedule {pluralize('Kernel',len([u for u in ret[sink].toposort() if u.op is Ops.KERNEL]))}", replay=True)
|
|
||||||
def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
|
|
||||||
"""
|
|
||||||
Function to transform the Tensor UOp graph into a version with Ops.KERNEL
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sink: The Ops.SINK rooting the Tensor graph.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Map transforming each UOp in the sink to the Ops.KERNEL graph.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# multi + merge_views + simplify
|
|
||||||
tensor_map = graph_rewrite_map(sink, multi_pm+do_fuse+merge_views+kernelize_sym+replace_contiguous, ctx={}, name="merge_views")
|
|
||||||
|
|
||||||
# display the cleaned up tensor graph
|
|
||||||
if getenv("VIZ"): graph_rewrite(tensor_map[sink], PatternMatcher([]), name="View Tensor Graph")
|
|
||||||
|
|
||||||
# insert contiguous in places determined by the realize map
|
|
||||||
realize_map = group_realizes(tensor_map[sink])
|
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], add_contiguous, ctx=realize_map, bottom_up=True, input_map=tensor_map, name="add_contiguous")
|
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], finalize_contiguous+remove_tags, input_map=tensor_map, name="finalize_contiguous")
|
|
||||||
|
|
||||||
# group into kernels (this is context-free)
|
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], create_kernels, input_map=tensor_map, name="create_kernels")
|
|
||||||
|
|
||||||
# if a kernel depends on a buffer, and that buffer is later assigned to, make the assign depend on the kernel's assign
|
|
||||||
kernel_assign: dict[UOp, UOp] = {}
|
|
||||||
assign_rep: dict[UOp, UOp] = {}
|
|
||||||
for u in tensor_map[sink].toposort():
|
|
||||||
if u.op is not Ops.ASSIGN: continue
|
|
||||||
kernel_assign[u.buf_uop] = u
|
|
||||||
for s in u.src[1].src:
|
|
||||||
# TODO: this is probably broken for MSELECT/MSTACK
|
|
||||||
if s.op is not Ops.BUFFER or s is u.buf_uop or (a:=kernel_assign.get(s)) is None: continue
|
|
||||||
if any(x.op is Ops.ASSIGN and x.buf_uop is s for x in u.toposort()):
|
|
||||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on ASSIGN or BUFFER")
|
|
||||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
|
||||||
if assign_rep:
|
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], _substitute, ctx=assign_rep, bottom_up=True, input_map=tensor_map, name="fix_assign")
|
|
||||||
|
|
||||||
# finally, create the AST for kernels
|
|
||||||
tensor_map = graph_rewrite_map(tensor_map[sink], create_ast+replace_metadata, bottom_up=True, input_map=tensor_map, name="create_ast")
|
|
||||||
|
|
||||||
# display the final graph
|
|
||||||
sched_sink = tensor_map[sink]
|
|
||||||
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
|
|
||||||
|
|
||||||
# verify Kernels match the spec
|
|
||||||
if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)
|
|
||||||
|
|
||||||
return tensor_map
|
|
||||||
@@ -10,6 +10,10 @@ from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, si
|
|||||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_unparented
|
||||||
from tinygrad.codegen.opt import Opt
|
from tinygrad.codegen.opt import Opt
|
||||||
|
|
||||||
|
# creation can recurse a lot
|
||||||
|
import sys
|
||||||
|
sys.setrecursionlimit(10000)
|
||||||
|
|
||||||
# *****************
|
# *****************
|
||||||
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
# 0. do some cleanup rewrites, mostly copied from the old stuff
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ from tinygrad.engine.memory import memory_planner
|
|||||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||||
from tinygrad.schedule.rangeify import get_rangeify_map
|
from tinygrad.schedule.rangeify import get_rangeify_map
|
||||||
from tinygrad.schedule.multi import get_multi_map
|
from tinygrad.schedule.multi import get_multi_map
|
||||||
from tinygrad.schedule.kernelize import get_kernelize_map
|
|
||||||
|
|
||||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||||
|
|
||||||
@@ -232,7 +231,7 @@ class Tensor(MathTrait):
|
|||||||
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
_apply_map_to_tensors(get_multi_map(big_sink), "Apply Multi Map")
|
||||||
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
big_sink = UOp.sink(*flatten([x.uop.src if x.uop.op is Ops.MULTI else [x.uop] for x in (self,)+lst]))
|
||||||
|
|
||||||
becomes_map = get_rangeify_map(big_sink) if RANGEIFY else get_kernelize_map(big_sink)
|
becomes_map = get_rangeify_map(big_sink)
|
||||||
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user