diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 07e2670b85..a7cf946e5a 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -77,7 +77,7 @@ assert out.as_buffer().cast('I')[0] == 5 print("******** third, the LazyBuffer ***********") from tinygrad.engine.realize import run_schedule -from tinygrad.engine.schedule import create_schedule +from tinygrad.engine.schedule import create_schedule_with_vars # allocate some values + load in values a = UOp.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE) @@ -91,7 +91,7 @@ b = b.buf_uop_view() out = a.alu(Ops.ADD, b) # schedule the computation as a list of kernels -sched = create_schedule([out]) +sched, _ = create_schedule_with_vars([out]) for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG # DEBUGGING: print the compute ast diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 2938f7072c..404fe9c94e 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -6,7 +6,6 @@ from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel from tinygrad.ops import Ops, sym_infer from tinygrad.device import Compiled -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA @@ -18,12 +17,12 @@ def get_sched_resnet(): # run model twice to get only what changes, these are the kernels of the model for _ in range(2): out = mdl(Tensor.empty(BS, 3, 224, 224)) - targets = [out.lazydata] + targets = [out] if getenv("BACKWARD"): optim.zero_grad() out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward() - targets += [x.lazydata for x in optim.schedule_step()] - sched = create_schedule(targets) + targets += [x for x in optim.schedule_step()] + sched = Tensor.schedule(*targets) print(f"schedule length {len(sched)}") return sched @@ -42,17 +41,16 @@ def get_sched_bert(): next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32) # run model twice to get only what changes, these are the kernels of the model - seen = set() for _ in range(2): lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids) - targets = [lm_logits.lazydata, seq_relationship_logits.lazydata] + targets = [lm_logits, seq_relationship_logits] if getenv("BACKWARD"): optim.zero_grad() loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels) # ignore grad norm and loss scaler for now loss.backward() - targets += [x.lazydata for x in optim.schedule_step()] - sched = create_schedule(targets) + targets += [x for x in optim.schedule_step()] + sched = Tensor.schedule(targets) print(f"schedule length {len(sched)}") return sched diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index c0d52f32cd..d257b81778 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -5,7 +5,6 @@ from tinygrad import Device, nn, Tensor, dtypes, Variable Device.DEFAULT = "CLANG" from train_gpt2 import GPT, GPTConfig from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_kernel, run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.ops import Ops @@ -37,7 +36,7 @@ if __name__ == "__main__": tensors = optimizer.schedule_step() else: tensors = [] - sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors]) + sched = loss.schedule(*tensors) print(f"calls {i}:", len(sched)) #run_schedule(sched[:]) sched = memory_planner(sched) diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py index b6851e7f61..d09dd36c35 100644 --- a/extra/gemm/tvm_gemm.py +++ b/extra/gemm/tvm_gemm.py @@ -30,14 +30,13 @@ except ImportError: import os from tinygrad.tensor import Tensor -from tinygrad.engine.schedule import create_schedule # define the compute A = Tensor.rand(M, K, device="clang") B = Tensor.rand(K, N, device="clang") C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2) -sched = create_schedule([C.lazydata]) +sched = C.schedule() from tinygrad.codegen.kernel import Kernel from tinygrad.device import CompilerOptions lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False)) diff --git a/extra/hip_gpu_driver/test_pm4.py b/extra/hip_gpu_driver/test_pm4.py index 27eb005705..b5a4541217 100644 --- a/extra/hip_gpu_driver/test_pm4.py +++ b/extra/hip_gpu_driver/test_pm4.py @@ -4,7 +4,6 @@ from tinygrad import Tensor, Device import tinygrad.runtime.autogen.amd_gpu as amd_gpu import tinygrad.runtime.autogen.kfd as kfd import tinygrad.runtime.autogen.hsa as hsa -from tinygrad.engine.schedule import create_schedule from tinygrad.runtime.ops_amd import kio, AMDProgram from tinygrad.helpers import to_mv @@ -49,7 +48,7 @@ if __name__ == "__main__": a = Tensor([0.,1.,2.], device="KFD").realize() b = a + 7 b.lazydata.buffer.allocate() - si = create_schedule([b.lazydata])[-1] + si = b.schedule()[-1] runner = dev.get_runner(*si.ast) prg: AMDProgram = runner.clprg print("device initted") diff --git a/test/external/external_test_hcq.py b/test/external/external_test_hcq.py index 74dec06e6b..0303948bd1 100644 --- a/test/external/external_test_hcq.py +++ b/test/external/external_test_hcq.py @@ -2,7 +2,6 @@ import unittest, ctypes, struct, time, array from tinygrad import Device, Tensor, dtypes from tinygrad.helpers import to_mv, CI from tinygrad.device import Buffer, BufferSpec -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_runner def _time_queue(q, d): @@ -21,7 +20,7 @@ class TestHCQ(unittest.TestCase): #TestHCQ.d1: AMDDevice = Device["AMD:1"] TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize() TestHCQ.b = self.a + 1 - si = create_schedule([self.b.lazydata])[-1] + si = self.b.schedule()[-1] TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast) TestHCQ.b.lazydata.buffer.allocate() # wow that's a lot of abstraction layers diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index a9b9f3f2db..26b976e99b 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -1,7 +1,6 @@ import unittest, struct, array, ctypes from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import to_mv -from tinygrad.engine.schedule import create_schedule from tinygrad.runtime.ops_nv import NVDevice, HWQueue from tinygrad.engine.search import Opt, OptOps from test.test_linearizer_failures import helper_test_lin @@ -20,7 +19,7 @@ class TestNV(unittest.TestCase): TestNV.d0: NVDevice = Device["NV"] TestNV.a = Tensor([0.,1.], device="NV").realize() TestNV.b = self.a + 1 - si = create_schedule([self.b.lazydata])[-1] + si = self.b.schedule()[-1] TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast) TestNV.b.lazydata.buffer.allocate() TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr) @@ -65,4 +64,3 @@ class TestNV(unittest.TestCase): if __name__ == "__main__": unittest.main() - diff --git a/test/external/fuzz_graph.py b/test/external/fuzz_graph.py index 0fa0b55b02..74c3af1dea 100644 --- a/test/external/fuzz_graph.py +++ b/test/external/fuzz_graph.py @@ -4,7 +4,6 @@ from tinygrad.device import Buffer, Device from tinygrad.helpers import Context, getenv, from_mv from tinygrad.dtype import dtypes from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner from tinygrad.engine.jit import apply_graph_to_jit @@ -19,7 +18,7 @@ def gen_prg(device, inputs_cnt): s = fst[0] for i in range(1, inputs_cnt): s = s.xor(fst[i]) - si = create_schedule([s.lazydata])[-1] + si = s.schedule()[-1] prg = get_runner(device, si.ast) cached_prgs[(device, inputs_cnt)] = prg return prg diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index 3c14545d2f..f97c153f72 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -3,7 +3,6 @@ import unittest from tinygrad.ops import Ops from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d -from tinygrad.engine.schedule import create_schedule from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.helpers import prod from test.unit.test_shapetracker import shapetracker_getitem @@ -11,11 +10,10 @@ from test.unit.test_shapetracker import shapetracker_getitem class TestConvShapetracker(unittest.TestCase): def test_conv_3x3_one_view(self): conv = Conv2d(16, 32, (3, 3)) - # first run to init the weights, they are scheduled. conv(Tensor.empty(1, 16, 10, 10)).schedule() # run it again to get the kernels - sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK] + sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).schedule() if si.ast.op is Ops.SINK] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]: assert len(st.views) == 1 diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index 5606301d1d..a9528742ff 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -6,7 +6,6 @@ import numpy as np from hypothesis import given, strategies as strat, settings, HealthCheck from tinygrad.dtype import DType from tinygrad.helpers import CI, getenv -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import run_schedule from tinygrad.ops import GroupOp from tinygrad.tensor import _to_np_dtype @@ -72,7 +71,7 @@ def universal_test(a, b, dtype, op): def universal_test_unary(a, dtype, op): if not isinstance(op, tuple): op = (op, op) out: Tensor = op[0](Tensor([a], dtype=dtype)) - sched = create_schedule([out.lazydata]) + sched = out.schedule() ast = sched[-1].ast run_schedule(sched) tensor_value = out.numpy() diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index b979ef2b2f..574273af82 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -2,7 +2,6 @@ import unittest import time import numpy as np from tinygrad import Tensor, dtypes -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item, run_schedule class TestFusionOp(unittest.TestCase): @@ -17,7 +16,7 @@ class TestFusionOp(unittest.TestCase): def test_expand_fuse(self): bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32) out = (bt*2).expand(10,10).sum(1) - sched = create_schedule([out.lazydata]) + sched = out.schedule() run_schedule(sched) outd = out.tolist() assert all(x == 20.0 for x in outd) @@ -26,7 +25,7 @@ class TestFusionOp(unittest.TestCase): st = time.perf_counter() a = Tensor([1,2,3,4]) for _ in range(24): a = a + a - sched = create_schedule([a.lazydata]) + sched = a.schedule() ei = lower_schedule_item(sched[-1]) self.assertLess(time.perf_counter()-st, 2.0) assert len(ei.prg.p.src.splitlines()) < 250 @@ -35,13 +34,13 @@ class TestFusionOp(unittest.TestCase): st = time.perf_counter() a = Tensor([1,2,3,4]) for _ in range(24): a = a + a - sched1 = create_schedule([a.lazydata]) + sched1 = a.schedule() b = Tensor([1,2,3,4]) for _ in range(24): b = b + b - sched2 = create_schedule([b.lazydata]) + sched2 = b.schedule() c = Tensor([1,2,3,4]) for _ in range(23): c = c + c - sched3 = create_schedule([c.lazydata]) + sched3 = c.schedule() self.assertEqual(sched1[-1].ast, sched2[-1].ast) with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast) self.assertLess(time.perf_counter()-st, 2.0) diff --git a/test/test_graph.py b/test/test_graph.py index e55ee912a0..989c51232f 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -3,7 +3,6 @@ import unittest, ctypes from tinygrad.device import Device, Buffer from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.engine.schedule import create_schedule from tinygrad.helpers import Context, CI, dedup, from_mv from tinygrad.dtype import dtypes from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner @@ -21,7 +20,7 @@ def helper_exec_op(device, outbuf, inbufs): s = fst[0] for i in range(1, len(inbufs)): s = s.xor(fst[i]) - si = create_schedule([s.lazydata])[-1] + si = s.schedule()[-1] prg = get_runner(device, si.ast) cached_prgs[(device, len(inbufs))] = prg diff --git a/test/test_hcq.py b/test/test_hcq.py index 626fd72e8e..b6d0469220 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -3,7 +3,6 @@ from tinygrad import Device, Tensor, dtypes from tinygrad.helpers import CI, getenv from tinygrad.device import Buffer, BufferSpec from tinygrad.runtime.support.hcq import HCQCompiled -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_runner, CompiledRunner from tinygrad.codegen.kernel import Kernel, Opt, OptOps from tinygrad import Variable @@ -159,7 +158,7 @@ class TestHCQ(unittest.TestCase): a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() b = a + 1 - si = create_schedule([b.lazydata])[-1] + si = b.schedule()[-1] k = Kernel(si.ast, opts=TestHCQ.d0.renderer) for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=3)) @@ -442,7 +441,7 @@ class TestHCQ(unittest.TestCase): def test_memory_barrier(self): a = Tensor([0, 1], device=Device.DEFAULT, dtype=dtypes.int8).realize() b = a + 1 - runner = get_runner(TestHCQ.d0.device, create_schedule([b.lazydata])[-1].ast) + runner = get_runner(TestHCQ.d0.device, b.schedule()[-1].ast) buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated() buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated() diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 040c7ad926..785f27df25 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -4,7 +4,6 @@ import unittest from tinygrad import Tensor, Device, dtypes from tinygrad.engine.realize import run_schedule from tinygrad.ops import Ops, UOp -from tinygrad.engine.schedule import create_schedule class TestLazyBuffer(unittest.TestCase): def test_fromcpu_shape_tracker(self): @@ -74,14 +73,14 @@ class TestLazyBuffer(unittest.TestCase): b = Tensor.randn(2, 2).realize() add = (a+b).contiguous() out = add+2 - sched = create_schedule([out.lazydata]) + sched = out.schedule() self.assertEqual(len(sched), 2) run_schedule(sched) np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2) def test_forced_realized_metaop(self): empty = Tensor.empty(1).contiguous() - sched = create_schedule([empty.lazydata]) + sched = empty.schedule() self.assertEqual(len(sched), 1) self.assertIs(sched[0].ast.op, Ops.EMPTY) run_schedule(sched) @@ -90,14 +89,14 @@ class TestReduceOp(unittest.TestCase): def test_no_split_reduce_kernel(self): a = Tensor.rand(4, 4).realize() a = a.sum() - sched = create_schedule([a.lazydata]) + sched = a.schedule() assert len(sched) == 1 self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS) def test_split_reduce_kernel_dim0(self): a = Tensor.rand(256, 255).realize() a = a.sum() - sched = create_schedule([a.lazydata]) + sched = a.schedule() assert len(sched) == 2 for s in sched: self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS) @@ -105,7 +104,7 @@ class TestReduceOp(unittest.TestCase): def test_split_reduce_kernel_dim1(self): a = Tensor.rand(255, 256).realize() a = a.sum() - sched = create_schedule([a.lazydata]) + sched = a.schedule() assert len(sched) == 2 for s in sched: self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 10a01a1483..a8a26f1b92 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,14 +12,14 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View # from tinygrad.ops import Variable from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.engine.schedule import BUF_LIMIT, create_schedule +from tinygrad.engine.schedule import BUF_LIMIT from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX from tinygrad.dtype import DType, dtypes def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]: if isinstance(r, Tensor): r = [r] - s = create_schedule([x.lazydata for x in r]) + s = Tensor.schedule(*r) run_schedule(s[:-1]) # run all kernels except the last one # now all input LazyBuffers buffers in s[-1] should be realized # create fresh buffers for the output buffer @@ -30,7 +30,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) np_a, np_b = a.numpy(), b.numpy() r = a.matmul(b, acc_dtype=dtype_out) - sched = create_schedule([r.lazydata]) + sched = r.schedule() realized_ast = sched[-1].ast run_schedule(sched) out = r.numpy() @@ -48,7 +48,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True): a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in) r = a.matmul(b, acc_dtype=dtype_out) - sched = create_schedule([r.lazydata]) + sched = r.schedule() realized_ast = sched[-1].ast k = Kernel(realized_ast) k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt) @@ -67,7 +67,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(4), Tensor.randn(4) np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) - lowered = list(lower_schedule(create_schedule([c.lazydata]))) + lowered = list(lower_schedule(c.schedule())) for ei in lowered: ei.run() rawbufs = lowered[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} @@ -924,7 +924,7 @@ class TestLinearizer(unittest.TestCase): # these are of size 3 to avoid float4 coalesce r = a[:-1] + a[1:] - k = Kernel(create_schedule([r.lazydata])[-1].ast) + k = Kernel(r.schedule()[-1].ast) k.upcast() k.linearize() num_loads = len([uop for uop in k.uops if uop.op is Ops.LOAD]) @@ -955,7 +955,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = a.expand([2]) + b.expand([2]) - k = Kernel(create_schedule([r.lazydata])[-1].ast) + k = Kernel(r.schedule()[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.op in GroupOp.ALU]) @@ -966,7 +966,7 @@ class TestLinearizer(unittest.TestCase): x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize() r = Tensor.conv2d(x,w,padding=1).relu() - k = Kernel(create_schedule([r.lazydata])[-1].ast) + k = Kernel(r.schedule()[-1].ast) k.upcast() k.upcast() k.linearize() @@ -983,7 +983,7 @@ class TestLinearizer(unittest.TestCase): def test_upcast_with_locals(self): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() - k = Kernel(create_schedule([r.lazydata])[-1].ast) + k = Kernel(r.schedule()[-1].ast) k.hand_coded_optimizations() k.linearize() @@ -1000,7 +1000,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize() r = Tensor.stack(a, b) - k = Kernel(create_schedule([r.lazydata])[-1].ast) + k = Kernel(r.schedule()[-1].ast) k.upcast() k.linearize() num_ops = len([uop for uop in k.uops if uop.op in GroupOp.ALU]) @@ -1011,14 +1011,14 @@ class TestLinearizer(unittest.TestCase): (dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)): if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype): a = Tensor([1, 2, 3], dtype=tensor_dtype).sum() - k = Kernel(create_schedule([a.lazydata])[-1].ast) + k = Kernel(a.schedule()[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC] assert local[0].dtype == acc_dtype def test_arg_acc_dtype(self): def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType): - k = Kernel(create_schedule([c.lazydata])[-1].ast) + k = Kernel(c.schedule()[-1].ast) k.linearize() local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC] assert local[0].dtype == expected_dtype @@ -1225,7 +1225,7 @@ class TestLinearizer(unittest.TestCase): def test_div_collapse(self): def helper(t, msg, max_ops=0): - sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK] + sched = [si for si in t.schedule() if si.ast.op is Ops.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) @@ -1246,7 +1246,7 @@ class TestLinearizer(unittest.TestCase): def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() - sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK] + sched = [si for si in t.schedule() if si.ast.op is Ops.SINK] assert len(sched) == 1 lin = Kernel(sched[0].ast) assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse" @@ -1262,7 +1262,7 @@ class TestLinearizer(unittest.TestCase): a = Tensor.ones(4, 4).contiguous().realize() b = a.shrink(((1, 2), None)).pad(((1, 2), None)) a.assign(b.where(2, a)) - sched = create_schedule([a.lazydata]) + sched = a.schedule() assert len(sched) == 1 sched_copy = sched[:] run_schedule(sched) @@ -1424,7 +1424,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 8).realize() c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.hand_coded_optimizations() k.linearize() @@ -1437,7 +1437,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 8).realize() c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 dimension k.shift_to(0, 2, insert_before=k.shape_len-1) @@ -1455,7 +1455,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, size).realize() c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(0, 4) k.shift_to(0, shift, insert_before=k.shape_len-1) @@ -1479,7 +1479,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(9).realize().shrink(((1, 9),)) c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.hand_coded_optimizations() # implicit trigger float4 dim k.linearize() @@ -1492,7 +1492,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),)) c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim k.upcast() @@ -1510,7 +1510,7 @@ class TestFloat4(unittest.TestCase): b = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),)) c = a + b - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim k.upcast() @@ -1535,7 +1535,7 @@ class TestFloat4(unittest.TestCase): # only the first and last conv dot products are aligned in a, and b is never aligned, so no # float4 should be emitted (the reduce axis of size 4 is the float4 axis here) - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.upcast() k.linearize() @@ -1551,7 +1551,7 @@ class TestFloat4(unittest.TestCase): # don't. # UPDATE: now we do this fusion - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.upcast() k.upcast() @@ -1567,7 +1567,7 @@ class TestFloat4(unittest.TestCase): # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(0, 4, top=True) # top axes are float4 axes k.upcast() @@ -1583,7 +1583,7 @@ class TestFloat4(unittest.TestCase): # we will upcast the top axis of sz 4. they should not be coalesced into float4, # since the top axis is not contiguous. - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() @@ -1598,7 +1598,7 @@ class TestFloat4(unittest.TestCase): # should float4 b but not a - s = create_schedule([c.lazydata])[0] + s = c.schedule()[0] k = Kernel(s.ast) k.shift_to(0, 4) # float4 axis k.upcast() @@ -1692,7 +1692,7 @@ class TestHandCodedOpts(unittest.TestCase): layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)]) layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20)) - s = create_schedule([layer_2.lazydata])[-1] + s = layer_2.schedule()[-1] k = Kernel(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 6 # make sure all ops are done in one kernel @@ -1705,7 +1705,7 @@ class TestHandCodedOpts(unittest.TestCase): def test_masked_upcast_wino(self): monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)]) - s = create_schedule([monster.lazydata])[-1] + s = monster.schedule()[-1] k = Kernel(s.ast) k.hand_coded_optimizations() assert len(k.bufs) == 37 # make sure all ops are done in one kernel @@ -1719,7 +1719,7 @@ class TestHandCodedOpts(unittest.TestCase): out.mean().backward() upcasts = [] - wino_schedule = create_schedule([out.lazydata]) + wino_schedule = out.schedule() # collect upcasts of tile transform kernels for i, si in enumerate(wino_schedule): k = Kernel(si.ast) @@ -1732,7 +1732,7 @@ class TestHandCodedOpts(unittest.TestCase): # this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1 - backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata]) + backward_schedule = Tensor.schedule(x.grad, w.grad) for si in backward_schedule: k = Kernel(si.ast) k.hand_coded_optimizations() diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 2cbcc0ad14..a80014c084 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -4,7 +4,6 @@ from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes from tinygrad.ops import Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule from tinygrad.multi import all_reduce, MultiLazyBuffer import numpy as np @@ -69,7 +68,7 @@ class TestMultiTensor(unittest.TestCase): X = Tensor.ones(256).contiguous().realize() X.shard_(devices_2, 0) out = (X + X) - sched = create_schedule(out.lazydata.lbs) + sched = out.schedule() names = [] for si, ei in zip(sched[:], lower_schedule(sched)): if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name) @@ -492,7 +491,7 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(bn): p.shard_(devices_4).realize() out = bn(t) - scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY] + scheds = [sched for sched in out.schedule() if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY] assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device" asts = [sched.ast for sched in scheds] assert len(asts) @@ -723,7 +722,7 @@ class TestHandleData(unittest.TestCase): device = (d0, d1, d2, d3) t = Tensor([1, 2, 3, 4]).shard(device).realize() not_covered = t.to(d5) - sched = create_schedule([not_covered.lazydata]) + sched = not_covered.schedule() assert len(sched) == 1 # setup again because create_schedule has side effect t = Tensor([1, 2, 3, 4]).shard(device).realize() @@ -733,7 +732,7 @@ class TestHandleData(unittest.TestCase): for d in device: t = Tensor([1, 2, 3, 4]).shard(device).realize() covered = t.to(d) - sched = create_schedule([covered.lazydata]) + sched = covered.schedule() assert len(sched) == 0 # setup again because create_schedule has side effect t = Tensor([1, 2, 3, 4]).shard(device).realize() @@ -1001,9 +1000,9 @@ class TestBatchNorm(unittest.TestCase): p.to_(devices) synced_out = synced_bn(x) - synced_si = list(create_schedule(synced_out.lazydata.lbs)) + synced_si = list(synced_out.schedule()) unsynced_out = unsynced_bn(x) - unsynced_si = list(create_schedule(unsynced_out.lazydata.lbs)) + unsynced_si = list(unsynced_out.schedule()) # TODO: test synced / unsynced batchnorm cross device kernel and copies assert synced_si diff --git a/test/test_pickle.py b/test/test_pickle.py index 0880959373..ec501969f1 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -1,7 +1,6 @@ import unittest, pickle, types import numpy as np from tinygrad import Tensor, TinyJit, Variable, dtypes -from tinygrad.engine.schedule import create_schedule from tinygrad.helpers import GlobalCounters from tinygrad.ops import PatternMatcher, UPat, UOp @@ -99,7 +98,7 @@ class TestPickle(unittest.TestCase): def test_pickle_schedule(self): a = Tensor([1,2]) out = a + 2 - sched = create_schedule([out.lazydata]) + sched = out.schedule() pk = pickle.dumps(sched) sched_pk = pickle.loads(pk) self.assertEqual(sched_pk[-1].ast, sched[-1].ast) diff --git a/test/test_schedule.py b/test/test_schedule.py index 68e536a938..d94a0f10ce 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -16,7 +16,7 @@ from tinygrad.shape.view import View from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast -from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops +from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule from extra.models.llama import precompute_freqs_cis @@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" - sched = create_schedule([t]) + sched, _ = create_schedule_with_vars([t]) if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") @@ -55,7 +55,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize() ret = Tensor.conv2d(img, w).relu().mean().backward() dtypes.default_float = old_default_float - with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata]) + with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad) run_schedule(s.copy()) cnt = len([si for si in s if si.ast.op is Ops.SINK]) assert cnt == allowed, f"expected {allowed} kernels, got {cnt}" @@ -1394,11 +1394,11 @@ class TestSchedule(unittest.TestCase): def test_const_schedule(self): constv = Tensor.empty(2, 2).lazydata.const_like(10) - self.assertEqual(len(create_schedule([constv])), 0) + check_schedule(constv, 0) def test_const_schedule_contig(self): constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous() - self.assertEqual(len(create_schedule([constv])), 1) + check_schedule(constv, 1) @unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU") def test_image_matmul(self): diff --git a/test/test_search.py b/test/test_search.py index 37747b2106..d22e03bc59 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -4,7 +4,6 @@ from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel from tinygrad.ops import UOp, Ops -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer from tinygrad.tensor import Tensor @@ -16,7 +15,8 @@ from tinygrad.shape.view import View class TestTimeLinearizer(unittest.TestCase): def test_reasonable_time(self): - si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0] + a = Tensor([1,2,3,4]).realize() + si = (a+1).schedule()[0] out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate() memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD} rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))] @@ -24,7 +24,8 @@ class TestTimeLinearizer(unittest.TestCase): assert tm > 0 and tm != float('inf') def test_bufs_from_lin(self): - si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0] + a = Tensor([1,2,3,4]).realize() + si = (a+1).schedule()[0] rawbufs = bufs_from_lin(lin:=Kernel(si.ast)) assert len(rawbufs) == len(lin.membufs) == 2 assert all(r is not None for r in rawbufs) @@ -34,7 +35,7 @@ class TestTimeLinearizer(unittest.TestCase): def test_bufs_from_lin_alt(self): a = Tensor.randn(4, 4).realize() b = a+a[0] - si = [si for si in b.schedule() if si.ast.op is Ops.SINK][0] + si = b.schedule()[0] rawbufs = bufs_from_lin(k:=Kernel(si.ast)) assert len(rawbufs) == len(k.membufs) == 2 assert all(r is not None for r in rawbufs) diff --git a/test/test_tensor.py b/test/test_tensor.py index 3a5853e2c5..610fa7c6af 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -3,7 +3,6 @@ import numpy as np import torch import unittest, copy, mmap, random, math, array from tinygrad import Tensor, Device, dtypes -from tinygrad.engine.schedule import create_schedule from tinygrad.helpers import getenv, temp, _METADATA, mv_address from extra.gradcheck import numerical_jacobian, jacobian, gradcheck from hypothesis import given, settings, strategies as strat @@ -725,7 +724,7 @@ class TestTensorMetadata(unittest.TestCase): W = Tensor.rand(3, 3, requires_grad=True) out = x.matmul(W) self.assertEqual(out.lazydata.metadata.name, "matmul") - si = create_schedule([out.lazydata])[-1] + si = out.schedule()[-1] self.assertEqual(len(si.metadata), 1) self.assertEqual(si.metadata[0].name, "matmul") @@ -733,7 +732,7 @@ class TestTensorMetadata(unittest.TestCase): x = Tensor.rand(3, requires_grad=True) out = x.relu() self.assertEqual(out.lazydata.metadata.name, "relu") - si = create_schedule([out.lazydata])[-1] + si = out.schedule()[-1] self.assertEqual(len(si.metadata), 1) self.assertEqual(si.metadata[0].name, "relu") @@ -744,7 +743,7 @@ class TestTensorMetadata(unittest.TestCase): self.assertEqual(out.lazydata.metadata.name, "__mul__") self.assertEqual(out.lazydata.src[0].metadata.name, "relu") self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid") - si = create_schedule([out.lazydata])[-1] + si = out.schedule()[-1] self.assertEqual(len(si.metadata), 3) self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"}) @@ -758,7 +757,7 @@ class TestTensorMetadata(unittest.TestCase): self.assertTrue(x.grad.lazydata.metadata.backward) self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid") self.assertTrue(y.grad.lazydata.metadata.backward) - si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1] + si = Tensor.schedule(out, x.grad, y.grad)[-1] self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}") self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"}) bw = [m for m in si.metadata if m.backward] diff --git a/test/test_uops.py b/test/test_uops.py index 68bb875f99..e908b9d735 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -9,7 +9,7 @@ from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import ProgramSpec -from tinygrad.engine.schedule import create_schedule, to_si +from tinygrad.engine.schedule import to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym @@ -237,12 +237,12 @@ class TestExecALU(TestUOps): class TestConstantFolding(unittest.TestCase): def test_cast_const(self): t = Tensor(1, dtype=dtypes.float).cast(dtypes.int) - si = create_schedule([t.lazydata]) + si = t.schedule() assert len(si) == 0 def test_bitcast_const(self): t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int) - si = create_schedule([t.lazydata]) + si = t.schedule() assert len(si) == 1 ji = lower_schedule_item(si[-1]) assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast" diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 323c8bbbc5..e95e539b35 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -1,7 +1,6 @@ import unittest from tinygrad import Tensor from tinygrad.helpers import getenv, GlobalCounters -from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item, ProgramSpec from tinygrad.renderer import Estimates from tinygrad.codegen.linearize import linearize_uop @@ -16,7 +15,7 @@ def flops_mem(uops, ignore_indexing=False): # **************** new FlopCounter **************** def get_stats(x:Tensor): - si = create_schedule([x.lazydata])[-1] + si = x.schedule()[-1] ei = lower_schedule_item(si) return ei.prg.estimates.ops, ei.prg.estimates.mem diff --git a/test/test_winograd.py b/test/test_winograd.py index 0acbee9de5..36d984678f 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -3,7 +3,6 @@ from tinygrad import Tensor, GlobalCounters, dtypes from tinygrad.ops import Ops from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.codegen.kernel import Kernel -from tinygrad.engine.schedule import create_schedule class TestWinograd(unittest.TestCase): def setUp(self): @@ -20,7 +19,7 @@ class TestWinograd(unittest.TestCase): out = Tensor.conv2d(x, w) with Timing("scheduling: "): - sched = create_schedule([out.lazydata]) + sched = out.schedule() for i,s in enumerate(sched): if s.ast.op is not Ops.SINK: continue