diff --git a/docs/abstractions.py b/docs/abstractions.py index c8c1e01d6c..8f8ebeb123 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -250,7 +250,8 @@ result = Tensor(2.0).realize() + Tensor(3.0).realize() # use the real Linearizer to linearize 2+3 from tinygrad.codegen.linearizer import Linearizer -sched = result.lazydata.schedule() +from tinygrad.realize import create_schedule +sched = create_schedule([result.lazydata]) linearizer = Linearizer(sched[-1].ast, ClangCompiler.linearizer_opts) linearizer.linearize() diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 38c13ffa56..e0db5effa8 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -73,7 +73,7 @@ assert out.as_buffer().cast('I')[0] == 5 print("******** third, the LazyBuffer ***********") from tinygrad.lazy import LazyBuffer, LoadOps -from tinygrad.realize import run_schedule +from tinygrad.realize import run_schedule, create_schedule # allocate some values + load in values # TODO: remove numpy here @@ -87,7 +87,7 @@ b.realized = Buffer("CPU", 1, dtypes.int32, np.array([3], np.int32).flatten()) out = a.e(BinaryOps.ADD, b) # schedule the computation as a list of kernels -sched = out.schedule() +sched = create_schedule([out]) for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG # DEBUGGING: print the compute ast as a tree diff --git a/docs/linearizer_v2.md b/docs/linearizer_v2.md index 6e7966c616..5b02e7fb41 100644 --- a/docs/linearizer_v2.md +++ b/docs/linearizer_v2.md @@ -20,10 +20,8 @@ More generically, the whole network is a DAG. Ignore the forward/backward stuff, This is a rewrite of a lot of tinygrad. I don't think continuing to support Interpreted backends is worth it, have to deal with disk in a smart way. -We keep the frontend: tensor.py + mlops.py + lazy.py -We keep the backend (renderer/runtime): cstyle.py + device.py + ops_*.py -We keep the shapetracker/symbolic: shapetracker.py + view.py + symbolic.py We keep the features and nn stuff. -But codegen is all rewritten. - - +We keep the frontend (Tensor -> LazyBuffer): tensor.py + mlops.py + lazy.py +We keep the shapetracker/symbolic (part of the frontend): shapetracker.py + view.py + symbolic.py +Codegen is all rewritten. +We keep the backend (uops renderer/runtime): cstyle.py + device.py + ops_*.py diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index 75ece98b2b..7f88087f1c 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -8,6 +8,7 @@ from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.helpers import ansilen, DEBUG, getenv from tinygrad.shape.symbolic import sym_infer from tinygrad.dtype import dtypes +from tinygrad.realize import create_schedule if __name__ == "__main__": if getenv("HALF"): @@ -21,12 +22,12 @@ if __name__ == "__main__": print(f"optimizing for {Device.DEFAULT}") # first model run to init the weights, they are saved in seen - mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen) + create_schedule([mdl(Tensor.empty(64, 3, 224, 224)).lazydata], seen) # run model again to get only what changes, these are the kernels of the model x = Tensor.empty(64, 3, 224, 224) out = mdl(x) - sched = out.lazydata.schedule(seen) + sched = create_schedule([out.lazydata], seen) sched = [x for x in sched if x.ast.op not in LoadOps] # focus on one kernel diff --git a/extra/autopad.py b/extra/autopad.py index 2fce66171b..7aebc2358f 100644 --- a/extra/autopad.py +++ b/extra/autopad.py @@ -3,13 +3,14 @@ from tinygrad.ops import LoadOps from tinygrad.codegen.linearizer import Linearizer from test.external.fuzz_linearizer import run_linearizer from tinygrad.codegen.kernel import Opt, OptOps +from tinygrad.realize import create_schedule N = 17**3 a = Tensor.rand(N, N) b = Tensor.rand(N, N) c = a @ b -sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps] +sched = [si for si in create_schedule([c.lazydata]) if si.ast.op not in LoadOps] assert len(sched) == 1 lin = Linearizer(sched[0].ast) @@ -24,7 +25,7 @@ run_linearizer(lin) ### a = Tensor.rand(61, 61).sum(axis=0) -sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps] +sched = [si for si in create_schedule([a.lazydata]) if si.ast.op not in LoadOps] assert len(sched) == 1 lin = Linearizer(sched[0].ast) diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py index 11318a348e..d89d189852 100644 --- a/extra/gemm/tvm_gemm.py +++ b/extra/gemm/tvm_gemm.py @@ -30,13 +30,14 @@ except ImportError: import os from tinygrad.tensor import Tensor +from tinygrad.realize import create_schedule # define the compute A = Tensor.rand(M, K, device="clang") B = Tensor.rand(K, N, device="clang") C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2) -sched = C.lazydata.schedule() +sched = create_schedule([C.lazydata]) from tinygrad.codegen.linearizer import Linearizer from tinygrad.codegen.kernel import LinearizerOptions lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False)) diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 9153e71889..238f5001ec 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -16,7 +16,7 @@ from extra.onnx import get_run_onnx from tinygrad import Tensor, Device, GlobalCounters, dtypes from tinygrad.dtype import ImageDType from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG -from tinygrad.realize import run_schedule, lower_schedule_item +from tinygrad.realize import run_schedule, lower_schedule_item, create_schedule from tinygrad.ops import LoadOps, ScheduleItem Device.DEFAULT = "GPU" @@ -32,7 +32,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: # run the model inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()} ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous() - schedule = ret.lazydata.schedule() + schedule = create_schedule([ret.lazydata]) # filter schedule that don't depend on the inputs input_lb = [x.lazydata.base for x in inputs.values()] diff --git a/test/external/external_test_uops_graphing.py b/test/external/external_test_uops_graphing.py index 8c49828a9c..56bd18cd7b 100644 --- a/test/external/external_test_uops_graphing.py +++ b/test/external/external_test_uops_graphing.py @@ -4,6 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.codegen.linearizer import Linearizer from tinygrad.renderer.cstyle import OpenCLRenderer from tinygrad.features.graph import graph_uops +from tinygrad.realize import create_schedule from tinygrad.nn import Conv2d class TestUopsGraph(unittest.TestCase): @@ -11,7 +12,7 @@ class TestUopsGraph(unittest.TestCase): N = 1024 a = Tensor.rand(N,N) b = Tensor.rand(N,N) - si = (a@b).lazydata.schedule()[-1] + si = create_schedule([(a@b).lazydata])[-1] lin = Linearizer(si.ast) lin.hand_coded_optimizations() print(lin.colored_shape()) @@ -22,7 +23,7 @@ class TestUopsGraph(unittest.TestCase): def test_reduce(self): a = Tensor.rand(1024*1024) - si = a.sum().lazydata.schedule()[-1] + si = create_schedule([a.sum().lazydata])[-1] lin = Linearizer(si.ast) lin.hand_coded_optimizations() uops = lin.linearize().uops @@ -32,7 +33,7 @@ class TestUopsGraph(unittest.TestCase): def test_conv(self): x = Tensor.rand(1,3,16,16) c = Conv2d(3, 16, (3,3)) - si = c(x).elu().lazydata.schedule()[-1] + si = create_schedule([c(x).elu().lazydata])[-1] lin = Linearizer(si.ast) lin.hand_coded_optimizations() uops = lin.linearize().uops diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 9cf7ff24a2..a3fc0da6dd 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -3,6 +3,7 @@ import unittest from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps from tinygrad.nn import Conv2d +from tinygrad.realize import create_schedule class TestConvShapetracker(unittest.TestCase): def test_conv_3x3_one_view(self): @@ -10,9 +11,9 @@ class TestConvShapetracker(unittest.TestCase): seen = set() # first run to init the weights, they are saved in seen - conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) + create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) # run it again to get the kernels - sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) if si.ast.op not in LoadOps] + sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" print(sched[0]) for arg in [sched[0].out, *sched[0].inputs]: diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index e9b4a8111a..be27bd7e8c 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -6,6 +6,7 @@ import numpy as np from hypothesis import given, strategies as strat, settings from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv +from tinygrad.realize import create_schedule from tinygrad.ops import UnaryOps, get_lazyop_info from test.test_dtype import is_dtype_supported @@ -64,7 +65,7 @@ def universal_test(a, b, dtype, op): def universal_test_unary(a, dtype, op): if not isinstance(op, tuple): op = (op, op) out: Tensor = op[0](Tensor([a], dtype=dtype)) - ast = out.lazydata.schedule()[-1].ast + ast = create_schedule([out.lazydata])[-1].ast tensor_value = out.numpy() numpy_value = op[1](np.array([a]).astype(dtype.np)) if dtype in dtypes_float: diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 0562e608df..56e4a42049 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -3,8 +3,7 @@ import time import numpy as np from tinygrad import Tensor, dtypes from tinygrad.device import InterpretedASTRunner -from tinygrad.lazy import create_schedule -from tinygrad.realize import run_schedule, lower_schedule_item +from tinygrad.realize import run_schedule, create_schedule, lower_schedule_item class TestFusionOp(unittest.TestCase): def test_contiguous_add(self): diff --git a/test/test_lazyop.py b/test/test_lazyop.py index 7e097d3d4a..da5d8a1227 100644 --- a/test/test_lazyop.py +++ b/test/test_lazyop.py @@ -1,5 +1,6 @@ import unittest from tinygrad.tensor import Tensor +from tinygrad.realize import create_schedule # stuff needed to unpack a kernel # ruff: noqa: F401 @@ -16,7 +17,7 @@ inf, nan = float('inf'), float('nan') class TestLazyOp(unittest.TestCase): def test_lazyop_str(self): t = Tensor.rand(10) + Tensor.rand(10) - s = t.lazydata.schedule() + s = create_schedule([t.lazydata]) ast = s[-1].ast ast_remade = eval(str(ast)) self.assertEqual(ast, ast_remade) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f90d0dc75c..1bcf32ce14 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -10,7 +10,7 @@ from tinygrad.shape.view import View from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode from tinygrad.tensor import Tensor from tinygrad.features.jit import CacheCollector -from tinygrad.realize import run_schedule +from tinygrad.realize import create_schedule, run_schedule from tinygrad.helpers import prod, Context from tinygrad.dtype import DType, dtypes @@ -33,7 +33,7 @@ class TestLinearizer(unittest.TestCase): # these are of size 3 to avoid float4 coalesce r = a[:-1] + a[1:] - k = Linearizer(r.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD]) @@ -46,7 +46,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = a.expand([2]) + b.expand([2]) - k = Linearizer(r.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) @@ -56,7 +56,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack([a, b]) - k = Linearizer(r.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([r.lazydata])[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU]) @@ -67,7 +67,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor(2), Tensor(3) r = a * b - k = Linearizer(r.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([r.lazydata])[-1].ast) k.linearize() num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]]) assert num_ops <= 0, "more load or alu uops than needed" @@ -76,14 +76,14 @@ class TestLinearizer(unittest.TestCase): for tensor_dtype, acc_dtype in ( (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() - k = Linearizer(a.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([a.lazydata])[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): - k = Linearizer(c.lazydata.schedule()[-1].ast) + k = Linearizer(create_schedule([c.lazydata])[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC] assert local[0].dtype == expected_dtype @@ -121,7 +121,7 @@ class TestLinearizer(unittest.TestCase): def test_limit_dims_to_max_5d_global(self): t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1 - sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] + sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps] assert len(sched) == 1 lin = Linearizer(sched[0].ast) assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9) @@ -129,7 +129,7 @@ class TestLinearizer(unittest.TestCase): def test_sum_collapse(self): t = Tensor.ones(256,256).sum() - sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps] + sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps] assert len(sched) == 1 lin = Linearizer(sched[0].ast) assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse" @@ -154,7 +154,7 @@ class TestLinearizer(unittest.TestCase): arg=TernaryOps.WHERE).uop == UOps.ALU def helper_realized_ast(r:Tensor): - s = r.lazydata.schedule() + s = create_schedule([r.lazydata]) run_schedule(s[:-1]) # run all kernels except the last one # now all input LazyBuffers buffers in s[-1] should be realized # allocate an output buffer @@ -176,7 +176,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 8).realize() c = a + b - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.hand_coded_optimizations() k.linearize() @@ -188,7 +188,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 8).realize() c = a + b - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.shift_to(0, 4) # float4 dimension k.shift_to(0, 2, insert_before=k.shape_len-1) @@ -204,7 +204,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(9).realize().shrink(((1, 9),)) c = a + b - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.hand_coded_optimizations() # implicit trigger float4 dim k.linearize() @@ -216,7 +216,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) c = a + b - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim k.upcast() @@ -234,7 +234,7 @@ class TestFloat4(unittest.TestCase): # only the first and last conv dot products are aligned in a, and b is never aligned, so no # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.upcast() k.linearize() @@ -249,7 +249,7 @@ class TestFloat4(unittest.TestCase): # dimension, then we could do float4 for only that one set of loads, but we currently # don't. - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.upcast() k.upcast() @@ -265,7 +265,7 @@ class TestFloat4(unittest.TestCase): # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.shift_to(0, 4, top=True) # top axes are float4 axes k.upcast() @@ -281,7 +281,7 @@ class TestFloat4(unittest.TestCase): # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() @@ -296,7 +296,7 @@ class TestFloat4(unittest.TestCase): # should float4 b but not a - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() @@ -310,7 +310,7 @@ class TestHandCodedOpts(unittest.TestCase): layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)]) layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20)) - s = layer_2.lazydata.schedule()[-1] + s = create_schedule([layer_2.lazydata])[-1] k = Linearizer(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 6 # make sure all ops are done in one kernel @@ -323,7 +323,7 @@ class TestHandCodedOpts(unittest.TestCase): def test_masked_upcast_wino(self): monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) - s = monster.lazydata.schedule()[-1] + s = create_schedule([monster.lazydata])[-1] k = Linearizer(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 37 # make sure all ops are done in one kernel @@ -335,7 +335,7 @@ class TestHandCodedOpts(unittest.TestCase): x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize() out = Tensor.conv2d(x,w, padding=1) upcasts = [] - wino_schedule = out.lazydata.schedule() + wino_schedule = create_schedule([out.lazydata]) # collect upcasts of tile transform kernels for i, si in enumerate(wino_schedule): k = Linearizer(si.ast) @@ -349,7 +349,7 @@ class TestHandCodedOpts(unittest.TestCase): assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1 out.mean().backward() - backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule() + backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata]) for si in backward_schedule: k = Linearizer(si.ast) k.hand_coded_optimizations() @@ -364,7 +364,7 @@ class TestHandCodedOpts(unittest.TestCase): layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4)) layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4)) - s = layer_3.lazydata.schedule()[-1] + s = create_schedule([layer_3.lazydata])[-1] k = Linearizer(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 5 # make sure all ops are done in one kernel @@ -379,7 +379,7 @@ class TestHandCodedOpts(unittest.TestCase): b = Tensor.rand(N, N).realize() c = a @ b - s = c.lazydata.schedule()[0] + s = create_schedule([c.lazydata])[0] k = Linearizer(s.ast) k.hand_coded_optimizations() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 77f5ce202a..92a22e7f40 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -5,6 +5,7 @@ from tinygrad.device import BufferCopy from tinygrad.ops import LoadOps, ReduceOps from tinygrad.helpers import CI from tinygrad.nn.state import get_parameters +from tinygrad.realize import create_schedule import numpy as np from hypothesis import given, strategies as strat, settings @@ -296,7 +297,7 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(bn): p.shard_(devices).realize() out = bn(t) - scheds = [sched for sched in out.lazydata.schedule() if sched.out.device in devices and sched.ast.op is not LoadOps.COPY] + scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY] assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device" asts = [sched.ast for sched in scheds] assert len(asts) == 8, len(asts) @@ -527,9 +528,9 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase): p.shard_(devices) synced_out = synced_bn(x) - synced_si = [si for si in synced_out.lazydata.schedule()] + synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)] unsynced_out = unsynced_bn(x) - unsynced_si = [si for si in unsynced_out.lazydata.schedule()] + unsynced_si = [si for si in create_schedule(unsynced_out.lazydata.lbs)] # TODO: test synced / unsynced batchnorm cross device kernel and copies assert synced_si diff --git a/test/test_schedule.py b/test/test_schedule.py index 2e39cc28cd..3ba137fdbd 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -10,16 +10,17 @@ from tinygrad.device import Device, Compiled from tinygrad.helpers import DEBUG, GRAPH from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.graph import print_tree, realized_lazybuffer +from tinygrad.realize import create_schedule from tinygrad import nn, dtypes def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True): seen = set() if to_prerealize: for pre in to_prerealize: - for s in pre.lazydata.schedule(seen.copy()): + for s in create_schedule([pre.lazydata], seen.copy()): if GRAPH: realized_lazybuffer(s.out, 0) seen.add(s.out) - sched = t.lazydata.schedule(seen) + sched = create_schedule([t.lazydata], seen) if GRAPH: for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1) if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps] diff --git a/test/test_search.py b/test/test_search.py index 485fd8fd7f..f893230a81 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -1,6 +1,7 @@ import unittest from tinygrad.codegen.linearizer import Linearizer +from tinygrad.realize import create_schedule from tinygrad.features.search import time_linearizer from tinygrad.device import Compiled, Device, Buffer from tinygrad.ops import LoadOps @@ -11,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase): if not isinstance(Device[Device.DEFAULT], Compiled): raise unittest.SkipTest("only test for compiled backends") def test_reasonable_time(self): - si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0] + si = [si for si in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if si.ast.op not in LoadOps][0] rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs] tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10) assert tm > 0 and tm != float('inf') diff --git a/test/test_winograd.py b/test/test_winograd.py index 281ba611fd..5f13447637 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -3,6 +3,7 @@ from tinygrad import Tensor, GlobalCounters from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG from tinygrad.ops import LoadOps from tinygrad.codegen.linearizer import Linearizer +from tinygrad.realize import create_schedule class TestWinograd(unittest.TestCase): def setUp(self): @@ -19,7 +20,7 @@ class TestWinograd(unittest.TestCase): out = Tensor.conv2d(x, w) with Timing("scheduling: "): - sched = out.lazydata.schedule() + sched = create_schedule([out.lazydata]) for i,s in enumerate(sched): if s.ast.op in LoadOps: continue diff --git a/tinygrad/features/multi.py b/tinygrad/features/multi.py index 0c159e1caf..384440d8f9 100644 --- a/tinygrad/features/multi.py +++ b/tinygrad/features/multi.py @@ -4,7 +4,7 @@ import functools, itertools, operator from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG from tinygrad.dtype import DType, Scalar from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps -from tinygrad.lazy import LazyBuffer, create_schedule +from tinygrad.lazy import LazyBuffer from tinygrad.shape.shapetracker import sint def all_reduce(op:ReduceOps, lbs): @@ -57,7 +57,6 @@ class MultiLazyBuffer: def is_unrealized_contiguous_const(self): return False # passthroughs - def schedule(self, seen=None): return create_schedule(self.real_lbs, seen) def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real) def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real) def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 10793f1219..5c546873c3 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,19 +1,14 @@ from __future__ import annotations -import sys, math -from collections import defaultdict -from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast -from tinygrad.dtype import dtypes, DType, ImageDType, Scalar -from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH -from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem -from tinygrad.shape.symbolic import sint, Variable +import math +from typing import Union, Optional, Any, Tuple, List, Dict, cast +from tinygrad.dtype import dtypes, DType, Scalar +from tinygrad.helpers import prod, getenv, all_int, all_same +from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op +from tinygrad.shape.symbolic import sint from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer -from tinygrad.features.graph import log_lazybuffer from weakref import ref, ReferenceType -# lazy can recurse a lot -sys.setrecursionlimit(10000) - lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {} def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))): @@ -76,8 +71,6 @@ class LazyBuffer: def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST - def schedule(self, seen=None): return create_schedule([self], seen) - def _copy(self, device:str) -> LazyBuffer: sync_size = 1 if self.device.startswith("HIP") else 0 sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True) @@ -124,6 +117,7 @@ class LazyBuffer: return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,)) def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: + # TODO: this logic should move to the scheduler if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape) assert len(self.shape)==len(new_shape) and all(ns in (1,s) for s,ns in zip(self.shape,new_shape)), f"not a contraction {self.shape=} {new_shape=}" # TODO: can we split symbolic shape if the reduce axis is not symbolic? @@ -149,166 +143,3 @@ class LazyBuffer: def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg)) def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg)) def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg)) - -# *** schedule creation *** - -# recursively create a lazyop -def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, - realizes:Set[LazyBuffer], cache, first=True) -> LazyOp: - if (buf, st) in cache: return cache[(buf, st)] - if buf != buf.base: - st = buf.st + st - buf = buf.base - # all buffers here are base now - assert buf.op is not None - - # consts are always fused and generated - if buf.op is LoadOps.CONST: - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st)) - - # if we aren't fusing it, it's a load and we add it to the inputs - if buf.realized or (buf in realizes and not first): - if buf not in inputs: inputs.append(buf) - unbound_st, st_var_vals = st.simplify().unbind() - var_vals.update(st_var_vals) - return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) - - # if a CONTIGUOUS made it all the way here, just skip it - if buf.op is LoadOps.CONTIGUOUS: - assert first - return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) - - # if it's a reduce, we have to change the shapetracker - if buf.op in ReduceOps: - assert st.contiguous, "ReduceOps late fusion must be contiguous" - st = ShapeTracker.from_shape(buf.srcs[0].shape) - - # otherwise we fuse it like normal - cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg) - return ret - -# recursively walk back in the graph to create the schedule -def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer], - reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]: - if out in seen or out.realized or out.op == LoadOps.CONST: return [] - assert out.base == out - seen.add(out) - - inputs: List[LazyBuffer] = [] - var_vals: Dict[Variable, int] = out.st.var_vals.copy() - if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}: - op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) - else: - output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) - op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={}) - op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])) - - return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)] - -# recursively search the entire graph for all LazyBuffers, insert realizes after expands -def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], - simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): - if buf in allbufs or buf.base.realized: return - if GRAPH: log_lazybuffer(buf, scheduled) - if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or - not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") - buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match - if buf.base != buf: - # realize all places where the buffer is expanded - if prod(buf.base.st.shape) < prod(buf.st.shape): - if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \ - prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]): - simple_pads.add(buf.base) - else: - realizes.add(buf.base) - return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children) - if buf.forced_realize: realizes.add(buf) - allbufs[buf] = None - if buf.op in LoadOps: realizes.add(buf.base) - if buf.op == LoadOps.COPY: - assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" - realizes.add(buf.srcs[0].base) - for x in buf.srcs: - children[x.base][buf] = None - _recurse_lb(x, realizes, allbufs, simple_pads, children) - -UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2} -def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: - if buf in realizes or buf.realized: return True - # NOTE: this broke to_image_idx and coder with JIT - if buf.op in UNSAFE_PAD_OPS: return False - return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) - -def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: - if seen is None: seen = set() - - # start by just realizing the buffers passed in - realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized]) - allbufs: Dict[LazyBuffer, None] = {} - simple_pads: Set[LazyBuffer] = set() - children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) - for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True) - - # check if we have to realize pads - for p in simple_pads: - if not _is_padding_okay(p, realizes): - realizes.add(p) - - # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) - reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} - for r in allbufs.keys(): - if r != r.base or r.op not in ReduceOps or r in realizes: continue - - # follow the reduce down - child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st} - realized_children: Dict[LazyBuffer, ShapeTracker] = {} - forced_realize = False - can_chase = True - while not forced_realize and len(child_set): - next_child_set = {} - for tr,st in child_set.items(): - if tr in realizes: - realized_children[tr] = st - # can only have one output buffer - # can only reduce contiguous - # max one reduceop per kernel - if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): - can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r - forced_realize = True - break - continue - for tr_next in children[tr].keys(): - if not tr_next.realized: - # max one reduceop per kernel - if tr_next.op in ReduceOps: - forced_realize = True - break - st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) - if len(st_childs) > 1: - forced_realize = True - break - next_child_set[tr_next] = st + st_childs[0].st - child_set = next_child_set - if forced_realize: - tr = r - if can_chase: - # can chase this down to contiguous children - st = tr.st - while len(children[tr]) == 1: - tr_next = next(iter(children[tr].keys())) - st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) - if len(st_childs) > 1: break - if st.size != st_childs[0].st.size: break - st = st + st_childs[0].st - if not st.contiguous or tr_next.op in ReduceOps: break - tr = tr_next - reduce_for_op[tr] = r - realizes.add(tr) - else: - assert len(realized_children) == 1 - reduce_for_op[next(iter(realized_children.keys()))] = r - - return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs) diff --git a/tinygrad/realize.py b/tinygrad/realize.py index f7ab930add..ee5b66f45a 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -1,9 +1,14 @@ -from typing import List, Dict, Optional, cast -from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters +import sys +from collections import defaultdict +from typing import List, Dict, Optional, cast, Set, DefaultDict +from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions -from tinygrad.features.graph import print_tree, realized_lazybuffer -from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG +from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer +from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int from tinygrad.shape.symbolic import Variable +from tinygrad.dtype import ImageDType, dtypes +from tinygrad.lazy import LazyBuffer +from tinygrad.shape.shapetracker import ShapeTracker # *** schedule running *** @@ -68,3 +73,169 @@ def run_schedule(schedule:List[ScheduleItem]): if prg: prg.exec(cast(List[Buffer], real_buffers), si.var_vals) elif si.out.size > 0: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device) if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count) + +# *** schedule creation *** + +# creation can recurse a lot +sys.setrecursionlimit(10000) + +# recursively create a lazyop +def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, + realizes:Set[LazyBuffer], cache, first=True) -> LazyOp: + if (buf, st) in cache: return cache[(buf, st)] + if buf != buf.base: + st = buf.st + st + buf = buf.base + # all buffers here are base now + assert buf.op is not None + + # consts are always fused and generated + if buf.op is LoadOps.CONST: + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st)) + + # if we aren't fusing it, it's a load and we add it to the inputs + if buf.realized or (buf in realizes and not first): + if buf not in inputs: inputs.append(buf) + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st)) + + # if a CONTIGUOUS made it all the way here, just skip it + if buf.op is LoadOps.CONTIGUOUS: + assert first + return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False) + + # if it's a reduce, we have to change the shapetracker + if buf.op in ReduceOps: + assert st.contiguous, "ReduceOps late fusion must be contiguous" + st = ShapeTracker.from_shape(buf.srcs[0].shape) + + # otherwise we fuse it like normal + cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg) + return ret + +# recursively walk back in the graph to create the schedule +def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer], + reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]: + if out in seen or out.realized or out.op == LoadOps.CONST: return [] + assert out.base == out + seen.add(out) + + inputs: List[LazyBuffer] = [] + var_vals: Dict[Variable, int] = out.st.var_vals.copy() + if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}: + op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) + else: + output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape) + op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={}) + op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])) + + return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)] + +# recursively search the entire graph for all LazyBuffers, insert realizes after expands +def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], + simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): + if buf in allbufs or buf.base.realized: return + if GRAPH: log_lazybuffer(buf, scheduled) + if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or + not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): + if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") + buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match + if buf.base != buf: + # realize all places where the buffer is expanded + if prod(buf.base.st.shape) < prod(buf.st.shape): + if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \ + prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]): + simple_pads.add(buf.base) + else: + realizes.add(buf.base) + return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children) + if buf.forced_realize: realizes.add(buf) + allbufs[buf] = None + if buf.op in LoadOps: realizes.add(buf.base) + if buf.op == LoadOps.COPY: + assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" + realizes.add(buf.srcs[0].base) + for x in buf.srcs: + children[x.base][buf] = None + _recurse_lb(x, realizes, allbufs, simple_pads, children) + +UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2} +def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: + if buf in realizes or buf.realized: return True + # NOTE: this broke to_image_idx and coder with JIT + if buf.op in UNSAFE_PAD_OPS: return False + return all(_is_padding_okay(x.base, realizes) for x in buf.srcs) + +def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: + if seen is None: seen = set() + + # start by just realizing the buffers passed in + realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized]) + allbufs: Dict[LazyBuffer, None] = {} + simple_pads: Set[LazyBuffer] = set() + children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) + for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True) + + # check if we have to realize pads + for p in simple_pads: + if not _is_padding_okay(p, realizes): + realizes.add(p) + + # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) + reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {} + for r in allbufs.keys(): + if r != r.base or r.op not in ReduceOps or r in realizes: continue + + # follow the reduce down + child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st} + realized_children: Dict[LazyBuffer, ShapeTracker] = {} + forced_realize = False + can_chase = True + while not forced_realize and len(child_set): + next_child_set = {} + for tr,st in child_set.items(): + if tr in realizes: + realized_children[tr] = st + # can only have one output buffer + # can only reduce contiguous + # max one reduceop per kernel + if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r): + can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r + forced_realize = True + break + continue + for tr_next in children[tr].keys(): + if not tr_next.realized: + # max one reduceop per kernel + if tr_next.op in ReduceOps: + forced_realize = True + break + st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) + if len(st_childs) > 1: + forced_realize = True + break + next_child_set[tr_next] = st + st_childs[0].st + child_set = next_child_set + if forced_realize: + tr = r + if can_chase: + # can chase this down to contiguous children + st = tr.st + while len(children[tr]) == 1: + tr_next = next(iter(children[tr].keys())) + st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) + if len(st_childs) > 1: break + if st.size != st_childs[0].st.size: break + st = st + st_childs[0].st + if not st.contiguous or tr_next.op in ReduceOps: break + tr = tr_next + reduce_for_op[tr] = r + realizes.add(tr) + else: + assert len(realized_children) == 1 + reduce_for_op[next(iter(realized_children.keys()))] = r + + return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs) \ No newline at end of file diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7ce0ada442..67ecf19165 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -8,12 +8,12 @@ import numpy as np from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten -from tinygrad.lazy import LazyBuffer, create_schedule +from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer from tinygrad.ops import LoadOps from tinygrad.device import Device, Buffer from tinygrad.shape.symbolic import sint -from tinygrad.realize import run_schedule +from tinygrad.realize import run_schedule, create_schedule # **** start with two base classes, Tensor and Function **** @@ -127,7 +127,7 @@ class Tensor: run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst]))) def realize(self) -> Tensor: - run_schedule(self.lazydata.schedule()) + Tensor.corealize([self]) return self def assign(self, x) -> Tensor: