remove ScheduleItem and merge it with ExecItem (#13759)

* remove ExecItem and merge it with ScheduleItem

* less diff

* fix issues

* min diff

* don't change bufs in _lower

* min diff

* update

* revert

* fixes

* diff
This commit is contained in:
George Hotz
2025-12-19 17:04:24 -04:00
committed by GitHub
parent df6cde8a00
commit 744af193f0
35 changed files with 172 additions and 395 deletions

View File

@@ -88,7 +88,7 @@ VIZ=1 python -c "from tinygrad import Tensor; Tensor.ones(10).sum().realize()"
## Debugging Tips
1. **Print UOp graphs**: `print(tensor.uop)` or `print(tensor.uop.sink())`
2. **Check schedule**: `tensor.schedule()` returns list of ScheduleItems
2. **Check schedule**: `tensor.schedule()` returns list of ExecItems
3. **Trace graph rewrites**: Use `VIZ=1` or add print in PatternMatcher callbacks
4. **Find UOps by type**: `[u for u in uop.toposort() if u.op is Ops.SOMETHING]`

View File

@@ -38,25 +38,19 @@ optim.schedule_step() # this will step the optimizer without running realize
# The weight Tensors have been assigned to, but not yet realized. Everything is still lazy at this point
# l1.uop and l2.uop define a computation graph
from tinygrad.engine.schedule import ScheduleItem
schedule: List[ScheduleItem] = Tensor.schedule(l1, l2)
from tinygrad.engine.schedule import ExecItem
schedule: List[ExecItem] = Tensor.schedule(l1, l2)
print(f"The schedule contains {len(schedule)} items.")
for si in schedule: print(str(si)[:80])
# *****
# 4. Lower a schedule.
# 4. Lower and run the schedule.
from tinygrad.engine.realize import lower_schedule_item, ExecItem
lowered: List[ExecItem] = [lower_schedule_item(si) for si in tqdm(schedule)]
for si in tqdm(schedule): si.run()
# *****
# 5. Run the schedule
for ei in tqdm(lowered): ei.run()
# *****
# 6. Print the weight change
# 5. Print the weight change
print("first weight change\n", l1.numpy()-l1n)
print("second weight change\n", l2.numpy()-l2n)

View File

@@ -17,15 +17,15 @@ The `UOp` graph specifies the compute in terms of low level tinygrad ops. Not al
## Scheduling
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ScheduleItem`. One `ScheduleItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
The [scheduler](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/schedule.py) converts the graph of UOps into a list of `ExecItem`. One `ExecItem` is one kernel on the GPU, and the scheduler is responsible for breaking the large compute graph into subgraphs that can fit in a kernel. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
::: tinygrad.engine.schedule.ScheduleItem
::: tinygrad.engine.schedule.ExecItem
## Lowering
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
The code in [realize](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/engine/realize.py) lowers `ExecItem` by populating its `prg` field with
::: tinygrad.engine.realize.lower_schedule
::: tinygrad.engine.realize.run_schedule
There's a ton of complexity hidden behind this, see the `codegen/` directory.

View File

@@ -9,7 +9,6 @@ import xml.etree.ElementTree as ET
from tinygrad import nn, Tensor, Device
from tinygrad.helpers import get_single_element
from tinygrad.engine.realize import lower_schedule
from tinygrad.runtime.support.elf import elf_loader
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
@@ -118,7 +117,8 @@ if __name__ == "__main__":
root = ET.fromstring(xml_str)
a = Tensor.empty(16)+1
for si, ei in lower_schedule(a.schedule()):
for ei in a.schedule():
ei.lower()
# get text
_, hdr, _ = elf_loader(ei.prg.lib)
text = get_single_element([x for x in hdr if x.name==".text"]).content

View File

@@ -36,7 +36,7 @@ if __name__ == "__main__":
for _ in range(run_count): tc = (a@b).realize()
GlobalCounters.reset()
ei = ExecItem(runner, [a.uop.buffer, b.uop.buffer, c.uop.buffer])
ei = ExecItem(ast, [a.uop.buffer, b.uop.buffer, c.uop.buffer], prg=runner)
with Context(DEBUG=2):
for _ in range(run_count): ei.run(wait=True)
print(f"custom {(c-tc).square().mean().item()}")

View File

@@ -147,7 +147,7 @@ def test_matmul(sink:UOp, N=N):
hc = Tensor.empty(N, N)
Tensor.realize(a, b, hc)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in [hc, a, b]])
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))
ets = []
with Context(DEBUG=2):

View File

@@ -3,7 +3,6 @@ from tinygrad import dtypes, Tensor
from tinygrad.helpers import getenv, get_single_element
from tinygrad.dtype import _to_np_dtype
from tinygrad.codegen.opt import OptOps
from tinygrad.engine.realize import lower_schedule
dtype_in = (dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else
dtypes.fp8e4m3 if getenv("FP8E4M3") else dtypes.fp8e5m2 if getenv("FP8E5M2") else dtypes.float)
@@ -40,8 +39,8 @@ if __name__ == "__main__":
if getenv("SHOULD_USE_TC"):
sched = a.matmul(b, dtype=acc_dtype).schedule()
lowered = list(lower_schedule(sched))
ei = get_single_element(lowered)[1]
ei = get_single_element(sched)
ei.lower()
assert any(opt.op is OptOps.TC for opt in ei.prg.p.applied_opts), f"TC not triggered, {ei.prg.p.applied_opts}"
ref = a.numpy().astype(np.float32) @ b.numpy().astype(np.float32)

View File

@@ -33,5 +33,5 @@ if __name__ == "__main__":
new_src = prg.src
# can mod source here
prg = replace(prg, src=new_src)
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
for i in range(5): ei.run(wait=True)

View File

@@ -88,7 +88,7 @@ if __name__ == "__main__":
prg = ProgramSpec("matmul_kernel", src, device=Device.DEFAULT,
global_size=[M//BLOCK_SIZE_M, N//BLOCK_SIZE_N, 1], local_size=[32*compiled.metadata.num_warps, 1, 1],
mem_estimate=A.nbytes() + B.nbytes() + C.nbytes())
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
ei = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
tflops = []
for i in range(5):
tm = ei.run(wait=True)

View File

@@ -1,128 +0,0 @@
import numpy as np
import ctypes
from tinygrad import Tensor, GlobalCounters, Context
from tinygrad.engine.realize import lower_schedule, CompiledRunner
from tinygrad.device import CPUProgram
from dataclasses import replace
from keystone import Ks, KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN
# only the memory access, over 100 GB/s! (sometimes)
reduce_asm = """
movi v0.2d, #0000000000000000
mov w9, #0x30
mov w10, #0x20
mov x8, #-0x10
movi v1.2d, #0000000000000000
movk w9, #0x300, lsl #16
movi v2.2d, #0000000000000000
movk w10, #0x200, lsl #16
movi v3.2d, #0000000000000000
mov w11, #0x1000000
mov w12, #0x3ffff0
loop:
ldp q4, q5, [x1]
add x13, x1, x11
add x15, x1, x10
add x14, x1, x9
add x8, x8, #0x10
cmp x8, x12
ldp q6, q7, [x1, #0x20]
add x1, x1, #0x40
ldp q4, q5, [x13]
ldp q6, q7, [x13, #0x20]
ldp q4, q5, [x15, #-0x20]
ldp q6, q7, [x15]
ldp q4, q5, [x14, #-0x30]
ldp q6, q7, [x14, #-0x10]
b.lo loop
fadd v0.4s, v1.4s, v0.4s
fadd v0.4s, v2.4s, v0.4s
fadd v0.4s, v3.4s, v0.4s
dup v1.4s, v0.s[1]
dup v2.4s, v0.s[2]
fadd v1.4s, v0.4s, v1.4s
dup v0.4s, v0.s[3]
fadd v1.4s, v2.4s, v1.4s
fadd v0.4s, v0.4s, v1.4s
str s0, [x0]
ret
"""
ks = Ks(KS_ARCH_ARM64, KS_MODE_LITTLE_ENDIAN)
arm_bytecode, _ = ks.asm(reduce_asm)
arm_bytecode = bytes(arm_bytecode)
reduce_src = """
// data1 is 16M inputs
typedef float float4 __attribute__((aligned(32),vector_size(16)));
void reduce(float* restrict data0, float* restrict data1) {
float4 acc0 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc1 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc2 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc3 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc4 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc5 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc6 = {0.0f, 0.0f, 0.0f, 0.0f};
float4 acc7 = {0.0f, 0.0f, 0.0f, 0.0f};
float* data1_1 = data1+4194304;
float* data1_2 = data1+(4194304*2);
float* data1_3 = data1+(4194304*3);
for (int ridx0 = 0; ridx0 < 16777216/4; ridx0+=16) {
float4 val0 = *(float4*)((data1+(ridx0+0)));
float4 val1 = *(float4*)((data1+(ridx0+4)));
float4 val2 = *(float4*)((data1+(ridx0+8)));
float4 val3 = *(float4*)((data1+(ridx0+12)));
acc0 += val0;
acc1 += val1;
acc2 += val2;
acc3 += val3;
val0 = *(float4*)((data1_1+(ridx0+0)));
val1 = *(float4*)((data1_1+(ridx0+4)));
val2 = *(float4*)((data1_1+(ridx0+8)));
val3 = *(float4*)((data1_1+(ridx0+12)));
acc4 += val0;
acc5 += val1;
acc6 += val2;
acc7 += val3;
val0 = *(float4*)((data1_2+(ridx0+0)));
val1 = *(float4*)((data1_2+(ridx0+4)));
val2 = *(float4*)((data1_2+(ridx0+8)));
val3 = *(float4*)((data1_2+(ridx0+12)));
acc0 += val0;
acc1 += val1;
acc2 += val2;
acc3 += val3;
val0 = *(float4*)((data1_3+(ridx0+0)));
val1 = *(float4*)((data1_3+(ridx0+4)));
val2 = *(float4*)((data1_3+(ridx0+8)));
val3 = *(float4*)((data1_3+(ridx0+12)));
acc4 += val0;
acc5 += val1;
acc6 += val2;
acc7 += val3;
}
float4 out = acc0+acc1+acc2+acc3+acc4+acc5+acc6+acc7;
*(data0+0) = out[0]+out[1]+out[2]+out[3];
}
"""
if __name__ == "__main__":
a = Tensor(np_array:=(np.random.default_rng().random((4096, 4096), dtype=np.float32)-0.5)).realize()
with Context(SPLIT_REDUCEOP=0):
# TODO: make it easy to alter the OptOps for a ScheduleItem
GlobalCounters.reset()
out = a.sum()
sis = out.schedule()
for i,(_,ei) in enumerate(lower_schedule(sis)):
if i == 0:
# change the source code
prg_spec = ei.prg.p
prg_spec = replace(prg_spec, name="reduce", src=reduce_src)
prg = CompiledRunner(prg_spec)
# change the assembly
#prg._prg = CPUProgram(prg_spec.name, arm_bytecode)
print("buffer at:",hex(ctypes.addressof(ei.bufs[1]._buf)))
ei = replace(ei, prg=prg)
ei.run()
print(out.item())
np.testing.assert_allclose(out.item(), np_array.sum(), atol=1, rtol=1e-4)

View File

@@ -1,100 +0,0 @@
import itertools
import numpy as np
from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union
from tinygrad.device import Buffer
from tinygrad.engine.realize import capturing, lower_schedule_item
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
from tinygrad.uop.ops import Ops, UOp
from tinygrad.tensor import Tensor, _to_np_dtype
ctx_vars = { MULTIOUTPUT: (0, 1) }
FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
def fuzz_schedule(outs:List[UOp]):
# find toposorts across all tunable params
unique_ts: Dict[Tuple[LBScheduleItem, ...], Dict[str, int]] = {}
for combination in itertools.product(*ctx_vars.values()):
for var, val in zip(ctx_vars, combination): var.value = val
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
graph, in_degree, _ = _graph_schedule(outs)
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = ctx_var_values
toposorts = list(unique_ts.items())
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
# setup ground truth
ground_truth: Dict[UOp, memoryview] = {}
assign_targets: Dict[UOp, UOp] = {}
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
prerealized: Dict[UOp, memoryview] = {}
seed = Tensor._seed
ts,_ = toposorts[0]
for lsi in ts:
for out in lsi.outputs:
# freeze assign state before exec
if out.op is Ops.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in lsi.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
_exec_si(si, seed)
for out in lsi.outputs:
ground_truth[out] = out.buffer.as_buffer()
del out.srcs # only schedule the LazyBuffer in this fuzz run
# exec and validate each permutation with new Buffers
for i, (ts, ctx) in enumerate(toposorts[1:]):
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
rawbufs: Dict[UOp, Buffer] = {}
for lsi in ts:
for out in lsi.outputs:
base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base)
if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in lsi.inputs:
if x not in rawbufs:
# override the assign_target after ASSIGN
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
_exec_si(si, seed)
for out in lsi.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
except Exception as e:
print(f"FAILED FOR {out}")
raise e
def _exec_si(si:ScheduleItem, seed:int):
ei = lower_schedule_item(si)
if len(capturing): capturing[0].add(ei)
ei.run()
T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:Union[DefaultDict[T, int], Dict[T, int]]) -> List[Tuple[T, ...]]:
visited: Set[T] = set()
ret: List[Tuple[T, ...]] = []
path: List[T] = []
def recurse_paths(path:List[T]):
for v, d in in_degree.items():
if d != 0 or v in visited: continue
for u in graph[v]: in_degree[u] -= 1
path.append(v)
visited.add(v)
recurse_paths(path)
if len(ret) >= FUZZ_SCHEDULE_MAX_PATHS: return
# backtrack
for u in graph[v]: in_degree[u] += 1
path.pop()
visited.remove(v)
if len(path) == len(in_degree): ret.append(tuple(path))
recurse_paths(path)
if len(ret) == 0: raise RuntimeError("detected cycle in the graph")
# verify all paths are unique
assert len(ret) == len(set(ret))
return ret

View File

@@ -25,4 +25,4 @@ if __name__ == "__main__":
#k.apply_opt(Opt(OptOps.GROUP, 0, 32))
from tinygrad.engine.realize import CompiledRunner, ExecItem
run = CompiledRunner(prg:=get_program(k.ast, k.opts, k.applied_opts))
ExecItem(run, si.bufs).run()
ExecItem(k.ast, list(si.bufs), prg=run).run()

View File

@@ -1,10 +1,12 @@
import ctypes, unittest
from tinygrad.helpers import init_c_struct_t
from tinygrad.device import Device, Buffer, BufferXfer
from tinygrad.device import Device, Buffer
from tinygrad.dtype import dtypes
from tinygrad.runtime.support.hsa import AQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.schedule import ExecItem
from tinygrad.engine.realize import BufferXfer
from tinygrad.uop.ops import UOp, Ops
def get_hsa_inc_prog(dev, inc=1):
prg = f"""
@@ -102,7 +104,8 @@ class TestHSADriver(unittest.TestCase):
test_buf1.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(1*4)))
jit_cache = [ExecItem(BufferXfer(), [test_buf0, test_buf2]), ExecItem(BufferXfer(), [test_buf2, test_buf1])]
jit_cache = [ExecItem(UOp(Ops.NOOP), [test_buf0, test_buf2], prg=BufferXfer(test_buf0.nbytes, test_buf0.device, test_buf2.device)),
ExecItem(UOp(Ops.NOOP), [test_buf2, test_buf1], prg=BufferXfer(test_buf2.nbytes, test_buf2.device, test_buf1.device))]
graph = HSAGraph(jit_cache, [], {})
for i in range(10000):

View File

@@ -4,7 +4,9 @@ from tinygrad.device import Buffer, Device
from tinygrad.helpers import Context, getenv, from_mv
from tinygrad.dtype import dtypes
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner
from tinygrad.engine.realize import BufferXfer, get_runner
from tinygrad.engine.schedule import ExecItem
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.jit import apply_graph_to_jit
BUF_LEN = getenv("BUF_LEN", 128)
@@ -35,13 +37,13 @@ def gen_kernel_ji(device, deps):
assert len(deps) >= 2
out = alloc_rawbuffer(device)
prg = gen_prg(device, len(deps))
return ExecItem(prg, [out] + deps)
return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg)
def gen_copy_ji(device, deps):
assert len(deps) == 1
out = alloc_rawbuffer(device)
prg = BufferXfer(deps[0].nbytes, device, deps[0].device)
return ExecItem(prg, [out] + deps)
return ExecItem(UOp(Ops.NOOP), [out] + deps, prg=prg)
def gen_graph():
input_buffers = []

View File

@@ -3,7 +3,8 @@ import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device, Variable
from tinygrad.helpers import Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem
from tinygrad.uop.ops import Ops
from tinygrad.renderer import Estimates
from tinygrad.renderer.ptx import PTXRenderer
@@ -14,7 +15,7 @@ class TestArange(unittest.TestCase):
sched = tensor.schedule()
self.assertEqual(len(sched), 1)
p = get_program(sched[-1].ast, renderer=Device[Device.DEFAULT].renderer)
ExecItem(CompiledRunner(p), [tensor.uop.buffer]).run()
ExecItem(sched[-1].ast, [tensor.uop.buffer], prg=CompiledRunner(p)).run()
np.testing.assert_equal(tensor.numpy(), desired)
return p.estimates.ops

View File

@@ -2,13 +2,12 @@ import unittest, io
from contextlib import redirect_stdout
from tinygrad import Tensor, dtypes, Device
from tinygrad.helpers import OSX, CPU_LLVM, CPU_LVP
from tinygrad.engine.realize import lower_schedule
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import get_program
class TestCompileFailures(unittest.TestCase):
def compile(self, out:Tensor):
for _ in lower_schedule(out.schedule()): pass
for si in out.schedule(): si.lower()
@unittest.skipUnless(is_dtype_supported(dtypes.uchar, Device.DEFAULT), f"no uint8 on {Device.DEFAULT}")
def test_interpolate_atari(self):

View File

@@ -2,7 +2,7 @@ import unittest
import time
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.engine.realize import lower_schedule_item, run_schedule
from tinygrad.engine.realize import run_schedule
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):
@@ -26,9 +26,9 @@ class TestFusionOp(unittest.TestCase):
a = Tensor([1,2,3,4])
for _ in range(24): a = a + a
sched = a.schedule()
ei = lower_schedule_item(sched[-1])
sched[-1].lower()
self.assertLess(time.perf_counter()-st, 2.0)
assert len(ei.prg.p.src.splitlines()) < 250
assert len(sched[-1].prg.p.src.splitlines()) < 250
def test_recursive_add_cmp(self):
st = time.perf_counter()

View File

@@ -6,7 +6,9 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import Context, dedup, from_mv
from tinygrad.dtype import dtypes
from tinygrad.engine.jit import MultiGraphRunner
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner
from tinygrad.engine.realize import BufferXfer, get_runner, CompiledRunner
from tinygrad.engine.schedule import ExecItem
from tinygrad.uop.ops import UOp, Ops
from test.helpers import needs_second_gpu
@@ -27,11 +29,11 @@ def helper_exec_op(device, outbuf, inbufs):
prg = get_runner(device, si.ast)
cached_prgs[(device, len(inbufs))] = prg
return ExecItem(cached_prgs[(device, len(inbufs))], [outbuf] + inbufs)
return ExecItem(UOp(Ops.NOOP), [outbuf] + inbufs, prg=cached_prgs[(device, len(inbufs))])
def helper_copy_op(device, dest, src):
prg = BufferXfer(dest.nbytes, device, src.device)
return ExecItem(prg, [dest, src])
return ExecItem(UOp(Ops.NOOP), [dest, src], prg=prg)
def helper_alloc_rawbuffer(device, fill=False):
rawbuf = Buffer(device, BUF_SIZE, dtypes.int).ensure_allocated()
@@ -69,7 +71,7 @@ def helper_test_graphs(graph_impl, graphs, runs=RUN_CNT):
ground_truth_np = [np.frombuffer(x, _to_np_dtype(bufs[i].dtype)) for i,x in enumerate(ground_thruth_bufs)]
# Build graphs
gr_ji = [ExecItem(graph_impl(graph, [], {}), []) for graph in graphs]
gr_ji = [ExecItem(UOp(Ops.NOOP), [], prg=graph_impl(graph, [], {})) for graph in graphs]
for _ in range(runs):
test_bufs = helper_run_jit(gr_ji, bufs, out_buffers)

View File

@@ -3,7 +3,6 @@ import numpy as np
from tinygrad import Device, dtypes, Tensor, Context
from tinygrad.device import LRUAllocator, is_dtype_supported
from tinygrad.dtype import ImageDType
from tinygrad.engine.realize import lower_schedule
from tinygrad.helpers import prod, unwrap
from test.helpers import REAL_DEV
@@ -129,8 +128,8 @@ class TestImageDType(unittest.TestCase):
loss = x.image_dot(w1).image_dot(w2).float().max()
loss.backward()
sched = unwrap(w1.grad).schedule()
for s,(_,ei) in zip(sched, lower_schedule(sched[:])):
ei.run()
for s in sched:
s.run()
if s.bufs[0].dtype == dtypes.float:
lst = s.bufs[0].as_buffer().cast("f").tolist()
print(lst)

View File

@@ -7,7 +7,7 @@ from tinygrad.codegen.gpudims import get_grouped_dims
from tinygrad.uop.ops import UOp, Ops, GroupOp, AxisType, PatternMatcher, graph_rewrite, UPat
from tinygrad.device import Device, Buffer, is_dtype_supported
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner, get_program
from tinygrad.engine.realize import run_schedule, CompiledRunner, get_program
from tinygrad.helpers import Context, flatten, dedup, TC_SELECT, TC_OPT, getenv
from tinygrad.dtype import DType, dtypes, PtrDType, AddrSpace
from tinygrad.renderer.ptx import PTXRenderer
@@ -25,9 +25,9 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = [x[1] for x in lower_schedule(c.schedule())]
for ei in lowered: ei.run()
rawbufs = lowered[-1].bufs
sched = c.schedule()
for si in sched: si.run()
rawbufs = sched[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.uop.base.realized, b.uop.base.realized}
np_c = (np_a[:2] - np_a[2:]) - (np_b[:2] - np_b[2:])
np.testing.assert_allclose(np_c, c.numpy(), atol=1e-4, rtol=1e-4)

View File

@@ -4,7 +4,7 @@ from tinygrad.device import is_dtype_supported
from tinygrad.uop.ops import Ops, UOp
from tinygrad.helpers import getenv, prod, Context
from tinygrad.nn.state import get_parameters, get_state_dict
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
from tinygrad.engine.realize import BufferCopy, CompiledRunner, run_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
from test.helpers import not_support_multi_device, needs_second_gpu, slow
@@ -123,9 +123,10 @@ class TestMultiTensor(unittest.TestCase):
out = (X + X)
sched = out.schedule()
names = []
for si, ei in lower_schedule(sched):
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
ei.run()
for si in sched:
si.lower()
if isinstance(si.prg, CompiledRunner): names.append(si.prg.p.name)
si.run()
self.assertEqual(len(set(names)), 1, "function was relinearized")
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")

View File

@@ -3,7 +3,8 @@ import unittest
from tinygrad import Tensor, Device
from tinygrad.helpers import get_single_element
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem, get_program
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem
class TestOptGemm(unittest.TestCase):
@classmethod
@@ -18,7 +19,7 @@ class TestOptGemm(unittest.TestCase):
# TODO: this should be a generic test helper
si = get_single_element(t.schedule())
run = CompiledRunner(get_program(si.ast, renderer=Device[Device.DEFAULT].renderer, opts=opts))
ExecItem(run, si.bufs).run()
ExecItem(si.ast, list(si.bufs), prg=run).run()
test = si.bufs[0].numpy().reshape(self.res.shape)
np.testing.assert_allclose(self.res, test, atol=1e-4)

View File

@@ -5,7 +5,8 @@ from dataclasses import replace
from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.uop.ops import Ops
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item, get_program
from tinygrad.engine.realize import CompiledRunner, get_program
from tinygrad.engine.schedule import ExecItem
N = 512
@@ -42,8 +43,8 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
if replace_src is not None:
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
prg = replace(prg, src=replace_src + "/* DSP boilerplate */" + prg.src.split("/* DSP boilerplate */")[1].replace(old_name, "fxn"))
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for _ in range(run_count): ei.run(wait=True)
new_si = ExecItem(si.ast, [x.ensure_allocated() for x in si.bufs], si.metadata, prg=CompiledRunner(prg))
for _ in range(run_count): new_si.run(wait=True)
def get_quantized_model(sz):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
@@ -74,8 +75,8 @@ class TestQuantizeOnnxCPU(unittest.TestCase):
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
with Context(QUANTIZE=1):
sched = run_onnx({"input":inp})["output"].schedule()
ei = lower_schedule_item(sched[-2])
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_REG]
sched[-2].lower()
daccs = [u for u in sched[-2].prg.p.uops if u.op is Ops.DEFINE_REG]
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")

View File

@@ -4,7 +4,7 @@ from functools import partial
from tinygrad import nn, dtypes, Tensor, Device, TinyJit, Variable
from tinygrad.helpers import getenv, CI, OSX
from tinygrad.device import is_dtype_supported
from tinygrad.engine.realize import lower_schedule, CompiledRunner
from tinygrad.engine.realize import CompiledRunner
from tinygrad.renderer.ptx import PTXRenderer
from tinygrad.renderer.nir import NIRRenderer
from test.helpers import not_support_multi_device, needs_second_gpu
@@ -103,10 +103,12 @@ class TestRandomness(unittest.TestCase):
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "PTX and NIR use pointer arithmetic")
def test_threefry_doesnt_use_long(self):
for (_,ei) in lower_schedule(Tensor.rand(20).schedule()):
if isinstance(ei.prg, CompiledRunner):
for u in ei.prg.p.uops:
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")
sched = Tensor.rand(20).schedule()
for si in sched:
si.lower()
if isinstance(si.prg, CompiledRunner):
for u in si.prg.p.uops:
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {si.prg.p.name}")
def test_threefry_against_reference_full(self):
Tensor.manual_seed(1337)

View File

@@ -12,7 +12,6 @@ from tinygrad.uop.ops import UOp, Ops, python_alu
from tinygrad.renderer import ProgramSpec
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.codegen import full_rewrite
from tinygrad.engine.realize import lower_schedule_item
def _test_uop_result(inputs:list[Tensor], stores:list[UOp], local_size=None):
for x in inputs: x.realize()
@@ -76,8 +75,8 @@ class TestCStyleFailures(unittest.TestCase):
for _ in range(5): ret = python_alu[op](ret, Tensor.empty(1, dtype=dtype))
schedule = ret.schedule()
assert len(schedule) == 1
ei = lower_schedule_item(schedule[0])
src = ei.prg.p.src
schedule[0].lower()
src = schedule[0].prg.p.src
self.assertEqual("("*5 not in src, should_strip_paren)
def test_repeat_add(self): self._test_src_strip_paren(Ops.ADD)

View File

@@ -13,7 +13,7 @@ from tinygrad.dtype import DType, ImageDType
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.schedule.rangeify import Kernel
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from tinygrad.engine.realize import CompiledRunner, run_schedule
class KernelCountException(Exception): pass
def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Tensor]|None=None, filter_sink=True):
@@ -24,8 +24,9 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
else:
assert isinstance(t, UOp), f"can't schedule {t}"
sched = Tensor(t).schedule()
# test lowering all the ScheduleItems to ExecItems
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
# test lowering all the ExecItems
for si in sched: si.lower()
kernel_cnt = len([si for si in sched if isinstance(si.prg, CompiledRunner) or not filter_sink])
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
if DEBUG >= 3:
@@ -174,7 +175,7 @@ class TestSchedule(unittest.TestCase):
child.realize()
assert a.uop.is_realized
# NOTE: because empty does not have an ExecItem if realize is called on a childless empty, it never gets allocated.
# NOTE: because empty does not have a lowered ExecItem if realize is called on a childless empty, it never gets allocated.
def test_childless_empty_never_allocates(self):
a = Tensor.empty(10)
a.realize()

View File

@@ -3,7 +3,6 @@ import numpy as np
from tinygrad import Tensor, GlobalCounters, Context, Device
from tinygrad.dtype import DTypeLike, dtypes
from tinygrad.helpers import DEBUG, get_single_element
from tinygrad.engine.realize import lower_schedule_item
from tinygrad.device import is_dtype_supported
def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Tensor:
@@ -27,7 +26,7 @@ def single_kernel_softmax(x_in:Tensor, axis=-1, dtype:DTypeLike|None=None) -> Te
out = e.div(ss).reshape(x_in.shape)
return out
def run_one_schedule_item(out): lower_schedule_item(get_single_element(out.schedule())).run()
def run_one_schedule_item(out): get_single_element(out.schedule()).run()
class TestFuse(unittest.TestCase):
def _test_fuse(self, fxn, *args, atol=1e-6, allow_multiple=False, **kwargs):

View File

@@ -9,7 +9,8 @@ from tinygrad.uop.ops import Ops, UOp, UPat, KernelInfo, exec_alu, AxisType
from tinygrad.uop.spec import shared_spec
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer.cstyle import CStyleLanguage
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner, ExecItem
from tinygrad.engine.realize import CompiledRunner, get_program, get_runner
from tinygrad.engine.schedule import ExecItem
from tinygrad.codegen import full_rewrite
from tinygrad.uop.symbolic import sym
from tinygrad.device import is_dtype_supported
@@ -568,7 +569,7 @@ class TestZeroRange(unittest.TestCase):
class TestUOpPrograms(unittest.TestCase):
def _run(self, prog:UOp, *tensors:Tensor):
ExecItem(get_runner(Device.DEFAULT, prog), [t.uop.buffer for t in tensors]).run(wait=True)
ExecItem(prog, [t.uop.buffer for t in tensors], prg=get_runner(Device.DEFAULT, prog)).run(wait=True)
def test_simple(self):
out = Tensor.empty(10,10,dtype=dtypes.int)

View File

@@ -1,7 +1,8 @@
import unittest
from tinygrad import Tensor
from tinygrad.helpers import getenv, GlobalCounters, EMULATE
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec, get_program
from tinygrad.engine.realize import get_program
from tinygrad.renderer import ProgramSpec
from tinygrad.renderer import Estimates
from tinygrad.codegen import full_rewrite
from tinygrad.uop.ops import Ops, UOp
@@ -17,9 +18,8 @@ def flops_mem(uops, ignore_indexing=False):
# **************** new FlopCounter ****************
def get_stats(x:Tensor):
si = x.schedule()[-1]
ei = lower_schedule_item(si)
return ei.prg.estimates.ops, ei.prg.estimates.mem
si = x.schedule()[-1].lower()
return si.prg.estimates.ops, si.prg.estimates.mem
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does extra load/store for packed types")
class TestMemoryCount(unittest.TestCase):

View File

@@ -2,7 +2,8 @@ import unittest, math, time
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import ExecItem
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import CI
import numpy as np
@@ -64,7 +65,7 @@ class TestTK(unittest.TestCase):
c = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b, c)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
c = c.float()
@@ -113,7 +114,7 @@ class TestTK(unittest.TestCase):
c = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b, c)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (c, a, b)])
ei = ExecItem(sink, [t.uop.buffer for t in (c, a, b)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
c = c.float()
@@ -149,7 +150,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -188,7 +189,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -230,7 +231,7 @@ class TestTK(unittest.TestCase):
c = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b, c)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, c, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, c, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
c = c.float()
@@ -270,7 +271,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -307,7 +308,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -352,7 +353,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -397,7 +398,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, M, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -442,7 +443,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -487,7 +488,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, M, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -547,7 +548,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, BLOCK_SIZE, N, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -607,7 +608,7 @@ class TestTK(unittest.TestCase):
b = Tensor.empty(1, 1, N, BLOCK_SIZE, dtype="float32")
Tensor.realize(a, b)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (b, a)])
ei = ExecItem(sink, [t.uop.buffer for t in (b, a)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5): ei.run(wait=True)
b = b.float()
@@ -717,7 +718,7 @@ class TestTK(unittest.TestCase):
out = Tensor.empty(B, N, H, D, dtype=dtypes.bfloat16)
Tensor.realize(q, k, v, out)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)])
ei = ExecItem(sink, [t.uop.buffer for t in (out, q, k, v)], prg=get_runner(Device.DEFAULT, sink))
for _ in range(5):
et = ei.run(wait=True)
attn_flops = 2 * B * H * N * N * D + \

View File

@@ -9,7 +9,7 @@ from tinygrad.engine.realize import ExecItem, capturing, ViewOp, BufferCopy, Buf
from tinygrad.engine.memory import _internal_memory_planner
from tinygrad.nn.state import get_parameters
from tinygrad.schedule.rangeify import mop_cleanup
from dataclasses import dataclass
from dataclasses import dataclass, replace
from weakref import WeakKeyDictionary
class GraphException(Exception): pass
@@ -31,7 +31,7 @@ def apply_graph_to_jit(jit_cache: list[ExecItem], input_rawbuffers: list[Buffer]
graph_runner = current_batch_devs[0].graph(current_batch, input_rawbuffers, var_vals)
# clear jit inputs to allow their memory to be freed/reused
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(graph_runner, cast(list[Buffer|None], input_rawbuffers)))
graphed_jit_cache.append(ExecItem(UOp(Ops.NOOP), cast(list[Buffer|None], input_rawbuffers), prg=graph_runner))
max_batch_size *= 2
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_batch_devs[0]}")
except GraphException as e:
@@ -89,6 +89,7 @@ class GraphRunner(Runner):
estimates = Estimates()
for j,ji in enumerate(jit_cache):
assert ji.prg is not None
estimates += ji.prg.estimates
if isinstance(ji.prg, CompiledRunner):
if ji.prg.p.vars: self.var_vals_replace[j] = [(i, self.vars.index(v.expr)) for i, v in enumerate(ji.prg.p.vars) if v.expr not in ji.fixedvars]
@@ -103,6 +104,7 @@ class GraphRunner(Runner):
self.w_dependency_map: dict[int, Any] = {}
self.r_dependency_map: dict[int, list[Any]] = collections.defaultdict(list)
assert jit_cache[0].prg is not None
super().__init__(colored(f"<batched {len(jit_cache)}>", "cyan"), jit_cache[0].prg.device.split(":")[0], estimates.simplify())
def updated_vars(self, var_vals: dict[str, int]):
@@ -186,7 +188,7 @@ class CapturedJit(Generic[ReturnType]):
def replan_buffers_memory_layout(self):
blacklist = [t.uop.buffer for t in get_parameters(self.ret)]
asgn = _internal_memory_planner([[b for item in self.jit_cache for b in item.bufs if b is not None and b not in blacklist]], ignore_checks=True)
self.jit_cache = [ExecItem(item.prg, [asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
self.jit_cache = [replace(item, bufs=[asgn.get(b,b) if b is not None else None for b in item.bufs]) for item in self.jit_cache]
for old, new in asgn.items():
if old.is_allocated(): new.ensure_allocated().copyin(old.as_buffer())
self.__post_init__()
@@ -249,7 +251,7 @@ class TinyJit(Generic[ReturnType]):
return ret
def add(self, ei:ExecItem):
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars))
self._jit_cache.append(ExecItem(ei.ast, [self.add_buffer(buf) for buf in ei.bufs if buf is not None], ei.metadata, ei.fixedvars, ei.prg))
def reset(self):
assert self.fxn is not None, "can't reset without function"
@@ -320,8 +322,7 @@ class TinyJit(Generic[ReturnType]):
# Exclude buffers involved in transfer ops to preserve parallelism.
noopt_buffers = {b for ji in jit_cache if isinstance(ji.prg, (BufferXfer, BufferCopy, EncDec)) for b in ji.bufs}
assigned = _internal_memory_planner([cast(list[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None],
item.metadata, item.fixedvars) for item in jit_cache]
jit_cache = [replace(item, bufs=[assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")

View File

@@ -1,6 +1,6 @@
from typing import cast
from collections import defaultdict
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.engine.realize import ExecItem
from tinygrad.device import Device, Buffer
from tinygrad.helpers import NO_MEMORY_PLANNER, dedup, DEBUG, round_up
from tinygrad.uop.ops import Ops
@@ -63,8 +63,8 @@ def _internal_memory_planner(buffers:list[list[Buffer]], noopt_buffers=None, ign
return assigned
def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
def memory_planner(schedule:list[ExecItem]) -> list[ExecItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
assigned = _internal_memory_planner([[b for b in si.bufs if b is not None] for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs if b is not None})
return [ExecItem(si.ast, [assigned.get(x, x) if x is not None else None for x in si.bufs], si.metadata, si.fixedvars) for si in schedule]

View File

@@ -1,4 +1,4 @@
from typing import cast, Generator, Callable
from typing import cast, Callable
import time, pprint, random, itertools, math
from dataclasses import dataclass, replace, field
from tinygrad.helpers import all_same, colored, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA, TracingKey
@@ -7,7 +7,6 @@ from tinygrad.helpers import unwrap
from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, sym_infer, graph_rewrite, print_uops, track_rewrites, KernelInfo, pyrender
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
from tinygrad.engine.schedule import ScheduleItem
from tinygrad.codegen import full_rewrite
from tinygrad.codegen.opt import Opt
@@ -171,15 +170,43 @@ def get_runner(device:str, ast:UOp) -> CompiledRunner:
# **************** lowering functions ****************
@dataclass(frozen=True)
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: get_runner(ctx[0].device, sink)),
(UPat(Ops.BUFFER_VIEW), lambda ctx: ViewOp(ctx[0])),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: (BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device))),
(UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: EncDec(encdec, ctx[0].nbytes, ctx[1].device)),
])
@dataclass
class ExecItem:
prg: Runner
bufs: list[Buffer|None]
metadata: tuple[Metadata, ...]|None = None
ast: UOp
bufs: list[Buffer|None] = field(default_factory=list)
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict)
prg: Runner|None = None
def lower(self):
"""Populate self.prg by lowering the AST."""
if self.prg is not None: return self
try: self.prg = cast(Runner, si_lowerer.rewrite(self.ast, self.bufs))
except Exception as e:
if DEBUG >= 2:
print(f"error lowering {self.ast.op}")
print("tensor operations:")
pprint.pprint(self.metadata, indent=2)
raise e
return self
def run(self, _var_vals:dict[str, int]|None=None, wait=False, jit=False, do_update_stats=True) -> float|None:
if self.prg is None: self.lower()
assert self.prg is not None
var_vals = self.fixedvars if _var_vals is None else (_var_vals|self.fixedvars)
bufs = [unwrap(x) for x in self.bufs] if jit else [unwrap(x).ensure_allocated() for x in self.bufs]
# reorder bufs to match program globals if needed
_bufs = [self.bufs[i] for i in self.prg.p.globals] if isinstance(self.prg, CompiledRunner) else self.bufs
bufs = cast(list[Buffer], [unwrap(x) for x in _bufs] if jit else [unwrap(x).ensure_allocated() for x in _bufs])
if PROFILE:
payload = {"metadata":self.metadata, "var_vals":var_vals, "bufs":[b.trace_num for b in bufs], "name":self.prg.display_name}
payload["outputs"], payload["inputs"] = (self.prg.p.outs, self.prg.p.ins) if isinstance(self.prg, CompiledRunner) else ([0], [1])
@@ -206,48 +233,28 @@ class ExecItem:
self.prg.first_run = False
return et
# NOTE: ctx is the buffers
si_lowerer = PatternMatcher([
(UPat(Ops.SINK, name="sink"), lambda ctx,sink: (runner:=get_runner(ctx[0].device, sink), [ctx[x] for x in runner.p.globals])),
(UPat(Ops.BUFFER_VIEW), lambda ctx: (ViewOp(ctx[0]), list(ctx))),
(UPat(Ops.COPY, name="copy"), lambda ctx,copy: ((BufferXfer(ctx[0].nbytes, ctx[0].device, ctx[1].device) \
if hasattr(Device[ctx[0].device].allocator, '_transfer') and all_same([x.device.split(":")[0] for x in ctx]) \
else BufferCopy(ctx[0].nbytes, ctx[0].device, ctx[1].device)), list(ctx))),
(UPat(Ops.ENCDEC, name="encdec"), lambda ctx,encdec: ((EncDec(encdec, ctx[0].nbytes, ctx[1].device)), list(ctx))),
])
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata, si.fixedvars)
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
while len(schedule):
si = schedule.pop(0)
try: yield (si, lower_schedule_item(si))
except Exception as e:
if DEBUG >= 2:
print(f"error lowering {si.ast.op}")
print("tensor operations:")
pprint.pprint(si.metadata, indent=2)
raise e
# **************** main run function ****************
capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ScheduleItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
for si, ei in lower_schedule(schedule):
def run_schedule(schedule:list[ExecItem], var_vals:dict[str, int]|None=None, do_update_stats=True):
while len(schedule):
ei = schedule.pop(0).lower()
if len(capturing) and CAPTURING: capturing[0].add(ei)
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
if VALIDATE_WITH_CPU and ei.ast.op is Ops.SINK:
# copy in allocated buffers from the GPU
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
for cpu_b, gpu_b in zip(nb, si.bufs):
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
bufs = [b for b in ei.bufs if b is not None]
nb: list[Buffer|None] = [Buffer("CPU", b.size, b.dtype) for b in bufs]
for cpu_b, gpu_b in zip(nb, bufs):
if cpu_b is not None and gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
# run on GPU
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
with Context(BEAM=0): lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata, si.fixedvars)).run(var_vals, do_update_stats=do_update_stats)
with Context(BEAM=0): ExecItem(ei.ast, nb, ei.metadata, ei.fixedvars).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
np.testing.assert_allclose(si.bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
assert nb[0] is not None
np.testing.assert_allclose(bufs[0].numpy(), nb[0].numpy(), rtol=1e-3, atol=1e-3)
else:
ei.run(var_vals, do_update_stats=do_update_stats)

View File

@@ -1,25 +1,17 @@
from __future__ import annotations
import time
from typing import cast
from dataclasses import dataclass, field
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites
from tinygrad.uop.ops import PatternMatcher, UPat, graph_rewrite, graph_rewrite_map
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import Metadata, DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
# **** ScheduleItem return type
@dataclass(frozen=True)
class ScheduleItem:
ast: UOp
bufs: tuple[Buffer, ...] = ()
metadata: tuple[Metadata, ...] = ()
fixedvars: dict[str, int] = field(default_factory=dict)
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize
from tinygrad.engine.realize import ExecItem
# **** schedule linearizer
def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
with cpu_profile(TracingKey("toposort sched_sink")):
# construct the KERNEL children graph based on assigns
children: dict[UOp, list[UOp]] = {}
@@ -70,7 +62,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
if in_degree[x] == 0: queue.append(x)
with cpu_profile(TracingKey("expand ranges")):
pre_schedule: list[ScheduleItem] = []
pre_schedule: list[ExecItem] = []
buf_uops_list: list[UOp] = []
sched_ptr = 0
in_ranges: dict[UOp, int] = {}
@@ -89,7 +81,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ScheduleItem], UOp]:
else:
ast, buf_uops, metadata, fixedvars, bound_ranges = si
fixedvars = fixedvars | {s.src[0].arg[0]:in_ranges[s.src[1]] for s in bound_ranges}
pre_schedule.append(ScheduleItem(ast, (), metadata, fixedvars))
pre_schedule.append(ExecItem(ast, [], metadata, fixedvars))
buf_uops_list.append(UOp.sink(*buf_uops))
sched_ptr += 1
return pre_schedule, UOp.sink(*buf_uops_list)
@@ -131,9 +123,9 @@ pm_post_sched_cache = PatternMatcher([
(UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR),), name="b"), lambda ctx,b: ctx.get(b)),
])
schedule_cache: dict[bytes, tuple[list[ScheduleItem], UOp]] = {}
schedule_cache: dict[bytes, tuple[list[ExecItem], UOp]] = {}
@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[1]))}")
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ScheduleItem], dict[str, int]]:
def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], list[ExecItem], dict[str, int]]:
# big_sink srcs are all the Tensors
st = time.perf_counter()
@@ -180,7 +172,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
tensor_map = {tm_src[i]:tm_src[i+1] for i in range(0, len(tm_src), 2)}
# add bufs to pre_schedule
schedule: list[ScheduleItem] = []
schedule: list[ExecItem] = []
for i, si in enumerate(pre_schedule):
buf_uops = buf_uops_sink.src[i].src
# create subbuffers if needed
@@ -193,10 +185,10 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
assert all(isinstance(x, MultiBuffer) for x in ubufs), "kernel must all be multibuffer"
dnums = [x for x in si.ast.variables() if x.arg[0] == '_device_num']
for j, bufs in enumerate(zip(*[x.bufs for x in cast(tuple[MultiBuffer, ...], ubufs)])):
schedule.append(ScheduleItem(si.ast, bufs, si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
schedule.append(ExecItem(si.ast, list(bufs), si.metadata, si.fixedvars | ({dnums[0].expr:j} if len(dnums) else {})))
else:
# ONE -> ONE
schedule.append(ScheduleItem(si.ast, cast(tuple[Buffer, ...], ubufs), si.metadata, si.fixedvars))
schedule.append(ExecItem(si.ast, list(ubufs), si.metadata, si.fixedvars))
with cpu_profile(TracingKey("memory planner")): schedule = memory_planner(schedule)
# extract var_vals from BINDs that were stripped (only if there are kernels)

View File

@@ -12,7 +12,7 @@ from tinygrad.gradient import compute_gradient
from tinygrad.mixin import OpMixin
from tinygrad.mixin.movement import _align_left
from tinygrad.uop.ops import smax, smin, resolve, UOp, Ops, sint, identity_element, all_metadata, _index_to_concrete_int, sint_to_uop, Variable
from tinygrad.engine.schedule import ScheduleItem, complete_create_schedule_with_vars
from tinygrad.engine.schedule import ExecItem, complete_create_schedule_with_vars
from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule
@@ -236,7 +236,7 @@ class Tensor(OpMixin):
"""
return [Tensor(u) for u in UOp.custom_kernel(*[t.uop for t in (self,)+lst], fxn=fxn, grad_fxn=grad_fxn)]
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[str, int]]:
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ExecItem], dict[str, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
@@ -249,7 +249,7 @@ class Tensor(OpMixin):
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
return schedule, var_vals
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
def schedule(self, *lst:Tensor) -> list[ExecItem]:
"""Creates the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst)
assert len(var_vals) == 0