lowerer is kernel [run_process_replay] (#5437)

This commit is contained in:
George Hotz
2024-07-12 18:50:55 -07:00
committed by GitHub
parent b8342fb085
commit 03c2dc8bd7
33 changed files with 215 additions and 213 deletions

View File

@@ -4,8 +4,7 @@ import unittest
from dataclasses import replace
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
from tinygrad.codegen.lowerer import Lowerer
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel
from tinygrad.codegen.lowerer import get_grouped_dims
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.device import Device, Buffer
@@ -38,7 +37,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
realized_ast = sched[-1].ast
run_schedule(sched)
out = r.numpy()
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
@@ -54,7 +53,7 @@ def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, d
r = a.matmul(b, acc_dtype=dtype_out)
sched = create_schedule([r.lazydata])
realized_ast = sched[-1].ast
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
k.linearize()
wmmas = len([uop for uop in k.uops if uop.op is UOps.WMMA])
@@ -211,7 +210,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skip("AST has implicit movement ops")
def test_early_end_local(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
self.assertEqual(len(endifs:=[x for x in k.uops if x.op is UOps.ENDIF]), len(ifs:=[x for x in k.uops if x.op is UOps.IF]))
@@ -243,7 +242,7 @@ class TestLinearizer(unittest.TestCase):
LazyOp(op=BufferOps.STORE, src=(ast2,), arg=MemBuffer(idx=order.index(2), dtype=dtypes.float, st=ShapeTracker.from_shape((1,)))),
LazyOp(op=BufferOps.STORE, src=(ast3,), arg=MemBuffer(idx=order.index(3), dtype=dtypes.float, st=ShapeTracker.from_shape((1,))))
]
k = Lowerer([asts[i] for i in order])
k = Kernel([asts[i] for i in order])
def recursive_reduceops(x: LazyOp): return [c for v in x.src for c in recursive_reduceops(v)] + [v for v in list(x.src) if v.op in ReduceOps]
for i,r in enumerate(k.reduceops): assert not any([r in recursive_reduceops(x) for x in k.reduceops[:i]]), "reduceops are out of order"
x = Tensor.randn(32).realize()
@@ -256,7 +255,7 @@ class TestLinearizer(unittest.TestCase):
def test_multireduce_store_locals(self):
# ensure the result of local reducop is stored and loaded back into every thread for future use
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
local_buf = [u for u in k.uops if u.op is UOps.DEFINE_LOCAL]
@@ -273,7 +272,7 @@ class TestLinearizer(unittest.TestCase):
def test_multireduce_upcasting(self):
# when upcasting multiple reductions, ensure ast_parse will create multiple uops even when using the result of past reductions
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 7), strides=(7, 1), offset=0, mask=None, contiguous=True),),))),), arg=(1,)),), arg=None),)),), arg=(1,)),), arg=MemBuffer(idx=0, dtype=dtypes.float32, st=ShapeTracker(views=(View(shape=(8, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.upcast()
k.linearize()
define_globals = [u for u in k.uops if u.op is UOps.DEFINE_GLOBAL]
@@ -302,7 +301,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skip("AST has implicit movement ops")
def test_multireduce_loop_scope(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None))), LazyOp(op=UnaryOps.RECIP, src=(LazyOp(op=UnaryOps.SQRT, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None), LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))), LazyOp(op=UnaryOps.NEG, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(864, 32, 1), offset=0, mask=None, contiguous=True),)))),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 32), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=None)), arg=None)), arg=None),), arg=(2,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.03125, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=None),)),),),), arg=(2,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3, 27, 1), strides=(27, 1, 0), offset=0, mask=None, contiguous=True),),))), # noqa: E501
k = Lowerer(ast)
k = Kernel(ast)
k.hand_coded_optimizations()
k.linearize()
def get_recursive_children(x:UOp): return set.union(set(x.src), *[get_recursive_children(v) for v in x.src])
@@ -377,7 +376,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.op is UOps.LOAD])
@@ -395,7 +394,7 @@ class TestLinearizer(unittest.TestCase):
b = LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=DT, st=ST)), VAL))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),), arg=MemBuffer(idx=0, dtype=DT, st=ST))
lin = Lowerer(ast)
lin = Kernel(ast)
lin.linearize()
assert len(lin.uops.uops) <= 7, "too many uops"
@@ -408,7 +407,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
@@ -419,7 +418,7 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.upcast()
k.linearize()
@@ -435,7 +434,7 @@ class TestLinearizer(unittest.TestCase):
def test_upcast_with_locals(self):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -469,7 +468,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack(a, b)
k = Lowerer(create_schedule([r.lazydata])[-1].ast)
k = Kernel(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.op is UOps.ALU])
@@ -479,14 +478,14 @@ class TestLinearizer(unittest.TestCase):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Lowerer(create_schedule([a.lazydata])[-1].ast)
k = Kernel(create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Lowerer(create_schedule([c.lazydata])[-1].ast)
k = Kernel(create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.op is UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
@@ -550,7 +549,7 @@ class TestLinearizer(unittest.TestCase):
c = a.conv2d(b, padding=1, acc_dtype=tc.dtype_out)
realized_ast, real_bufs = helper_realized_ast(c)
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
k.apply_tensor_cores(1, axis=axis, tc_opt=2)
k.linearize()
assert len([uop for uop in k.uops if uop.op is UOps.WMMA]) > 0, "tensor core not triggered"
@@ -567,7 +566,7 @@ class TestLinearizer(unittest.TestCase):
# check that get_linearizer_actions produces all 9 options
from tinygrad.engine.search import get_linearizer_actions
tc_actions = [k for i, k in get_linearizer_actions(Lowerer(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
tc_actions = [k for i, k in get_linearizer_actions(Kernel(realized_ast), False).items() if k.applied_opts[0].op == OptOps.TC]
assert len(tc_actions) == 9, f"get_linearizer_actions should contain 9 possible TC actions, only got {len(tc_actions)}"
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
@@ -677,7 +676,7 @@ class TestLinearizer(unittest.TestCase):
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
assert len(sched) == 1
lin = Lowerer(sched[0].ast)
lin = Kernel(sched[0].ast)
assert sum(u.arg is UnaryOps.RECIP for u in lin.linearize().uops) == max_ops, msg
a = Tensor.rand((4,4))
@@ -697,7 +696,7 @@ class TestLinearizer(unittest.TestCase):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
assert len(sched) == 1
lin = Lowerer(sched[0].ast)
lin = Kernel(sched[0].ast)
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
def test_assign_fold(self):
@@ -716,7 +715,7 @@ class TestLinearizer(unittest.TestCase):
sched_copy = sched[:]
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
lin = Lowerer(sched_copy[-1].ast)
lin = Kernel(sched_copy[-1].ast)
lin.hand_coded_optimizations()
lin.linearize()
assert not any(u.arg == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded"
@@ -844,7 +843,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -856,7 +855,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
@@ -872,7 +871,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -884,7 +883,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
@@ -902,7 +901,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.upcast()
k.linearize()
@@ -917,7 +916,7 @@ class TestFloat4(unittest.TestCase):
# don't.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.upcast()
k.upcast()
k.linearize()
@@ -933,7 +932,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
@@ -949,7 +948,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -964,7 +963,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = create_schedule([c.lazydata])[0]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -977,7 +976,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = create_schedule([layer_2.lazydata])[-1]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
@@ -989,7 +988,7 @@ class TestHandCodedOpts(unittest.TestCase):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Lowerer(s.ast)
k = Kernel(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
@@ -1003,7 +1002,7 @@ class TestHandCodedOpts(unittest.TestCase):
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Lowerer(si.ast)
k = Kernel(si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
@@ -1016,7 +1015,7 @@ class TestHandCodedOpts(unittest.TestCase):
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Lowerer(si.ast)
k = Kernel(si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
@@ -1058,11 +1057,11 @@ def helper_linearizer_opt(r:Union[Tensor, List[Tensor]], *args, **kwargs):
return _helper_linearizer_opt_ast(realized_ast, real_bufs, *args, **kwargs)
def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts=[],
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Lowerer]:
lins: List[Lowerer] = []
apply_tc=False, atol=1e-4, rtol=1e-4, color_sizes=[], wanna_output=[]) -> List[Kernel]:
lins: List[Kernel] = []
outbufs = [real_bufs[i] for i in range(len(realized_ast.src))]
def get_prg(k:Lowerer): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def get_prg(k:Kernel): return CompiledRunner(replace(k.to_program(), dname=Device.DEFAULT))
def check_opt(opts, create_k, expected_color_size):
k = create_k()
@@ -1082,7 +1081,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Get baseline if it is not provided, which is not optimized at all.
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
lins.append(k)
prg = get_prg(k)
prg.exec(real_bufs)
@@ -1092,7 +1091,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
# Check correctness of handcoded optimiztions.
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
lins.append(k)
k.hand_coded_optimizations()
prg = get_prg(k)
@@ -1101,7 +1100,7 @@ def _helper_linearizer_opt_ast(realized_ast:LazyOp, real_bufs:List[Buffer], opts
for i, buf in enumerate(outbufs):
np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), _to_np_dtype(buf.dtype)), wanna_output[i], atol=atol, rtol=rtol)
for i, x in enumerate(opts): # Check custom transformations if any.
check_opt(x, lambda: Lowerer(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
check_opt(x, lambda: Kernel(realized_ast), color_sizes[i] if i < len(color_sizes) else None)
return lins
# creates a back-to-back multi reduce AST by merging r0 and r1.
@@ -1438,14 +1437,14 @@ class TestKernelOpts(unittest.TestCase):
[Opt(OptOps.LOCAL, 0, 2), Opt(OptOps.LOCAL, 2, 2)],
]
for x in invalid_opts:
k = Lowerer(realized_ast)
k = Kernel(realized_ast)
with self.assertRaises(AssertionError):
assert k.apply_tensor_cores(use_tensor_cores=1, extra_opts=x), "no valid tensor core" # for METAL in runners
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_buf_index_not_found_tensor_core(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPNE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Lowerer(ast, opts=Device[Device.DEFAULT].renderer)
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
@@ -1462,7 +1461,7 @@ class TestKernelOpts(unittest.TestCase):
c, d = Tensor.rand(M, K, dtype=tc.dtype_in).realize(), Tensor.rand(K, N, dtype=tc.dtype_in).realize()
r1 = c.matmul(d, acc_dtype=tc.dtype_out)
ast = _temp_create_multireduce_ast(r0, r1)
lin = Lowerer(ast)
lin = Kernel(ast)
lin.apply_opt(Opt(op=OptOps.TC, axis=0, amt=2))
lin.linearize()
result = compare_linearizer(lin)