mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
multioutput ScheduleItem (#3699)
* refactor realize.py * update docs * update test_sched * update runners and devices * update openpilot and unit tests * cleanup runner lowering * update more tests
This commit is contained in:
@@ -240,7 +240,7 @@ result = Tensor(2.0).realize() + Tensor(3.0).realize()
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.realize import create_schedule
|
||||
sched = create_schedule([result.lazydata])
|
||||
linearizer = Linearizer(sched[-1].ast, opts=ClangCompiler.linearizer_opts)
|
||||
linearizer = Linearizer(*sched[-1].ast, opts=ClangCompiler.linearizer_opts)
|
||||
linearizer.linearize()
|
||||
|
||||
# print the uops
|
||||
|
||||
@@ -86,11 +86,11 @@ out = a.e(BinaryOps.ADD, b)
|
||||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = create_schedule([out])
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
for si in sched: print(si.ast[0].op) # NOTE: the first two convert it to CLANG
|
||||
|
||||
# DEBUGGING: print the compute ast as a tree
|
||||
from tinygrad.features.graph import print_tree
|
||||
print_tree(sched[-1].ast)
|
||||
print_tree(sched[-1].ast[0])
|
||||
# NOTE: sched[-1].ast is the same as st_0 above
|
||||
|
||||
# run that schedule
|
||||
|
||||
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
||||
x = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(x)
|
||||
sched = create_schedule([out.lazydata], seen)
|
||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
sched = [x for x in sched if x.ast[0].op not in LoadOps]
|
||||
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
||||
@@ -37,24 +37,24 @@ if __name__ == "__main__":
|
||||
total_tm = 0
|
||||
running_gflops = 0
|
||||
for i,si in enumerate(sched):
|
||||
rawbufs = bufs_from_lin(Linearizer(si.ast))
|
||||
rawbufs = bufs_from_lin(Linearizer(*si.ast))
|
||||
|
||||
# "linearize" the op into uops in different ways
|
||||
lins:List[Linearizer] = []
|
||||
|
||||
# always try hand coded opt
|
||||
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
|
||||
lin.hand_coded_optimizations()
|
||||
lins.append(lin)
|
||||
|
||||
# maybe try tensor cores
|
||||
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
|
||||
if lin.apply_tensor_cores():
|
||||
lins.append(lin)
|
||||
|
||||
# try a beam search
|
||||
if beam:=getenv("BEAM"):
|
||||
lin = Linearizer(si.ast, opts=device.compiler.linearizer_opts)
|
||||
lin = Linearizer(*si.ast, opts=device.compiler.linearizer_opts)
|
||||
lin = beam_search(lin, rawbufs, beam, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
lins.append(lin)
|
||||
|
||||
|
||||
@@ -39,15 +39,15 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
depends.add(si.out)
|
||||
for out in si.outputs: depends.add(out)
|
||||
|
||||
# run all kernels that don't depend on the inputs
|
||||
# NOTE: there's two extra kernels due to fusions that now happen since the weights aren't realized
|
||||
schedule, schedule_independent = partition(schedule, lambda si: si.out in depends)
|
||||
schedule, schedule_independent = partition(schedule, lambda si: any(out in depends for out in si.outputs))
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no loadops in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op not in LoadOps or si.out in input_lb for si in schedule), "has loadops, can't compile to Thneed"
|
||||
assert all(si.ast[0].op not in LoadOps or out in input_lb for si in schedule for out in si.outputs), "has loadops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[str, Tensor]):
|
||||
@@ -88,9 +88,9 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
|
||||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
for si in schedule: lower_schedule_item(si)([si.out.realized] + [x.realized for x in si.inputs], {})
|
||||
for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {})
|
||||
|
||||
new_tinygrad_out = Tensor(schedule[-1].out).numpy()
|
||||
new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
@@ -102,13 +102,13 @@ if __name__ == "__main__":
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast[0].op not in LoadOps)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent)
|
||||
run_schedule(schedule_input)
|
||||
with Context(DEBUG=max(DEBUG.value, 2), BEAM=getenv("LATEBEAM")):
|
||||
image_count = sum(isinstance(si.out.dtype, ImageDType) for si in schedule)
|
||||
image_count = sum(isinstance(out.dtype, ImageDType) for si in schedule for out in si.outputs)
|
||||
print(f"**** running real kernels {image_count}/{len(schedule)} images ****")
|
||||
|
||||
GlobalCounters.reset()
|
||||
|
||||
2
test/external/external_test_hip_compile.py
vendored
2
test/external/external_test_hip_compile.py
vendored
@@ -10,7 +10,7 @@ class TestHIPCompileSpeed(unittest.TestCase):
|
||||
def test_hip_compile(self):
|
||||
a, b = Tensor([1,2,3,4,5]), Tensor([1,2,3,4,5])
|
||||
out = a + b
|
||||
lin = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
lin = Linearizer(create_schedule([out.lazydata])[-1].ast[0])
|
||||
lin.linearize()
|
||||
|
||||
reference = """
|
||||
|
||||
@@ -13,10 +13,10 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
# first run to init the weights, they are saved in seen
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
print(sched[0])
|
||||
for arg in [sched[0].out, *sched[0].inputs]:
|
||||
for arg in [sched[0].outputs[0], *sched[0].inputs]:
|
||||
print(arg.st)
|
||||
assert len(arg.st.views) == 1
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ def universal_test(a, b, dtype, op):
|
||||
def universal_test_unary(a, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
out: Tensor = op[0](Tensor([a], dtype=dtype))
|
||||
ast = create_schedule([out.lazydata])[-1].ast
|
||||
ast = create_schedule([out.lazydata])[-1].ast[0]
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
||||
if dtype in dtypes_float:
|
||||
|
||||
@@ -74,7 +74,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
a = a.sum()
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 1
|
||||
assert sched[0].ast.src[0].op is ReduceOps.SUM
|
||||
assert sched[0].ast[0].src[0].op is ReduceOps.SUM
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
@@ -82,7 +82,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
assert s.ast.src[0].op is ReduceOps.SUM
|
||||
assert s.ast[0].src[0].op is ReduceOps.SUM
|
||||
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
@@ -90,7 +90,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = create_schedule([a.lazydata])
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
assert s.ast.src[0].op is ReduceOps.SUM
|
||||
assert s.ast[0].src[0].op is ReduceOps.SUM
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -41,7 +41,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
|
||||
@@ -74,7 +74,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -86,7 +86,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -102,7 +102,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -120,7 +120,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack([a, b])
|
||||
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -130,7 +130,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor(2), Tensor(3)
|
||||
r = a * b
|
||||
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
@@ -139,14 +139,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
for tensor_dtype, acc_dtype in (
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Linearizer(create_schedule([a.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([a.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(create_schedule([c.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([c.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
@@ -185,17 +185,17 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
|
||||
lin.limit_dims_to_max(global_max=[16, 16, 16], local_max=[16, 16, 16])
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor.ones(256,256).sum()
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
lin = Linearizer(*sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
|
||||
def test_simplify_uop(self):
|
||||
@@ -221,7 +221,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def helper(t, max_ops=0):
|
||||
sched = create_schedule([t.lazydata])
|
||||
assert len(sched) == 1
|
||||
k = Linearizer(sched[0].ast)
|
||||
k = Linearizer(*sched[0].ast)
|
||||
k.hand_coded_optimizations()
|
||||
uops = list(k.linearize().uops)
|
||||
# ignore kernel optimized IF/LOOP statements for now
|
||||
@@ -242,8 +242,8 @@ def helper_realized_ast(r:Tensor):
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# allocate an output buffer
|
||||
output_buffer = Buffer(s[-1].out.device, prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype)
|
||||
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
|
||||
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype)
|
||||
return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs]
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
@@ -260,7 +260,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -272,7 +272,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.shift_to(0, 4) # float4 dimension
|
||||
k.shift_to(0, 2, insert_before=k.shape_len-1)
|
||||
k.upcast()
|
||||
@@ -288,7 +288,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations() # implicit trigger float4 dim
|
||||
k.linearize()
|
||||
|
||||
@@ -300,7 +300,7 @@ class TestFloat4(unittest.TestCase):
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
|
||||
@@ -318,7 +318,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
|
||||
@@ -333,7 +333,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# don't.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -349,7 +349,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.shift_to(0, 4, top=True) # top axes are float4 axes
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -365,7 +365,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -380,7 +380,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# should float4 b but not a
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -393,7 +393,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
|
||||
|
||||
s = create_schedule([layer_2.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
|
||||
# masked upcast should upcast masked axis of size 7
|
||||
@@ -406,7 +406,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
s = create_schedule([monster.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
|
||||
# should upcast the two Tensor.stacks
|
||||
@@ -420,7 +420,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
wino_schedule = create_schedule([out.lazydata])
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Linearizer(si.ast)
|
||||
k = Linearizer(*si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel)
|
||||
if len(k.bufs) < 36: continue # not a tile transform kernel (there's a permute kernel at the end)
|
||||
@@ -433,7 +433,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
|
||||
for si in backward_schedule:
|
||||
k = Linearizer(si.ast)
|
||||
k = Linearizer(*si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
if len(k.bufs) < 20: continue # not a tile transform kernel
|
||||
@@ -447,7 +447,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
|
||||
|
||||
s = create_schedule([layer_3.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
|
||||
# check that we don't do too many upcasts
|
||||
@@ -462,7 +462,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
c = a @ b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k = Linearizer(*s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
assert k.group_for_reduces == 1
|
||||
@@ -777,7 +777,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
x, y = Tensor.randn(64,64), Tensor.randn(64,64)
|
||||
out = x.matmul(y)
|
||||
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -791,7 +791,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
x = Tensor.randn((4,3,6,6)).realize()
|
||||
out = x.flip((0,1)).contiguous()
|
||||
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -808,7 +808,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
|
||||
opts = [Opt(OptOps.LOCAL, 0, 4), Opt(OptOps.GROUPTOP, 0, 8),
|
||||
Opt(OptOps.UNROLL, 0, 4), Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UPCAST, 1, 2)] # upcast accs in both reduces
|
||||
k = Linearizer(create_schedule([out.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([out.lazydata])[-1].ast)
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
k.linearize()
|
||||
|
||||
@@ -830,7 +830,7 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
||||
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
|
||||
@@ -298,8 +298,8 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(bn): p.shard_(devices).realize()
|
||||
|
||||
out = bn(t)
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
|
||||
assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device"
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY]
|
||||
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device"
|
||||
asts = [sched.ast for sched in scheds]
|
||||
assert len(asts) == 8, len(asts)
|
||||
# test case to show that ast can be different on devices
|
||||
@@ -634,4 +634,4 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
assert unsynced_si
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -17,22 +17,24 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in create_schedule([pre.lazydata], seen.copy()):
|
||||
if GRAPH: realized_lazybuffer(s.out, 0)
|
||||
seen.add(s.out)
|
||||
for i,out in enumerate(s.outputs):
|
||||
if GRAPH: realized_lazybuffer(out, 0)
|
||||
seen.add(out)
|
||||
sched = create_schedule([t.lazydata], seen)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
|
||||
for i,s in enumerate(sched):
|
||||
for out in s.outputs: realized_lazybuffer(out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
print_tree(s.ast)
|
||||
for op in s.ast: print_tree(op)
|
||||
assert len(sched) == allowed
|
||||
# test the (non loadops) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op in LoadOps: continue
|
||||
l = Linearizer(s.ast)
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
l = Linearizer(*s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
|
||||
|
||||
@@ -9,14 +9,15 @@ from tinygrad.tensor import Tensor
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
|
||||
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
|
||||
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype)
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
|
||||
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op not in LoadOps][0]
|
||||
rawbufs = bufs_from_lin(lin:=Linearizer(si.ast))
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
rawbufs = bufs_from_lin(lin:=Linearizer(*si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
assert all(isinstance(r, Buffer) for r in rawbufs)
|
||||
|
||||
@@ -23,10 +23,10 @@ class TestWinograd(unittest.TestCase):
|
||||
sched = create_schedule([out.lazydata])
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op in LoadOps: continue
|
||||
ops = s.ast.lazyops
|
||||
if s.ast[0].op in LoadOps: continue
|
||||
ops = [out.lazyops for out in s.ast]
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Linearizer(s.ast)
|
||||
l = Linearizer(*s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
assert len(l.sts) <= 256 # just the current value to prevent regression
|
||||
@@ -69,4 +69,4 @@ class TestWinograd(unittest.TestCase):
|
||||
print(f"mem: normal {mem_normal:9d} wino {mem_wino:9d} ratio {mem_ratio:.2f}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -242,22 +242,22 @@ class Compiled:
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
return ret
|
||||
|
||||
def get_linearizer(self, ast:LazyOp) -> Linearizer:
|
||||
def get_linearizer(self, *ast:LazyOp) -> Linearizer:
|
||||
assert self.compiler is not None, "compiler is required to build AST"
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.features.graph import print_tree
|
||||
print_tree(ast)
|
||||
for op in ast: print_tree(op)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
k = Linearizer(ast, opts=self.compiler.linearizer_opts)
|
||||
k = Linearizer(*ast, opts=self.compiler.linearizer_opts)
|
||||
k.required_optimizations()
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1:
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, opts=self.compiler.linearizer_opts)))
|
||||
lins.append(("hc", Linearizer(*ast, opts=self.compiler.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
kb = Linearizer(ast, opts=self.compiler.linearizer_opts)
|
||||
kb = Linearizer(*ast, opts=self.compiler.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin
|
||||
test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization
|
||||
@@ -268,4 +268,4 @@ class Compiled:
|
||||
return k
|
||||
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(ast))
|
||||
def get_runner(self, *ast:LazyOp) -> CompiledASTRunner: return self.to_program(self.get_linearizer(*ast))
|
||||
|
||||
@@ -40,8 +40,8 @@ class ConstBuffer:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: LazyOp
|
||||
out: LazyBuffer
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from typing import List, Dict, Optional, cast, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, Compiled, BufferOptions
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer
|
||||
from tinygrad.features.graph import realized_lazybuffer, log_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
@@ -27,55 +27,58 @@ class SyncOp(JITRunner):
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
assert all(si.out.device == x.device for x in si.inputs) or si.ast.op in {LoadOps.COPY, LoadOps.WAIT}, \
|
||||
f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
|
||||
if si.ast.op is LoadOps.EMPTY: return None
|
||||
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and si.out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
|
||||
if si.ast[0].op not in {LoadOps.COPY, LoadOps.WAIT}: assert len(set(x.device for x in si.outputs+si.inputs)) == 1
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
if ast.op in {LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY} and out.device.startswith("HIP") and si.inputs[0].device.startswith("HIP"):
|
||||
from tinygrad.runtime.ops_hip import HIPSyncEvent, HIPWaitEvent
|
||||
if si.ast.op is LoadOps.SYNC: return HIPSyncEvent(si.out)
|
||||
if si.ast.op is LoadOps.WAIT: return HIPWaitEvent(si.out.device)
|
||||
if si.ast.op in {LoadOps.SYNC, LoadOps.WAIT} and si.out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"):
|
||||
if ast.op is LoadOps.SYNC: return HIPSyncEvent(out)
|
||||
if ast.op is LoadOps.WAIT: return HIPWaitEvent(out.device)
|
||||
if ast.op in {LoadOps.SYNC, LoadOps.WAIT} and out.device.startswith("HSA") and si.inputs[0].device.startswith("HSA"):
|
||||
# Our HSA runtime handles synchronization
|
||||
if si.ast.op is LoadOps.SYNC: return None
|
||||
if si.ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[si.out.device].allocator, 'transfer') and type(Device[si.out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
|
||||
if ast.op is LoadOps.SYNC: return None
|
||||
if ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and type(Device[out.device]) is type(Device[si.inputs[0].device]): return BufferXfer()
|
||||
if si.inputs[0].device.startswith("DISK"): return BufferRead()
|
||||
return BufferCopy()
|
||||
if si.ast.op is LoadOps.CUSTOM: return CustomOp(si.ast.arg)
|
||||
if si.ast.op is LoadOps.SYNC: return SyncOp(si.out.device) if isinstance(Device[si.out.device], Compiled) else None
|
||||
if si.ast.op is LoadOps.WAIT: return None
|
||||
return Device[si.out.device].get_runner(si.ast)
|
||||
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
|
||||
if ast.op is LoadOps.SYNC: return SyncOp(out.device) if isinstance(Device[out.device], Compiled) else None
|
||||
return None
|
||||
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
def run_schedule(schedule:List[ScheduleItem]):
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
if logops and si.ast.op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
|
||||
# get the program
|
||||
prg = lower_schedule_item(si)
|
||||
|
||||
# invalidate the output buffer if there's a non contig usage of it in inputs
|
||||
if si.out.output_buffer is not None:
|
||||
for i,a in enumerate(si.inputs):
|
||||
if a.realized == si.out.output_buffer:
|
||||
if any(not x.arg.st.contiguous for x in si.ast.lazyops if x.op is BufferOps.LOAD and x.arg.idx == i+1):
|
||||
si.out.output_buffer = None
|
||||
break
|
||||
for out_op, out in zip(si.ast, si.outputs):
|
||||
# invalidate the output buffer if there's a non contig usage of it in inputs
|
||||
if out.output_buffer is not None:
|
||||
for i,a in enumerate(si.inputs):
|
||||
if a.realized == out.output_buffer:
|
||||
if any(not x.arg.st.contiguous for x in out_op.lazyops if x.op is BufferOps.LOAD and x.arg.idx == i+1):
|
||||
out.output_buffer = None
|
||||
break
|
||||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if si.out.size > 0:
|
||||
options = BufferOptions(host=True, signal=True) if si.ast.op is LoadOps.SYNC else None
|
||||
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
||||
Buffer(si.out.device, si.out.size, si.out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
|
||||
del si.out.srcs
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0:
|
||||
options = BufferOptions(host=True, signal=True) if si.ast[0].op is LoadOps.SYNC else None
|
||||
out.realized = out.output_buffer if out.output_buffer is not None else \
|
||||
Buffer(out.device, out.size, out.dtype, "PLACEHOLDER" if getattr(prg, "skip_allocation", False) else None, options=options)
|
||||
del out.srcs
|
||||
|
||||
# run the function (put it in JIT)
|
||||
real_buffers = [x.realized for x in (si.out,)+si.inputs if x.size != 0]
|
||||
real_buffers = [x.realized for x in si.outputs+si.inputs if x.size != 0]
|
||||
assert all(x is not None for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
|
||||
if prg: prg.exec(cast(List[Buffer], real_buffers), si.var_vals)
|
||||
elif si.out.size > 0: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
||||
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
||||
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.st.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
|
||||
if GRAPH:
|
||||
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)
|
||||
|
||||
# *** schedule creation ***
|
||||
|
||||
@@ -135,7 +138,7 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem((op,), (out,), tuple(inputs), var_vals)]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
|
||||
@@ -73,4 +73,6 @@ class DiskRunner(JITRunner):
|
||||
class DiskDevice(Compiled):
|
||||
def __init__(self, device:str): super().__init__(device, DiskAllocator(device[len("disk:"):]), None, None)
|
||||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def get_runner(self, ast:LazyOp): return DiskRunner(ast)
|
||||
def get_runner(self, *ast:LazyOp):
|
||||
assert len(ast) == 1, "DiskRunner doesn't support multioutput kernels."
|
||||
return DiskRunner(ast[0])
|
||||
|
||||
Reference in New Issue
Block a user