diff --git a/CLAUDE.md b/CLAUDE.md index b6b660c5d7..17af1c2074 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -88,7 +88,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()" ## Debugging Tips 1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())` -2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems +2. **Check schedule**: `tensor.schedule()` returns list of ExecItems 3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks 4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]` diff --git a/docs/abstractions3.py b/docs/abstractions3.py index c34a399bba..649bbd60b1 100644 --- a/docs/abstractions3.py +++ b/docs/abstractions3.py @@ -38,25 +38,19 @@ optim.schedule_step() # this will step the optimizer without running realize # The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point # l1.uop and l2.uop define a computation graph -from tinygrad.engine.schedule import ScheduleItem -schedule: List[ScheduleItem] = Tensor.schedule(l1, l2) +from tinygrad.engine.schedule import ExecItem +schedule: List[ExecItem] = Tensor.schedule(l1, l2) print(f"The schedule contains {len(schedule)} items.") for si in schedule: print(str(si)[:80]) # ***** -# 4. Lower a schedule. +# 4. Lower and run the schedule. -from tinygrad.engine.realize import lower_schedule_item, ExecItem -lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)] +for si in tqdm(schedule): si.run() # ***** -# 5. Run the schedule - -for ei in tqdm(lowered): ei.run() - -# ***** -# 6. Print the weight change +# 5. Print the weight change print("first weight change\n", l1.numpy()-l1n) print("second weight change\n", l2.numpy()-l2n) diff --git a/docs/developer/developer.md b/docs/developer/developer.md index f932f0a935..d7c7518ff6 100644 --- a/docs/developer/developer.md +++ b/docs/developer/developer.md @@ -17,15 +17,15 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al ## Scheduling -The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on. +The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on. -::: tinygrad.engine.schedule.ScheduleItem +::: tinygrad.engine.schedule.ExecItem ## Lowering -The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with +The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with -::: tinygrad.engine.realize.lower_schedule +::: tinygrad.engine.realize.run_schedule There's a ton of complexity hidden behind this, see the `codegen/` directory. diff --git a/extra/assembly/rdna3/generate.py b/extra/assembly/rdna3/generate.py index 9c0a2ff7e1..61d5f22be6 100644 --- a/extra/assembly/rdna3/generate.py +++ b/extra/assembly/rdna3/generate.py @@ -9,7 +9,6 @@ import xml.etree.ElementTree as ET from tinygrad import nn, Tensor, Device from tinygrad.helpers import get_single_element -from tinygrad.engine.realize import lower_schedule from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.ops_amd import ProfileSQTTEvent from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets @@ -118,7 +117,8 @@ if __name__ == "__main__": root = ET.fromstring(xml_str) a = Tensor.empty(16)+1 - for si, ei in lower_schedule(a.schedule()): + for ei in a.schedule(): + ei.lower() # get text _, hdr, _ = elf_loader(ei.prg.lib) text = get_single_element([x for x in hdr if x.name==".text"]).content diff --git a/extra/gemm/amd_matmul.py b/extra/gemm/amd_matmul.py index 6d704766b8..d72034fef5 100644 --- a/extra/gemm/amd_matmul.py +++ b/extra/gemm/amd_matmul.py @@ -36,7 +36,7 @@ if __name__ == "__main__": for _ in range(run_count): tc = (a@b).realize() GlobalCounters.reset() - ei = ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer]) + ei = ExecItem(ast, [a.uop.buffer, b.uop.buffer, c.uop.buffer], prg=runner) with Context(DEBUG=2): for _ in range(run_count): ei.run(wait=True) print(f"custom {(c-tc).square().mean().item()}") diff --git a/extra/gemm/amd_uop_matmul.py b/extra/gemm/amd_uop_matmul.py index 9ca6adb0ad..af5758e693 100644 --- a/extra/gemm/amd_uop_matmul.py +++ b/extra/gemm/amd_uop_matmul.py @@ -147,7 +147,7 @@ def test_matmul(sink:UOp, N=N): hc = Tensor.empty(N, N) Tensor.realize(a, b, hc) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]]) + ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink)) ets = [] with Context(DEBUG=2): diff --git a/extra/gemm/simple_matmul.py b/extra/gemm/simple_matmul.py index 45a359be38..379b50474e 100644 --- a/extra/gemm/simple_matmul.py +++ b/extra/gemm/simple_matmul.py @@ -3,7 +3,6 @@ from tinygrad import dtypes, Tensor from tinygrad.helpers import getenv, get_single_element from tinygrad.dtype import _to_np_dtype from tinygrad.codegen.opt import OptOps -from tinygrad.engine.realize import lower_schedule dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float) @@ -40,8 +39,8 @@ if __name__ == "__main__": if getenv("SHOULD_USE_TC"): sched = a.matmul(b, dtype=acc_dtype).schedule() - lowered = list(lower_schedule(sched)) - ei = get_single_element(lowered)[1] + ei = get_single_element(sched) + ei.lower() assert any(opt.op is OptOps.TC for opt in ei.prg.p.applied_opts), f"TC not triggered, {ei.prg.p.applied_opts}" ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32) diff --git a/extra/gemm/tinygrad_nv_matmul.py b/extra/gemm/tinygrad_nv_matmul.py index 1ee3e72e15..5e7a4c265a 100644 --- a/extra/gemm/tinygrad_nv_matmul.py +++ b/extra/gemm/tinygrad_nv_matmul.py @@ -33,5 +33,5 @@ if __name__ == "__main__": new_src = prg.src # can mod source here prg = replace(prg, src=new_src) - ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) + ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg)) for i in range(5): ei.run(wait=True) diff --git a/extra/gemm/triton_nv_matmul.py b/extra/gemm/triton_nv_matmul.py index 89e7838bb0..cdd077213c 100644 --- a/extra/gemm/triton_nv_matmul.py +++ b/extra/gemm/triton_nv_matmul.py @@ -88,7 +88,7 @@ if __name__ == "__main__": prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT, global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1], mem_estimate=A.nbytes() + B.nbytes() + C.nbytes()) - ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) + ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg)) tflops = [] for i in range(5): tm = ei.run(wait=True) diff --git a/extra/reduce_speed.py b/extra/reduce_speed.py deleted file mode 100644 index 36bd0d3d5c..0000000000 --- a/extra/reduce_speed.py +++ /dev/null @@ -1,128 +0,0 @@ -import numpy as np -import ctypes -from tinygrad import Tensor, GlobalCounters, Context -from tinygrad.engine.realize import lower_schedule, CompiledRunner -from tinygrad.device import CPUProgram -from dataclasses import replace -from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN - -# only the memory access, over 100 GB/s! (sometimes) -reduce_asm = """ -movi v0.2d, #0000000000000000 -mov w9, #0x30 -mov w10, #0x20 -mov x8, #-0x10 -movi v1.2d, #0000000000000000 -movk w9, #0x300, lsl #16 -movi v2.2d, #0000000000000000 -movk w10, #0x200, lsl #16 -movi v3.2d, #0000000000000000 -mov w11, #0x1000000 -mov w12, #0x3ffff0 -loop: -ldp q4, q5, [x1] -add x13, x1, x11 -add x15, x1, x10 -add x14, x1, x9 -add x8, x8, #0x10 -cmp x8, x12 -ldp q6, q7, [x1, #0x20] -add x1, x1, #0x40 -ldp q4, q5, [x13] -ldp q6, q7, [x13, #0x20] -ldp q4, q5, [x15, #-0x20] -ldp q6, q7, [x15] -ldp q4, q5, [x14, #-0x30] -ldp q6, q7, [x14, #-0x10] -b.lo loop -fadd v0.4s, v1.4s, v0.4s -fadd v0.4s, v2.4s, v0.4s -fadd v0.4s, v3.4s, v0.4s -dup v1.4s, v0.s[1] -dup v2.4s, v0.s[2] -fadd v1.4s, v0.4s, v1.4s -dup v0.4s, v0.s[3] -fadd v1.4s, v2.4s, v1.4s -fadd v0.4s, v0.4s, v1.4s -str s0, [x0] -ret -""" - -ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN) -arm_bytecode, _ = ks.asm(reduce_asm) -arm_bytecode = bytes(arm_bytecode) - -reduce_src = """ -// data1 is 16M inputs -typedef float float4 __attribute__((aligned(32),vector_size(16))); -void reduce(float* restrict data0, float* restrict data1) { - float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f}; - float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f}; - float* data1_1 = data1+4194304; - float* data1_2 = data1+(4194304*2); - float* data1_3 = data1+(4194304*3); - for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) { - float4 val0 = *(float4*)((data1+(ridx0+0))); - float4 val1 = *(float4*)((data1+(ridx0+4))); - float4 val2 = *(float4*)((data1+(ridx0+8))); - float4 val3 = *(float4*)((data1+(ridx0+12))); - acc0 += val0; - acc1 += val1; - acc2 += val2; - acc3 += val3; - val0 = *(float4*)((data1_1+(ridx0+0))); - val1 = *(float4*)((data1_1+(ridx0+4))); - val2 = *(float4*)((data1_1+(ridx0+8))); - val3 = *(float4*)((data1_1+(ridx0+12))); - acc4 += val0; - acc5 += val1; - acc6 += val2; - acc7 += val3; - val0 = *(float4*)((data1_2+(ridx0+0))); - val1 = *(float4*)((data1_2+(ridx0+4))); - val2 = *(float4*)((data1_2+(ridx0+8))); - val3 = *(float4*)((data1_2+(ridx0+12))); - acc0 += val0; - acc1 += val1; - acc2 += val2; - acc3 += val3; - val0 = *(float4*)((data1_3+(ridx0+0))); - val1 = *(float4*)((data1_3+(ridx0+4))); - val2 = *(float4*)((data1_3+(ridx0+8))); - val3 = *(float4*)((data1_3+(ridx0+12))); - acc4 += val0; - acc5 += val1; - acc6 += val2; - acc7 += val3; - } - float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7; - *(data0+0) = out[0]+out[1]+out[2]+out[3]; -} -""" - -if __name__ == "__main__": - a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize() - with Context(SPLIT_REDUCEOP=0): - # TODO: make it easy to alter the OptOps for a ScheduleItem - GlobalCounters.reset() - out = a.sum() - sis = out.schedule() - for i,(_,ei) in enumerate(lower_schedule(sis)): - if i == 0: - # change the source code - prg_spec = ei.prg.p - prg_spec = replace(prg_spec, name="reduce", src=reduce_src) - prg = CompiledRunner(prg_spec) - # change the assembly - #prg._prg = CPUProgram(prg_spec.name, arm_bytecode) - print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf))) - ei = replace(ei, prg=prg) - ei.run() - print(out.item()) - np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4) diff --git a/extra/sched/fuzz_schedule.py b/extra/sched/fuzz_schedule.py deleted file mode 100644 index 34f3313ada..0000000000 --- a/extra/sched/fuzz_schedule.py +++ /dev/null @@ -1,100 +0,0 @@ -import itertools -import numpy as np -from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union -from tinygrad.device import Buffer -from tinygrad.engine.realize import capturing, lower_schedule_item -from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv -from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem -from tinygrad.uop.ops import Ops, UOp -from tinygrad.tensor import Tensor, _to_np_dtype - -ctx_vars = { MULTIOUTPUT: (0, 1) } -FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10) - -def fuzz_schedule(outs:List[UOp]): - # find toposorts across all tunable params - unique_ts: Dict[Tuple[LBScheduleItem, ...], Dict[str, int]] = {} - for combination in itertools.product(*ctx_vars.values()): - for var, val in zip(ctx_vars, combination): var.value = val - ctx_var_values = dict(zip([v.key for v in ctx_vars], combination)) - graph, in_degree, _ = _graph_schedule(outs) - for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values - toposorts = list(unique_ts.items()) - if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow")) - - # setup ground truth - ground_truth: Dict[UOp, memoryview] = {} - assign_targets: Dict[UOp, UOp] = {} - # IMPORTANT: freeze prerealized bufs before ScheduleItem exec - prerealized: Dict[UOp, memoryview] = {} - seed = Tensor._seed - ts,_ = toposorts[0] - for lsi in ts: - for out in lsi.outputs: - # freeze assign state before exec - if out.op is Ops.ASSIGN: - prerealized[out] = out.buffer.as_buffer() - assign_targets[out.srcs[1]] = out - for x in lsi.inputs: - if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer() - si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata) - _exec_si(si, seed) - for out in lsi.outputs: - ground_truth[out] = out.buffer.as_buffer() - del out.srcs # only schedule the LazyBuffer in this fuzz run - - # exec and validate each permutation with new Buffers - for i, (ts, ctx) in enumerate(toposorts[1:]): - if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow")) - rawbufs: Dict[UOp, Buffer] = {} - for lsi in ts: - for out in lsi.outputs: - base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None - rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base) - if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) - for x in lsi.inputs: - if x not in rawbufs: - # override the assign_target after ASSIGN - if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]] - elif x.device == "NPY": rawbufs[x] = x.buffer - # copy the pre realized input - else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x])) - si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata) - _exec_si(si, seed) - for out in lsi.outputs: - outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype)) - try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2) - except Exception as e: - print(f"FAILED FOR {out}") - raise e - -def _exec_si(si:ScheduleItem, seed:int): - ei = lower_schedule_item(si) - if len(capturing): capturing[0].add(ei) - ei.run() - -T = TypeVar("T") -def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]: - visited: Set[T] = set() - ret: List[Tuple[T, ...]] = [] - path: List[T] = [] - - def recurse_paths(path:List[T]): - for v, d in in_degree.items(): - if d != 0 or v in visited: continue - for u in graph[v]: in_degree[u] -= 1 - path.append(v) - visited.add(v) - recurse_paths(path) - if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return - # backtrack - for u in graph[v]: in_degree[u] += 1 - path.pop() - visited.remove(v) - if len(path) == len(in_degree): ret.append(tuple(path)) - recurse_paths(path) - - if len(ret) == 0: raise RuntimeError("detected cycle in the graph") - # verify all paths are unique - assert len(ret) == len(set(ret)) - return ret diff --git a/test/external/external_benchmark_sdxl_softmax.py b/test/external/external_benchmark_sdxl_softmax.py index 19837ac9b2..4d03989652 100644 --- a/test/external/external_benchmark_sdxl_softmax.py +++ b/test/external/external_benchmark_sdxl_softmax.py @@ -25,4 +25,4 @@ if __name__ == "__main__": #k.apply_opt(Opt(OptOps.GROUP, 0, 32)) from tinygrad.engine.realize import CompiledRunner, ExecItem run = CompiledRunner(prg:=get_program(k.ast, k.opts, k.applied_opts)) - ExecItem(run, si.bufs).run() + ExecItem(k.ast, list(si.bufs), prg=run).run() diff --git a/test/external/external_test_hsa_driver.py b/test/external/external_test_hsa_driver.py index 29ffaa2340..737dbc29b0 100644 --- a/test/external/external_test_hsa_driver.py +++ b/test/external/external_test_hsa_driver.py @@ -1,10 +1,12 @@ import ctypes, unittest from tinygrad.helpers import init_c_struct_t -from tinygrad.device import Device, Buffer, BufferXfer +from tinygrad.device import Device, Buffer from tinygrad.dtype import dtypes from tinygrad.runtime.support.hsa import AQLQueue from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph -from tinygrad.engine.realize import ExecItem +from tinygrad.engine.schedule import ExecItem +from tinygrad.engine.realize import BufferXfer +from tinygrad.uop.ops import UOp, Ops def get_hsa_inc_prog(dev, inc=1): prg = f""" @@ -102,7 +104,8 @@ class TestHSADriver(unittest.TestCase): test_buf1.copyin(memoryview(bytearray(1*4))) test_buf2.copyin(memoryview(bytearray(1*4))) - jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])] + jit_cache = [ExecItem(UOp(Ops.NOOP), [test_buf0, test_buf2], prg=BufferXfer(test_buf0.nbytes, test_buf0.device, test_buf2.device)), + ExecItem(UOp(Ops.NOOP), [test_buf2, test_buf1], prg=BufferXfer(test_buf2.nbytes, test_buf2.device, test_buf1.device))] graph = HSAGraph(jit_cache, [], {}) for i in range(10000): diff --git a/test/external/fuzz_graph.py b/test/external/fuzz_graph.py index 4e1d492eae..3b62ccf289 100644 --- a/test/external/fuzz_graph.py +++ b/test/external/fuzz_graph.py @@ -4,7 +4,9 @@ from tinygrad.device import Buffer, Device from tinygrad.helpers import Context, getenv, from_mv from tinygrad.dtype import dtypes from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner +from tinygrad.engine.realize import BufferXfer, get_runner +from tinygrad.engine.schedule import ExecItem +from tinygrad.uop.ops import UOp, Ops from tinygrad.engine.jit import apply_graph_to_jit BUF_LEN = getenv("BUF_LEN", 128) @@ -35,13 +37,13 @@ def gen_kernel_ji(device, deps): assert len(deps) >= 2 out = alloc_rawbuffer(device) prg = gen_prg(device, len(deps)) - return ExecItem(prg, [out] + deps) + return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg) def gen_copy_ji(device, deps): assert len(deps) == 1 out = alloc_rawbuffer(device) prg = BufferXfer(deps[0].nbytes, device, deps[0].device) - return ExecItem(prg, [out] + deps) + return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg) def gen_graph(): input_buffers = [] diff --git a/test/test_arange.py b/test/test_arange.py index 00ef30e290..f82014a269 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -3,7 +3,8 @@ import numpy as np from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable from tinygrad.helpers import Context, getenv from tinygrad.engine.realize import run_schedule -from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program +from tinygrad.engine.realize import CompiledRunner, get_program +from tinygrad.engine.schedule import ExecItem from tinygrad.uop.ops import Ops from tinygrad.renderer import Estimates from tinygrad.renderer.ptx import PTXRenderer @@ -14,7 +15,7 @@ class TestArange(unittest.TestCase): sched = tensor.schedule() self.assertEqual(len(sched), 1) p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer) - ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run() + ExecItem(sched[-1].ast, [tensor.uop.buffer], prg=CompiledRunner(p)).run() np.testing.assert_equal(tensor.numpy(), desired) return p.estimates.ops diff --git a/test/test_compile_failures.py b/test/test_compile_failures.py index 7d9b0e33d5..11c1d42abe 100644 --- a/test/test_compile_failures.py +++ b/test/test_compile_failures.py @@ -2,13 +2,12 @@ import unittest, io from contextlib import redirect_stdout from tinygrad import Tensor, dtypes, Device from tinygrad.helpers import OSX, CPU_LLVM, CPU_LVP -from tinygrad.engine.realize import lower_schedule from tinygrad.device import is_dtype_supported from tinygrad.engine.realize import get_program class TestCompileFailures(unittest.TestCase): def compile(self, out:Tensor): - for _ in lower_schedule(out.schedule()): pass + for si in out.schedule(): si.lower() @unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}") def test_interpolate_atari(self): diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 6dd9040dc1..a01c80c7d9 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -2,7 +2,7 @@ import unittest import time import numpy as np from tinygrad import Tensor, dtypes -from tinygrad.engine.realize import lower_schedule_item, run_schedule +from tinygrad.engine.realize import run_schedule class TestFusionOp(unittest.TestCase): def test_contiguous_add(self): @@ -26,9 +26,9 @@ class TestFusionOp(unittest.TestCase): a = Tensor([1,2,3,4]) for _ in range(24): a = a + a sched = a.schedule() - ei = lower_schedule_item(sched[-1]) + sched[-1].lower() self.assertLess(time.perf_counter()-st, 2.0) - assert len(ei.prg.p.src.splitlines()) < 250 + assert len(sched[-1].prg.p.src.splitlines()) < 250 def test_recursive_add_cmp(self): st = time.perf_counter() diff --git a/test/test_graph.py b/test/test_graph.py index 1d009e218c..33bb4f3793 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -6,7 +6,9 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import Context, dedup, from_mv from tinygrad.dtype import dtypes from tinygrad.engine.jit import MultiGraphRunner -from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner +from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner +from tinygrad.engine.schedule import ExecItem +from tinygrad.uop.ops import UOp, Ops from test.helpers import needs_second_gpu @@ -27,11 +29,11 @@ def helper_exec_op(device, outbuf, inbufs): prg = get_runner(device, si.ast) cached_prgs[(device, len(inbufs))] = prg - return ExecItem(cached_prgs[(device, len(inbufs))], [outbuf] + inbufs) + return ExecItem(UOp(Ops.NOOP), [outbuf] + inbufs, prg=cached_prgs[(device, len(inbufs))]) def helper_copy_op(device, dest, src): prg = BufferXfer(dest.nbytes, device, src.device) - return ExecItem(prg, [dest, src]) + return ExecItem(UOp(Ops.NOOP), [dest, src], prg=prg) def helper_alloc_rawbuffer(device, fill=False): rawbuf = Buffer(device, BUF_SIZE, dtypes.int).ensure_allocated() @@ -69,7 +71,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT): ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)] # Build graphs - gr_ji = [ExecItem(graph_impl(graph, [], {}), []) for graph in graphs] + gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs] for _ in range(runs): test_bufs = helper_run_jit(gr_ji, bufs, out_buffers) diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index da1f3aeeea..1ec95939be 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -3,7 +3,6 @@ import numpy as np from tinygrad import Device, dtypes, Tensor, Context from tinygrad.device import LRUAllocator, is_dtype_supported from tinygrad.dtype import ImageDType -from tinygrad.engine.realize import lower_schedule from tinygrad.helpers import prod, unwrap from test.helpers import REAL_DEV @@ -129,8 +128,8 @@ class TestImageDType(unittest.TestCase): loss = x.image_dot(w1).image_dot(w2).float().max() loss.backward() sched = unwrap(w1.grad).schedule() - for s,(_,ei) in zip(sched, lower_schedule(sched[:])): - ei.run() + for s in sched: + s.run() if s.bufs[0].dtype == dtypes.float: lst = s.bufs[0].as_buffer().cast("f").tolist() print(lst) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c15ad8c2e2..c26b93d0c9 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -7,7 +7,7 @@ from tinygrad.codegen.gpudims import get_grouped_dims from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat from tinygrad.device import Device, Buffer, is_dtype_supported from tinygrad.tensor import Tensor, _to_np_dtype -from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program +from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace from tinygrad.renderer.ptx import PTXRenderer @@ -25,9 +25,9 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) - lowered = [x[1] for x in lower_schedule(c.schedule())] - for ei in lowered: ei.run() - rawbufs = lowered[-1].bufs + sched = c.schedule() + for si in sched: si.run() + rawbufs = sched[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized} np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:]) np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 1409822775..c68531654e 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -4,7 +4,7 @@ from tinygrad.device import is_dtype_supported from tinygrad.uop.ops import Ops, UOp from tinygrad.helpers import getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict -from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule +from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule import numpy as np from hypothesis import given, strategies as strat, settings from test.helpers import not_support_multi_device, needs_second_gpu, slow @@ -123,9 +123,10 @@ class TestMultiTensor(unittest.TestCase): out = (X + X) sched = out.schedule() names = [] - for si, ei in lower_schedule(sched): - if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name) - ei.run() + for si in sched: + si.lower() + if isinstance(si.prg, CompiledRunner): names.append(si.prg.p.name) + si.run() self.assertEqual(len(set(names)), 1, "function was relinearized") @unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs") diff --git a/test/test_opt_gemm.py b/test/test_opt_gemm.py index de7d0b438c..12d1bda436 100644 --- a/test/test_opt_gemm.py +++ b/test/test_opt_gemm.py @@ -3,7 +3,8 @@ import unittest from tinygrad import Tensor, Device from tinygrad.helpers import get_single_element from tinygrad.codegen.opt import Opt, OptOps -from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program +from tinygrad.engine.realize import CompiledRunner, get_program +from tinygrad.engine.schedule import ExecItem class TestOptGemm(unittest.TestCase): @classmethod @@ -18,7 +19,7 @@ class TestOptGemm(unittest.TestCase): # TODO: this should be a generic test helper si = get_single_element(t.schedule()) run = CompiledRunner(get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts)) - ExecItem(run, si.bufs).run() + ExecItem(si.ast, list(si.bufs), prg=run).run() test = si.bufs[0].numpy().reshape(self.res.shape) np.testing.assert_allclose(self.res, test, atol=1e-4) diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 4a30b9ad2c..72d74b4251 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -5,7 +5,8 @@ from dataclasses import replace from tinygrad import Tensor, Context, Device, dtypes from tinygrad.uop.ops import Ops from tinygrad.codegen.opt import Opt, OptOps -from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program +from tinygrad.engine.realize import CompiledRunner, get_program +from tinygrad.engine.schedule import ExecItem N = 512 @@ -42,8 +43,8 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): if replace_src is not None: old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0] prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn")) - ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata) - for _ in range(run_count): ei.run(wait=True) + new_si = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg)) + for _ in range(run_count): new_si.run(wait=True) def get_quantized_model(sz): from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader @@ -74,8 +75,8 @@ class TestQuantizeOnnxCPU(unittest.TestCase): inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)) with Context(QUANTIZE=1): sched = run_onnx({"input":inp})["output"].schedule() - ei = lower_schedule_item(sched[-2]) - daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG] + sched[-2].lower() + daccs = [u for u in sched[-2].prg.p.uops if u.op is Ops.DEFINE_REG] assert all(u.dtype.scalar() is dtypes.int for u in daccs) @unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP") diff --git a/test/test_randomness.py b/test/test_randomness.py index 4504ccba65..d3d888fd34 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -4,7 +4,7 @@ from functools import partial from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable from tinygrad.helpers import getenv, CI, OSX from tinygrad.device import is_dtype_supported -from tinygrad.engine.realize import lower_schedule, CompiledRunner +from tinygrad.engine.realize import CompiledRunner from tinygrad.renderer.ptx import PTXRenderer from tinygrad.renderer.nir import NIRRenderer from test.helpers import not_support_multi_device, needs_second_gpu @@ -103,10 +103,12 @@ class TestRandomness(unittest.TestCase): @unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "PTX and NIR use pointer arithmetic") def test_threefry_doesnt_use_long(self): - for (_,ei) in lower_schedule(Tensor.rand(20).schedule()): - if isinstance(ei.prg, CompiledRunner): - for u in ei.prg.p.uops: - self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}") + sched = Tensor.rand(20).schedule() + for si in sched: + si.lower() + if isinstance(si.prg, CompiledRunner): + for u in si.prg.p.uops: + self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {si.prg.p.name}") def test_threefry_against_reference_full(self): Tensor.manual_seed(1337) diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 4baebae6b7..5a406067f1 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -12,7 +12,6 @@ from tinygrad.uop.ops import UOp, Ops, python_alu from tinygrad.renderer import ProgramSpec from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.codegen import full_rewrite -from tinygrad.engine.realize import lower_schedule_item def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None): for x in inputs: x.realize() @@ -76,8 +75,8 @@ class TestCStyleFailures(unittest.TestCase): for _ in range(5): ret = python_alu[op](ret, Tensor.empty(1, dtype=dtype)) schedule = ret.schedule() assert len(schedule) == 1 - ei = lower_schedule_item(schedule[0]) - src = ei.prg.p.src + schedule[0].lower() + src = schedule[0].prg.p.src self.assertEqual("("*5 not in src, should_strip_paren) def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD) diff --git a/test/test_schedule.py b/test/test_schedule.py index d4c77fbc9f..10d0912344 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -13,7 +13,7 @@ from tinygrad.dtype import DType, ImageDType from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.schedule.rangeify import Kernel -from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule +from tinygrad.engine.realize import CompiledRunner, run_schedule class KernelCountException(Exception): pass def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True): @@ -24,8 +24,9 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te else: assert isinstance(t, UOp), f"can't schedule {t}" sched = Tensor(t).schedule() - # test lowering all the ScheduleItems to ExecItems - kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink]) + # test lowering all the ExecItems + for si in sched: si.lower() + kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink]) if kernel_cnt != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}") if DEBUG >= 3: @@ -174,7 +175,7 @@ class TestSchedule(unittest.TestCase): child.realize() assert a.uop.is_realized - # NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated. + # NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated. def test_childless_empty_never_allocates(self): a = Tensor.empty(10) a.realize() diff --git a/test/test_softmax_fusion.py b/test/test_softmax_fusion.py index da141205da..5588991079 100644 --- a/test/test_softmax_fusion.py +++ b/test/test_softmax_fusion.py @@ -3,7 +3,6 @@ import numpy as np from tinygrad import Tensor, GlobalCounters, Context, Device from tinygrad.dtype import DTypeLike, dtypes from tinygrad.helpers import DEBUG, get_single_element -from tinygrad.engine.realize import lower_schedule_item from tinygrad.device import is_dtype_supported def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor: @@ -27,7 +26,7 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te out = e.div(ss).reshape(x_in.shape) return out -def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run() +def run_one_schedule_item(out): get_single_element(out.schedule()).run() class TestFuse(unittest.TestCase): def _test_fuse(self, fxn, *args, atol=1e-6, allow_multiple=False, **kwargs): diff --git a/test/test_uops.py b/test/test_uops.py index a92939d7c0..d0583a05c6 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -9,7 +9,8 @@ from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType from tinygrad.uop.spec import shared_spec from tinygrad.renderer import ProgramSpec from tinygrad.renderer.cstyle import CStyleLanguage -from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem +from tinygrad.engine.realize import CompiledRunner, get_program, get_runner +from tinygrad.engine.schedule import ExecItem from tinygrad.codegen import full_rewrite from tinygrad.uop.symbolic import sym from tinygrad.device import is_dtype_supported @@ -568,7 +569,7 @@ class TestZeroRange(unittest.TestCase): class TestUOpPrograms(unittest.TestCase): def _run(self, prog:UOp, *tensors:Tensor): - ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True) + ExecItem(prog, [t.uop.buffer for t in tensors], prg=get_runner(Device.DEFAULT, prog)).run(wait=True) def test_simple(self): out = Tensor.empty(10,10,dtype=dtypes.int) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index 78c35e08ca..d349aea765 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -1,7 +1,8 @@ import unittest from tinygrad import Tensor from tinygrad.helpers import getenv, GlobalCounters, EMULATE -from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program +from tinygrad.engine.realize import get_program +from tinygrad.renderer import ProgramSpec from tinygrad.renderer import Estimates from tinygrad.codegen import full_rewrite from tinygrad.uop.ops import Ops, UOp @@ -17,9 +18,8 @@ def flops_mem(uops, ignore_indexing=False): # **************** new FlopCounter **************** def get_stats(x:Tensor): - si = x.schedule()[-1] - ei = lower_schedule_item(si) - return ei.prg.estimates.ops, ei.prg.estimates.mem + si = x.schedule()[-1].lower() + return si.prg.estimates.ops, si.prg.estimates.mem @unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does extra load/store for packed types") class TestMemoryCount(unittest.TestCase): diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index e3a898f802..51b74874ba 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -2,7 +2,8 @@ import unittest, math, time from tinygrad import Tensor, Device, dtypes, Context from tinygrad.uop.ops import UOp, Ops -from tinygrad.engine.realize import ExecItem, get_runner +from tinygrad.engine.realize import get_runner +from tinygrad.engine.schedule import ExecItem from tinygrad.engine.jit import TinyJit from tinygrad.helpers import CI import numpy as np @@ -64,7 +65,7 @@ class TestTK(unittest.TestCase): c = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b, c) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)]) + ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) c = c.float() @@ -113,7 +114,7 @@ class TestTK(unittest.TestCase): c = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b, c) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)]) + ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) c = c.float() @@ -149,7 +150,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -188,7 +189,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -230,7 +231,7 @@ class TestTK(unittest.TestCase): c = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b, c) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, c, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, c, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() c = c.float() @@ -270,7 +271,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -307,7 +308,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -352,7 +353,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -397,7 +398,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, M, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -442,7 +443,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -487,7 +488,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, M, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -547,7 +548,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, BLOCK_SIZE, N, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -607,7 +608,7 @@ class TestTK(unittest.TestCase): b = Tensor.empty(1, 1, N, BLOCK_SIZE, dtype="float32") Tensor.realize(a, b) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)]) + ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): ei.run(wait=True) b = b.float() @@ -717,7 +718,7 @@ class TestTK(unittest.TestCase): out = Tensor.empty(B, N, H, D, dtype=dtypes.bfloat16) Tensor.realize(q, k, v, out) - ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)]) + ei = ExecItem(sink, [t.uop.buffer for t in (out, q, k, v)], prg=get_runner(Device.DEFAULT, sink)) for _ in range(5): et = ei.run(wait=True) attn_flops = 2 * B * H * N * N * D + \ diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index f728eb7559..89a652ed96 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, Buf from tinygrad.engine.memory import _internal_memory_planner from tinygrad.nn.state import get_parameters from tinygrad.schedule.rangeify import mop_cleanup -from dataclasses import dataclass +from dataclasses import dataclass, replace from weakref import WeakKeyDictionary class GraphException(Exception): pass @@ -31,7 +31,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer] graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals) # clear jit inputs to allow their memory to be freed/reused for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None - graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers))) + graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner)) max_batch_size *= 2 if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}") except GraphException as e: @@ -89,6 +89,7 @@ class GraphRunner(Runner): estimates = Estimates() for j,ji in enumerate(jit_cache): + assert ji.prg is not None estimates += ji.prg.estimates if isinstance(ji.prg, CompiledRunner): if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars] @@ -103,6 +104,7 @@ class GraphRunner(Runner): self.w_dependency_map: dict[int, Any] = {} self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list) + assert jit_cache[0].prg is not None super().__init__(colored(f"", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify()) def updated_vars(self, var_vals: dict[str, int]): @@ -186,7 +188,7 @@ class CapturedJit(Generic[ReturnType]): def replan_buffers_memory_layout(self): blacklist = [t.uop.buffer for t in get_parameters(self.ret)] asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True) - self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] + self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache] for old, new in asgn.items(): if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer()) self.__post_init__() @@ -249,7 +251,7 @@ class TinyJit(Generic[ReturnType]): return ret def add(self, ei:ExecItem): - self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars)) + self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg)) def reset(self): assert self.fxn is not None, "can't reset without function" @@ -320,8 +322,7 @@ class TinyJit(Generic[ReturnType]): # Exclude buffers involved in transfer ops to preserve parallelism. noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs} assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ") - jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None], - item.metadata, item.fixedvars) for item in jit_cache] + jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache] input_replace = get_input_replace(jit_cache, input_buffers) if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found") diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 36a4e3b0ba..ae1d42544a 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -1,6 +1,6 @@ from typing import cast from collections import defaultdict -from tinygrad.engine.schedule import ScheduleItem +from tinygrad.engine.realize import ExecItem from tinygrad.device import Device, Buffer from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up from tinygrad.uop.ops import Ops @@ -63,8 +63,8 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign return assigned -def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: +def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. - assigned = _internal_memory_planner([list(si.bufs) for si in schedule], - noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule] + assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule], + noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None}) + return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule] diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a2edd2ad62..d2451e1c8b 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,4 +1,4 @@ -from typing import cast, Generator, Callable +from typing import cast, Callable import time, pprint, random, itertools, math from dataclasses import dataclass, replace, field from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey @@ -7,7 +7,6 @@ from tinygrad.helpers import unwrap from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates -from tinygrad.engine.schedule import ScheduleItem from tinygrad.codegen import full_rewrite from tinygrad.codegen.opt import Opt @@ -171,15 +170,43 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner: # **************** lowering functions **************** -@dataclass(frozen=True) +# NOTE: ctx is the buffers +si_lowerer = PatternMatcher([ + (UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)), + (UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])), + (UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ + if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ + else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))), + (UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: EncDec(encdec, ctx[0].nbytes, ctx[1].device)), +]) + +@dataclass class ExecItem: - prg: Runner - bufs: list[Buffer|None] - metadata: tuple[Metadata, ...]|None = None + ast: UOp + bufs: list[Buffer|None] = field(default_factory=list) + metadata: tuple[Metadata, ...] = () fixedvars: dict[str, int] = field(default_factory=dict) + prg: Runner|None = None + + def lower(self): + """Populate self.prg by lowering the AST.""" + if self.prg is not None: return self + try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs)) + except Exception as e: + if DEBUG >= 2: + print(f"error lowering {self.ast.op}") + print("tensor operations:") + pprint.pprint(self.metadata, indent=2) + raise e + return self + def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None: + if self.prg is None: self.lower() + assert self.prg is not None var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars) - bufs = [unwrap(x) for x in self.bufs] if jit else [unwrap(x).ensure_allocated() for x in self.bufs] + # reorder bufs to match program globals if needed + _bufs = [self.bufs[i] for i in self.prg.p.globals] if isinstance(self.prg, CompiledRunner) else self.bufs + bufs = cast(list[Buffer], [unwrap(x) for x in _bufs] if jit else [unwrap(x).ensure_allocated() for x in _bufs]) if PROFILE: payload = {"metadata":self.metadata, "var_vals":var_vals, "bufs":[b.trace_num for b in bufs], "name":self.prg.display_name} payload["outputs"], payload["inputs"] = (self.prg.p.outs, self.prg.p.ins) if isinstance(self.prg, CompiledRunner) else ([0], [1]) @@ -206,48 +233,28 @@ class ExecItem: self.prg.first_run = False return et -# NOTE: ctx is the buffers -si_lowerer = PatternMatcher([ - (UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])), - (UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))), - (UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \ - if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \ - else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))), - (UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: ((EncDec(encdec, ctx[0].nbytes, ctx[1].device)), list(ctx))), -]) -def lower_schedule_item(si:ScheduleItem) -> ExecItem: - return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars) - -def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]: - while len(schedule): - si = schedule.pop(0) - try: yield (si, lower_schedule_item(si)) - except Exception as e: - if DEBUG >= 2: - print(f"error lowering {si.ast.op}") - print("tensor operations:") - pprint.pprint(si.metadata, indent=2) - raise e - # **************** main run function **************** capturing: list = [] # put classes with an add method in here -def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True): - for si, ei in lower_schedule(schedule): +def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True): + while len(schedule): + ei = schedule.pop(0).lower() if len(capturing) and CAPTURING: capturing[0].add(ei) - if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK: + if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK: # copy in allocated buffers from the GPU - nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs) - for cpu_b, gpu_b in zip(nb, si.bufs): - if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer()) + bufs = [b for b in ei.bufs if b is not None] + nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs] + for cpu_b, gpu_b in zip(nb, bufs): + if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer()) # run on GPU ei.run(var_vals, do_update_stats=do_update_stats) # validate the output buffers match (NOTE: this is assuming the output is buffer 0) - with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats) + with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats) import numpy as np - np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) + assert nb[0] is not None + np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3) else: ei.run(var_vals, do_update_stats=do_update_stats) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index d2ee084711..a8c4a6845e 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,25 +1,17 @@ +from __future__ import annotations import time from typing import cast -from dataclasses import dataclass, field from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize - -# **** ScheduleItem return type - -@dataclass(frozen=True) -class ScheduleItem: - ast: UOp - bufs: tuple[Buffer, ...] = () - metadata: tuple[Metadata, ...] = () - fixedvars: dict[str, int] = field(default_factory=dict) +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize +from tinygrad.engine.realize import ExecItem # **** schedule linearizer -def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]: +def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]: with cpu_profile(TracingKey("toposort sched_sink")): # construct the KERNEL children graph based on assigns children: dict[UOp, list[UOp]] = {} @@ -70,7 +62,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]: if in_degree[x] == 0: queue.append(x) with cpu_profile(TracingKey("expand ranges")): - pre_schedule: list[ScheduleItem] = [] + pre_schedule: list[ExecItem] = [] buf_uops_list: list[UOp] = [] sched_ptr = 0 in_ranges: dict[UOp, int] = {} @@ -89,7 +81,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]: else: ast, buf_uops, metadata, fixedvars, bound_ranges = si fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges} - pre_schedule.append(ScheduleItem(ast, (), metadata, fixedvars)) + pre_schedule.append(ExecItem(ast, [], metadata, fixedvars)) buf_uops_list.append(UOp.sink(*buf_uops)) sched_ptr += 1 return pre_schedule, UOp.sink(*buf_uops_list) @@ -131,9 +123,9 @@ pm_post_sched_cache = PatternMatcher([ (UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)), ]) -schedule_cache: dict[bytes, tuple[list[ScheduleItem], UOp]] = {} +schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {} @track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}") -def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]: +def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]: # big_sink srcs are all the Tensors st = time.perf_counter() @@ -180,7 +172,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)} # add bufs to pre_schedule - schedule: list[ScheduleItem] = [] + schedule: list[ExecItem] = [] for i, si in enumerate(pre_schedule): buf_uops = buf_uops_sink.src[i].src # create subbuffers if needed @@ -193,10 +185,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer" dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num'] for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])): - schedule.append(ScheduleItem(si.ast, bufs, si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}))) + schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {}))) else: # ONE -> ONE - schedule.append(ScheduleItem(si.ast, cast(tuple[Buffer, ...], ubufs), si.metadata, si.fixedvars)) + schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars)) with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule) # extract var_vals from BINDs that were stripped (only if there are kernels) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 369675f4f7..c397bbd08b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -12,7 +12,7 @@ from tinygrad.gradient import compute_gradient from tinygrad.mixin import OpMixin from tinygrad.mixin.movement import _align_left from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable -from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars +from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule @@ -236,7 +236,7 @@ class Tensor(OpMixin): """ return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)] - def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]: + def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]: """ Creates the schedule needed to realize these Tensor(s), with Variables. @@ -249,7 +249,7 @@ class Tensor(OpMixin): _apply_map_to_tensors(becomes_map, name="Apply Schedule Map") return schedule, var_vals - def schedule(self, *lst:Tensor) -> list[ScheduleItem]: + def schedule(self, *lst:Tensor) -> list[ExecItem]: """Creates the schedule needed to realize these Tensor(s).""" schedule, var_vals = self.schedule_with_vars(*lst) assert len(var_vals) == 0