multioutput ScheduleItem (#3699)

* refactor realize.py

* update docs

* update test_sched

* update runners and devices

* update openpilot and unit tests

* cleanup runner lowering

* update more tests
This commit is contained in:
qazal
2024-03-13 17:59:38 +02:00
committed by GitHub
parent 08064a0e29
commit 337cd53444
17 changed files with 125 additions and 117 deletions

View File

@@ -240,7 +240,7 @@ result = Tensor(2.0).realize() + Tensor(3.0).realize()
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.realize import create_schedule
sched = create_schedule([result.lazydata])
linearizer = Linearizer(sched[-1].ast, opts=ClangCompiler.linearizer_opts)
linearizer = Linearizer(*sched[-1].ast, opts=ClangCompiler.linearizer_opts)
linearizer.linearize()
# print the uops

View File

@@ -86,11 +86,11 @@ out = a.e(BinaryOps.ADD, b)
# schedule the computation as a list of kernels
sched = create_schedule([out])
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
for si in sched: print(si.ast[0].op) # NOTE: the first two convert it to CLANG
# DEBUGGING: print the compute ast as a tree
from tinygrad.features.graph import print_tree
print_tree(sched[-1].ast)
print_tree(sched[-1].ast[0])
# NOTE: sched[-1].ast is the same as st_0 above
# run that schedule

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = create_schedule([out.lazydata], seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
sched = [x for x in sched if x.ast[0].op not in LoadOps]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
@@ -37,24 +37,24 @@ if __name__ == "__main__":
total_tm = 0
running_gflops = 0
for i,si in enumerate(sched):
rawbufs = bufs_from_lin(Linearizer(si.ast))
rawbufs = bufs_from_lin(Linearizer(*si.ast))
# "linearize" the op into uops in different ways
lins:List[Linearizer] = []
# always try hand coded opt
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
lin.hand_coded_optimizations()
lins.append(lin)
# maybe try tensor cores
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
if lin.apply_tensor_cores():
lins.append(lin)
# try a beam search
if beam:=getenv("BEAM"):
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
lins.append(lin)

View File

@@ -39,15 +39,15 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
depends = set(input_lb)
for si in schedule:
if any(b in depends for b in si.inputs):
depends.add(si.out)
for out in si.outputs: depends.add(out)
# run all kernels that don't depend on the inputs
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
assert all(si.ast[0].op not in LoadOps or out in input_lb for si in schedule for out in si.outputs), "has loadops, can't compile to Thneed"
return schedule, schedule_independent, inputs
def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[str, Tensor]):
@@ -88,9 +88,9 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
# run code (all buffers have been allocated)
GlobalCounters.reset()
for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {})
for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {})
new_tinygrad_out = Tensor(schedule[-1].out).numpy()
new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy()
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!")
@@ -102,13 +102,13 @@ if __name__ == "__main__":
#exit(0)
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
schedule, schedule_input = partition(schedule, lambda x: x.ast[0].op not in LoadOps)
print(f"{len(schedule_input)} inputs")
run_schedule(schedule_independent)
run_schedule(schedule_input)
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
GlobalCounters.reset()

View File

@@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
def test_hip_compile(self):
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
out = a + b
lin = Linearizer(create_schedule([out.lazydata])[-1].ast)
lin = Linearizer(create_schedule([out.lazydata])[-1].ast[0])
lin.linearize()
reference = """

View File

@@ -13,10 +13,10 @@ class TestConvShapetracker(unittest.TestCase):
# first run to init the weights, they are saved in seen
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast[0].op not in LoadOps]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
print(sched[0])
for arg in [sched[0].out, *sched[0].inputs]:
for arg in [sched[0].outputs[0], *sched[0].inputs]:
print(arg.st)
assert len(arg.st.views) == 1

View File

@@ -65,7 +65,7 @@ def universal_test(a, b, dtype, op):
def universal_test_unary(a, dtype, op):
if not isinstance(op, tuple): op = (op, op)
out: Tensor = op[0](Tensor([a], dtype=dtype))
ast = create_schedule([out.lazydata])[-1].ast
ast = create_schedule([out.lazydata])[-1].ast[0]
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
if dtype in dtypes_float:

View File

@@ -74,7 +74,7 @@ class TestReduceOp(unittest.TestCase):
a = a.sum()
sched = create_schedule([a.lazydata])
assert len(sched) == 1
assert sched[0].ast.src[0].op is ReduceOps.SUM
assert sched[0].ast[0].src[0].op is ReduceOps.SUM
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
@@ -82,7 +82,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
assert s.ast.src[0].op is ReduceOps.SUM
assert s.ast[0].src[0].op is ReduceOps.SUM
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
@@ -90,7 +90,7 @@ class TestReduceOp(unittest.TestCase):
sched = create_schedule([a.lazydata])
assert len(sched) == 2
for s in sched:
assert s.ast.src[0].op is ReduceOps.SUM
assert s.ast[0].src[0].op is ReduceOps.SUM
if __name__ == "__main__":
unittest.main()

View File

@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
@@ -74,7 +74,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -86,7 +86,7 @@ class TestLinearizer(unittest.TestCase):
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
r = Tensor.conv2d(x,w,padding=1).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.upcast()
k.linearize()
@@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -120,7 +120,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -130,7 +130,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor(2), Tensor(3)
r = a * b
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
@@ -139,14 +139,14 @@ class TestLinearizer(unittest.TestCase):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(create_schedule([a.lazydata])[-1].ast)
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(create_schedule([c.lazydata])[-1].ast)
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
@@ -185,17 +185,17 @@ class TestLinearizer(unittest.TestCase):
def test_limit_dims_to_max_5d_global(self):
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin = Linearizer(*sched[0].ast)
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16])
def test_sum_collapse(self):
t = Tensor.ones(256,256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
lin = Linearizer(*sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
def test_simplify_uop(self):
@@ -221,7 +221,7 @@ class TestLinearizer(unittest.TestCase):
def helper(t, max_ops=0):
sched = create_schedule([t.lazydata])
assert len(sched) == 1
k = Linearizer(sched[0].ast)
k = Linearizer(*sched[0].ast)
k.hand_coded_optimizations()
uops = list(k.linearize().uops)
# ignore kernel optimized IF/LOOP statements for now
@@ -242,8 +242,8 @@ def helper_realized_ast(r:Tensor):
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
# allocate an output buffer
output_buffer = Buffer(s[-1].out.device, prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype)
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype)
return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs]
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "need backends that support float4")
class TestFloat4(unittest.TestCase):
@@ -260,7 +260,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -272,7 +272,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
@@ -288,7 +288,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -300,7 +300,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
@@ -318,7 +318,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.upcast()
k.linearize()
@@ -333,7 +333,7 @@ class TestFloat4(unittest.TestCase):
# don't.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.upcast()
k.upcast()
k.linearize()
@@ -349,7 +349,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
@@ -365,7 +365,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -380,7 +380,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -393,7 +393,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = create_schedule([layer_2.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
# masked upcast should upcast masked axis of size 7
@@ -406,7 +406,7 @@ class TestHandCodedOpts(unittest.TestCase):
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = create_schedule([monster.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
# should upcast the two Tensor.stacks
@@ -420,7 +420,7 @@ class TestHandCodedOpts(unittest.TestCase):
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Linearizer(si.ast)
k = Linearizer(*si.ast)
k.hand_coded_optimizations()
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
@@ -433,7 +433,7 @@ class TestHandCodedOpts(unittest.TestCase):
out.mean().backward()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Linearizer(si.ast)
k = Linearizer(*si.ast)
k.hand_coded_optimizations()
k.linearize()
if len(k.bufs) < 20: continue # not a tile transform kernel
@@ -447,7 +447,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
s = create_schedule([layer_3.lazydata])[-1]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
# check that we don't do too many upcasts
@@ -462,7 +462,7 @@ class TestHandCodedOpts(unittest.TestCase):
c = a @ b
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k = Linearizer(*s.ast)
k.hand_coded_optimizations()
assert k.group_for_reduces == 1
@@ -777,7 +777,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
out = x.matmul(y)
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -791,7 +791,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x = Tensor.randn((4,3,6,6)).realize()
out = x.flip((0,1)).contiguous()
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()
@@ -808,7 +808,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
for opt in opts: k.apply_opt(opt)
k.linearize()
@@ -830,7 +830,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()

View File

@@ -298,8 +298,8 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(bn): p.shard_(devices).realize()
out = bn(t)
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device"
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY]
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts) == 8, len(asts)
# test case to show that ast can be different on devices
@@ -634,4 +634,4 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
assert unsynced_si
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@@ -17,22 +17,24 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if to_prerealize:
for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()):
if GRAPH: realized_lazybuffer(s.out, 0)
seen.add(s.out)
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)
sched = create_schedule([t.lazydata], seen)
if GRAPH:
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
for i,s in enumerate(sched):
for out in s.outputs: realized_lazybuffer(out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("kernel", i+1)
print_tree(s.ast)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
# test the (non loadops) ops linearize
for s in sched:
if s.ast.op in LoadOps: continue
l = Linearizer(s.ast)
if s.ast[0].op in LoadOps: continue
l = Linearizer(*s.ast)
l.hand_coded_optimizations()
l.linearize()

View File

@@ -9,14 +9,15 @@ from tinygrad.tensor import Tensor
class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype)
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(si.ast))
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
assert len(rawbufs) == len(lin.membufs)
assert all(r is not None for r in rawbufs)
assert all(isinstance(r, Buffer) for r in rawbufs)

View File

@@ -23,10 +23,10 @@ class TestWinograd(unittest.TestCase):
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op in LoadOps: continue
ops = s.ast.lazyops
if s.ast[0].op in LoadOps: continue
ops = [out.lazyops for out in s.ast]
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Linearizer(s.ast)
l = Linearizer(*s.ast)
l.hand_coded_optimizations()
l.linearize()
assert len(l.sts) <= 256 # just the current value to prevent regression
@@ -69,4 +69,4 @@ class TestWinograd(unittest.TestCase):
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
if __name__ == '__main__':
unittest.main(verbosity=2)
unittest.main(verbosity=2)

View File

@@ -242,22 +242,22 @@ class Compiled:
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret
def get_linearizer(self, ast:LazyOp) -> Linearizer:
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
assert self.compiler is not None, "compiler is required to build AST"
if DEBUG >= 3:
from tinygrad.features.graph import print_tree
print_tree(ast)
for op in ast: print_tree(op)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, opts=self.compiler.linearizer_opts)
k = Linearizer(*ast, opts=self.compiler.linearizer_opts)
k.required_optimizations()
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, opts=self.compiler.linearizer_opts)))
lins.append(("hc", Linearizer(*ast, opts=self.compiler.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
kb = Linearizer(ast, opts=self.compiler.linearizer_opts)
kb = Linearizer(*ast, opts=self.compiler.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
@@ -268,4 +268,4 @@ class Compiled:
return k
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast))
def get_runner(self, *ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(*ast))

View File

@@ -40,8 +40,8 @@ class ConstBuffer:
@dataclass(frozen=True)
class ScheduleItem:
ast: LazyOp
out: LazyBuffer
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from typing import List, Dict, Optional, cast, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, Compiled, BufferOptions
from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer
from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
@@ -27,55 +27,58 @@ class SyncOp(JITRunner):
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op in {LoadOps.COPY, LoadOps.WAIT}, \
f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
if si.ast.op is LoadOps.EMPTY: return None
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
if si.ast[0].op not in {LoadOps.COPY, LoadOps.WAIT}: assert len(set(x.device for x in si.outputs+si.inputs)) == 1
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
if ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
from tinygrad.runtime.ops_hip import HIPSyncEvent, HIPWaitEvent
if si.ast.op is LoadOps.SYNC: return HIPSyncEvent(si.out)
if si.ast.op is LoadOps.WAIT: return HIPWaitEvent(si.out.device)
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT} and si.out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"):
if ast.op is LoadOps.SYNC: return HIPSyncEvent(out)
if ast.op is LoadOps.WAIT: return HIPWaitEvent(out.device)
if ast.op in {LoadOps.SYNC, LoadOps.WAIT} and out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"):
# Our HSA runtime handles synchronization
if si.ast.op is LoadOps.SYNC: return None
if si.ast.op is LoadOps.COPY:
if hasattr(Device[si.out.device].allocator, 'transfer') and type(Device[si.out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
if ast.op is LoadOps.SYNC: return None
if ast.op is LoadOps.COPY:
if hasattr(Device[out.device].allocator, 'transfer') and type(Device[out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
if si.inputs[0].device.startswith("DISK"): return BufferRead()
return BufferCopy()
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
if si.ast.op is LoadOps.WAIT: return None
return Device[si.out.device].get_runner(si.ast)
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
if ast.op is LoadOps.SYNC: return SyncOp(out.device) if isinstance(Device[out.device], Compiled) else None
return None
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
def run_schedule(schedule:List[ScheduleItem]):
while len(schedule):
si = schedule.pop(0)
if logops and si.ast.op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
# get the program
prg = lower_schedule_item(si)
# invalidate the output buffer if there's a non contig usage of it in inputs
if si.out.output_buffer is not None:
for i,a in enumerate(si.inputs):
if a.realized == si.out.output_buffer:
if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == i+1):
si.out.output_buffer = None
break
for out_op, out in zip(si.ast, si.outputs):
# invalidate the output buffer if there's a non contig usage of it in inputs
if out.output_buffer is not None:
for i,a in enumerate(si.inputs):
if a.realized == out.output_buffer:
if any(not x.arg.st.contiguous for x in out_op.lazyops if x.op is BufferOps.LOAD and x.arg.idx == i+1):
out.output_buffer = None
break
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if si.out.size > 0:
options = BufferOptions(host=True, signal=True) if si.ast.op is LoadOps.SYNC else None
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
del si.out.srcs
for out in si.outputs:
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
if out.size > 0:
options = BufferOptions(host=True, signal=True) if si.ast[0].op is LoadOps.SYNC else None
out.realized = out.output_buffer if out.output_buffer is not None else \
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
del out.srcs
# run the function (put it in JIT)
real_buffers = [x.realized for x in (si.out,)+si.inputs if x.size != 0]
real_buffers = [x.realized for x in si.outputs+si.inputs if x.size != 0]
assert all(x is not None for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
if prg: prg.exec(cast(List[Buffer], real_buffers), si.var_vals)
elif si.out.size > 0: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.st.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
if GRAPH:
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)
# *** schedule creation ***
@@ -135,7 +138,7 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem((op,), (out,), tuple(inputs), var_vals)]
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],

View File

@@ -73,4 +73,6 @@ class DiskRunner(JITRunner):
class DiskDevice(Compiled):
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast:LazyOp): return DiskRunner(ast)
def get_runner(self, *ast:LazyOp):
assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels."
return DiskRunner(ast[0])