s/UOps/Ops (#7500)

* s/UOps/Ops [pr]

* fix
This commit is contained in:
George Hotz
2024-11-03 11:26:10 +08:00
committed by GitHub
parent d078dcd0c8
commit c8bf09b7d4
58 changed files with 3003 additions and 3002 deletions

View File

@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import Kernel, verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
@@ -29,7 +29,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
if to_prerealize:
for pre in to_prerealize: pre.schedule()
sched = create_schedule(outs)
if filter_sink: sched = [s for s in sched if s.ast.op is UOps.SINK]
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not UOps.SINK: continue
if s.ast.op is not Ops.SINK: continue
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()
@@ -58,7 +58,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
dtypes.default_float = old_default_float
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
run_schedule(s.copy())
cnt = len([si for si in s if si.ast.op is UOps.SINK])
cnt = len([si for si in s if si.ast.op is Ops.SINK])
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
if getenv("CHECK", 1):
import torch
@@ -191,7 +191,7 @@ class TestSchedule(unittest.TestCase):
r1 = (x - r0).sum(axis=0).div(2)
out = r0 + r1
schedule = check_schedule(out, 2)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_cache_reduce_multiple_children(self):
@@ -202,7 +202,7 @@ class TestSchedule(unittest.TestCase):
out0 = r0 + y
out1 = r1 + y
schedule = check_schedule([out0, out1], 4)
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
assert len(reduceops) == 2
def test_fold_double_unary(self):
@@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
a = Tensor.empty(16, 16)
b = (a.sum(0) + a.max(1)) + 2
schedule = check_schedule(b, 2)
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
# multireduce spec
def test_multireduce_midreduce_nochase(self):
@@ -1117,7 +1117,7 @@ class TestSchedule(unittest.TestCase):
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
# schedule = check_schedule(b, 2)
schedule = check_schedule(b, 4)
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
run_schedule(schedule)
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
@@ -1352,7 +1352,7 @@ class TestIndexing(unittest.TestCase):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
lst = [xt] if isinstance(xt, Tensor) else xt
s = Tensor.schedule(*lst)
kernels = [si for si in s if si.ast.op is UOps.SINK]
kernels = [si for si in s if si.ast.op is Ops.SINK]
for si in kernels: verify_ast(si.ast)
run_schedule(s)
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
@@ -1607,20 +1607,20 @@ class TestIndexing(unittest.TestCase):
self.assertLess(et, 1200)
def test_no_rewrite_elementwise(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
@@ -1628,21 +1628,21 @@ class TestIndexing(unittest.TestCase):
verify_ast(rsink)
def test_no_reshape_reduceop(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
def test_reshape_many(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
@@ -1656,11 +1656,11 @@ class TestIndexing(unittest.TestCase):
sizes = [10*(i+1) for i in range(SZ)]
tms: List[float] = []
for sz in sizes:
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink, et = timeit(graph_rewrite, sink, view_right)
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
verify_ast(rsink)
@@ -1676,20 +1676,20 @@ class TestIndexing(unittest.TestCase):
def test_swizzle_rewrite(self):
# graph rewrite
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
UOp(Ops.STORE, dtypes.void, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
x8,
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# verify output
k = Kernel(sink)
@@ -1705,13 +1705,13 @@ class TestIndexing(unittest.TestCase):
a = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+1
# LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
const = ast_const(dtypes.int, 1, ())
alu = swizzle_r+const
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
# graph rewrite
sink = graph_rewrite(sink, view_right)
# verify output
@@ -1728,13 +1728,13 @@ class TestIndexing(unittest.TestCase):
b = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
# LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
# graph rewrite
sink = graph_rewrite(sink, view_right)
# verify output
@@ -1745,51 +1745,51 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_equal(c.numpy(), expected_out)
def test_swizzle_rewrite_alt(self):
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an EXPAND pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
# EXPAND is rewritten
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
# and pushed to the LOAD
new_load_st = unwrap([x for x in ret.parents if x.op is UOps.VIEW][0].st)
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
def test_permute_rewrite(self):
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
x1,
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),))
@track_rewrites()
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
ret = rewrite(sink)
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
assert len([x for x in ret.sparents if x.op is Ops.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
if __name__ == '__main__':
unittest.main(verbosity=2)