multioutput ScheduleItem (#3699)

* refactor realize.py

* update docs

* update test_sched

* update runners and devices

* update openpilot and unit tests

* cleanup runner lowering

* update more tests
This commit is contained in:
qazal
2024-03-13 17:59:38 +02:00
committed by GitHub
parent 08064a0e29
commit 337cd53444
17 changed files with 125 additions and 117 deletions

View File

@@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
def test_hip_compile(self):
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
out = a + b
lin = Linearizer(create_schedule([out.lazydata])[-1].ast)
lin = Linearizer(create_schedule([out.lazydata])[-1].ast[0])
lin.linearize()
reference = """

View File

@@ -13,10 +13,10 @@ class TestConvShapetracker(unittest.TestCase):
# first run to init the weights, they are saved in seen
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast[0].op not in LoadOps]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
print(sched[0])
for arg in [sched[0].out, *sched[0].inputs]:
for arg in [sched[0].outputs[0], *sched[0].inputs]:
print(arg.st)
assert len(arg.st.views) == 1

View File

@@ -65,7 +65,7 @@ def universal_test(a, b, dtype, op):
def universal_test_unary(a, dtype, op):
if not isinstance(op, tuple): op = (op, op)
out: Tensor = op[0](Tensor([a], dtype=dtype))
ast = create_schedule([out.lazydata])[-1].ast
ast = create_schedule([out.lazydata])[-1].ast[0]
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
if dtype in dtypes_float:

View File

@@ -74,7 +74,7 @@ class TestReduceOp(unittest.TestCase):
a = a.sum()
sched = create_schedule([a.lazydata])
assert len(sched) == 1
assert sched[0].ast.src[0].op is ReduceOps.SUM
assert sched[0].ast[0].src[0].op is ReduceOps.SUM
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
@@ -82,7 +82,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
assert s.ast.src[0].op is ReduceOps.SUM
assert s.ast[0].src[0].op is ReduceOps.SUM
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
@@ -90,7 +90,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
assert s.ast.src[0].op is ReduceOps.SUM
assert s.ast[0].src[0].op is ReduceOps.SUM
if __name__ == "__main__":
unittest.main()

View File

@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
@@ -74,7 +74,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -86,7 +86,7 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.upcast()
k.linearize()
@@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -120,7 +120,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -130,7 +130,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor(2), Tensor(3)
r = a * b
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
@@ -139,14 +139,14 @@ class TestLinearizer(unittest.TestCase):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(create_schedule([a.lazydata])[-1].ast)
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(create_schedule([c.lazydata])[-1].ast)
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
@@ -185,17 +185,17 @@ class TestLinearizer(unittest.TestCase):
def test_limit_dims_to_max_5d_global(self):
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin = Linearizer(*sched[0].ast)
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16])
def test_sum_collapse(self):
t = Tensor.ones(256,256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin = Linearizer(*sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_simplify_uop(self):
@@ -221,7 +221,7 @@ class TestLinearizer(unittest.TestCase):
def helper(t, max_ops=0):
sched = create_schedule([t.lazydata])
assert len(sched) == 1
k = Linearizer(sched[0].ast)
k = Linearizer(*sched[0].ast)
k.hand_coded_optimizations()
uops = list(k.linearize().uops)
# ignore kernel optimized IF/LOOP statements for now
@@ -242,8 +242,8 @@ def helper_realized_ast(r:Tensor):
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
# allocate an output buffer
output_buffer = Buffer(s[-1].out.device, prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype)
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype)
return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs]
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@@ -260,7 +260,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -272,7 +272,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
@@ -288,7 +288,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -300,7 +300,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
@@ -318,7 +318,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.upcast()
k.linearize()
@@ -333,7 +333,7 @@ class TestFloat4(unittest.TestCase):
# don't.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.upcast()
k.upcast()
k.linearize()
@@ -349,7 +349,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
@@ -365,7 +365,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -380,7 +380,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -393,7 +393,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = create_schedule([layer_2.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
@@ -406,7 +406,7 @@ class TestHandCodedOpts(unittest.TestCase):
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
@@ -420,7 +420,7 @@ class TestHandCodedOpts(unittest.TestCase):
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Linearizer(si.ast)
k = Linearizer(*si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
@@ -433,7 +433,7 @@ class TestHandCodedOpts(unittest.TestCase):
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Linearizer(si.ast)
k = Linearizer(*si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
@@ -447,7 +447,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
s = create_schedule([layer_3.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
# check that we don't do too many upcasts
@@ -462,7 +462,7 @@ class TestHandCodedOpts(unittest.TestCase):
c = a @ b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert k.group_for_reduces == 1
@@ -777,7 +777,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -791,7 +791,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -808,7 +808,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
@@ -830,7 +830,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()

View File

@@ -298,8 +298,8 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(bn): p.shard_(devices).realize()
out = bn(t)
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device"
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY]
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts) == 8, len(asts)
# test case to show that ast can be different on devices
@@ -634,4 +634,4 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
assert unsynced_si
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -17,22 +17,24 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if to_prerealize:
for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()):
if GRAPH: realized_lazybuffer(s.out, 0)
seen.add(s.out)
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule([t.lazydata], seen)
if GRAPH:
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("kernel", i+1)
print_tree(s.ast)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
# test the (non loadops) ops linearize
for s in sched:
if s.ast.op in LoadOps: continue
l = Linearizer(s.ast)
if s.ast[0].op in LoadOps: continue
l = Linearizer(*s.ast)
l.hand_coded_optimizations()
l.linearize()

View File

@@ -9,14 +9,15 @@ from tinygrad.tensor import Tensor
class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype)
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(si.ast))
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
assert len(rawbufs) == len(lin.membufs)
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)

View File

@@ -23,10 +23,10 @@ class TestWinograd(unittest.TestCase):
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op in LoadOps: continue
ops = s.ast.lazyops
if s.ast[0].op in LoadOps: continue
ops = [out.lazyops for out in s.ast]
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Linearizer(s.ast)
l = Linearizer(*s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
@@ -69,4 +69,4 @@ class TestWinograd(unittest.TestCase):
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
if __name__ == '__main__':
unittest.main(verbosity=2)
unittest.main(verbosity=2)