mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -12,7 +12,7 @@ from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, st_fixup, view_left
|
||||
@@ -29,7 +29,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize: pre.schedule()
|
||||
sched = create_schedule(outs)
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is UOps.SINK]
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
@@ -38,7 +38,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr
|
||||
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
# test the (sink) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op is not UOps.SINK: continue
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
@@ -58,7 +58,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is UOps.SINK])
|
||||
cnt = len([si for si in s if si.ast.op is Ops.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
if getenv("CHECK", 1):
|
||||
import torch
|
||||
@@ -191,7 +191,7 @@ class TestSchedule(unittest.TestCase):
|
||||
r1 = (x - r0).sum(axis=0).div(2)
|
||||
out = r0 + r1
|
||||
schedule = check_schedule(out, 2)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_cache_reduce_multiple_children(self):
|
||||
@@ -202,7 +202,7 @@ class TestSchedule(unittest.TestCase):
|
||||
out0 = r0 + y
|
||||
out1 = r1 + y
|
||||
schedule = check_schedule([out0, out1], 4)
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is UOps.REDUCE_AXIS]
|
||||
reduceops = [x for si in schedule for x in si.ast.parents if x.op is Ops.REDUCE_AXIS]
|
||||
assert len(reduceops) == 2
|
||||
|
||||
def test_fold_double_unary(self):
|
||||
@@ -1108,7 +1108,7 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.empty(16, 16)
|
||||
b = (a.sum(0) + a.max(1)) + 2
|
||||
schedule = check_schedule(b, 2)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
# multireduce spec
|
||||
def test_multireduce_midreduce_nochase(self):
|
||||
@@ -1117,7 +1117,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = (a.sum(0)+a.max(0) + a.max(1)+a.sum(1)) + 2
|
||||
# schedule = check_schedule(b, 2)
|
||||
schedule = check_schedule(b, 4)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, UOps.REDUCE_AXIS)
|
||||
self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
run_schedule(schedule)
|
||||
np.testing.assert_allclose(b.numpy(), a.numpy().sum(0)+a.numpy().max(0) + a.numpy().max(1)+a.numpy().sum(1)+2, atol=1e-4, rtol=1e-4)
|
||||
|
||||
@@ -1352,7 +1352,7 @@ class TestIndexing(unittest.TestCase):
|
||||
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
|
||||
lst = [xt] if isinstance(xt, Tensor) else xt
|
||||
s = Tensor.schedule(*lst)
|
||||
kernels = [si for si in s if si.ast.op is UOps.SINK]
|
||||
kernels = [si for si in s if si.ast.op is Ops.SINK]
|
||||
for si in kernels: verify_ast(si.ast)
|
||||
run_schedule(s)
|
||||
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
|
||||
@@ -1607,20 +1607,20 @@ class TestIndexing(unittest.TestCase):
|
||||
self.assertLess(et, 1200)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
@@ -1628,21 +1628,21 @@ class TestIndexing(unittest.TestCase):
|
||||
verify_ast(rsink)
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
|
||||
def test_reshape_many(self):
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
for _ in range(24): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
# NOTE: this AST is always correct in the entire lifecycle of graph_rewrite!
|
||||
# with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
@@ -1656,11 +1656,11 @@ class TestIndexing(unittest.TestCase):
|
||||
sizes = [10*(i+1) for i in range(SZ)]
|
||||
tms: List[float] = []
|
||||
for sz in sizes:
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink, et = timeit(graph_rewrite, sink, view_right)
|
||||
with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink)
|
||||
verify_ast(rsink)
|
||||
@@ -1676,20 +1676,20 @@ class TestIndexing(unittest.TestCase):
|
||||
|
||||
def test_swizzle_rewrite(self):
|
||||
# graph rewrite
|
||||
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
sink = UOp(Ops.SINK, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8:=UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||
UOp(Ops.LOAD, dtypes.int, arg=None, src=(
|
||||
x8,
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) # noqa E501
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# verify output
|
||||
k = Kernel(sink)
|
||||
@@ -1705,13 +1705,13 @@ class TestIndexing(unittest.TestCase):
|
||||
a = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+1
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||
swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||
const = ast_const(dtypes.int, 1, ())
|
||||
alu = swizzle_r+const
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),))
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
@@ -1728,13 +1728,13 @@ class TestIndexing(unittest.TestCase):
|
||||
b = Tensor.randint(4,).realize()
|
||||
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
|
||||
# LazyBuffer to pre-rewrite AST
|
||||
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
|
||||
alu = UOp(UOps.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(UOps.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||
r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,)))
|
||||
alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501
|
||||
# graph rewrite
|
||||
sink = graph_rewrite(sink, view_right)
|
||||
# verify output
|
||||
@@ -1745,51 +1745,51 @@ class TestIndexing(unittest.TestCase):
|
||||
np.testing.assert_equal(c.numpy(), expected_out)
|
||||
|
||||
def test_swizzle_rewrite_alt(self):
|
||||
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||
# there's an EXPAND pushing through the REDUCE_AXIS
|
||||
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
||||
ret = graph_rewrite(graph_rewrite(swizzle, view_left), view_right)
|
||||
# EXPAND is rewritten
|
||||
self.assertEqual(prod(ret.st.shape), prod(ret.src[0].st.shape))
|
||||
# and pushed to the LOAD
|
||||
new_load_st = unwrap([x for x in ret.parents if x.op is UOps.VIEW][0].st)
|
||||
new_load_st = unwrap([x for x in ret.parents if x.op is Ops.VIEW][0].st)
|
||||
self.assertGreater(prod(new_load_st.shape), prod(ld_st.shape))
|
||||
self.assertEqual(new_load_st.views[0].strides, (0, 9, 3, 0, 1, 0, 27))
|
||||
|
||||
def test_permute_rewrite(self):
|
||||
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
sink = UOp(Ops.STORE, dtypes.void, arg=None, src=(
|
||||
x1:=UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2:=UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(Ops.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||
x1,
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 32, 1, 1024), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8)), src=(
|
||||
UOp(Ops.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11:=UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||
x2,)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(Ops.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(Ops.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||
x11,)),)),)),)),))
|
||||
@track_rewrites()
|
||||
def rewrite(sink): return graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
ret = rewrite(sink)
|
||||
assert len([x for x in ret.sparents if x.op is UOps.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||
assert len([x for x in ret.sparents if x.op is Ops.VIEW and len(x.src) != 0]) == 0, f"unmerged views left in sink {ret}"
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user