From 9eef44521bf27dab92ac3dbcd2f8ae979209e20f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 29 Mar 2024 20:50:27 -0700 Subject: [PATCH] ScheduleItem uses Buffer (#3995) * schedule Buffer * update * update tests * master * works * remove LoadOps.WAIT * fix compile2 * bad test * rename and note --- openpilot/compile2.py | 8 ++++---- test/test_conv_shapetracker.py | 8 +++----- test/test_linearizer.py | 4 ++-- test/test_multitensor.py | 2 +- test/test_search.py | 7 ++++--- tinygrad/engine/realize.py | 13 +++++-------- tinygrad/engine/schedule.py | 28 +++++++++++++++++++++------- tinygrad/lazy.py | 9 +++------ tinygrad/ops.py | 9 ++++----- 9 files changed, 47 insertions(+), 41 deletions(-) diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 6148f24b94..929354c9f5 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -36,7 +36,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: schedule = create_schedule([ret.lazydata]) # filter schedule that don't depend on the inputs - input_lb = [x.lazydata.base for x in inputs.values()] + input_lb = [x.lazydata.base.buffer for x in inputs.values()] depends = set(input_lb) for si in schedule: if any(b in depends for b in si.inputs): @@ -89,10 +89,10 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s # run code (all buffers have been allocated) GlobalCounters.reset() - for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {}) + for si in schedule: lower_schedule_item(si)(si.outputs+si.inputs, {}) - new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy() - np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2) + new_tinygrad_out = np.frombuffer(schedule[-1].outputs[0].as_buffer(), dtype=schedule[-1].outputs[0].dtype.np) + np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2) print("semi-thneed self-test passed!") if __name__ == "__main__": diff --git a/test/test_conv_shapetracker.py b/test/test_conv_shapetracker.py index d8447bb593..f36a32c166 100644 --- a/test/test_conv_shapetracker.py +++ b/test/test_conv_shapetracker.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps +from tinygrad.ops import LoadOps, BufferOps from tinygrad.nn import Conv2d from tinygrad.engine.schedule import create_schedule @@ -15,10 +15,8 @@ class TestConvShapetracker(unittest.TestCase): # run it again to get the kernels sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast[0].op not in LoadOps] assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}" - print(sched[0]) - for arg in [sched[0].outputs[0], *sched[0].inputs]: - print(arg.st) - assert len(arg.st.views) == 1 + for st in [x.arg.st for x in sched[0].ast[0].lazyops if x.op is BufferOps.LOAD]: + assert len(st.views) == 1 if __name__ == '__main__': unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 7ce727729f..af811ec418 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -283,8 +283,8 @@ def helper_realized_ast(r:Tensor): run_schedule(s[:-1]) # run all kernels except the last one # now all input LazyBuffers buffers in s[-1] should be realized # allocate an output buffer - output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype).allocate() - return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs] + output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate() + return s[-1].ast[0], [output_buffer] + list(s[-1].inputs) @unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4") class TestFloat4(unittest.TestCase): diff --git a/test/test_multitensor.py b/test/test_multitensor.py index fd3a3bcb4e..f7c0063cf2 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -351,7 +351,7 @@ class TestMultiTensor(unittest.TestCase): scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY] assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device" asts = [sched.ast for sched in scheds] - assert len(asts) == 8, len(asts) + assert len(asts) # test case to show that ast can be different on devices # TODO: make ast identical on devices #assert len(set(asts)) == 4, len(asts) diff --git a/test/test_search.py b/test/test_search.py index 1380880363..1f7cde3dd4 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -4,14 +4,15 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.engine.schedule import create_schedule from tinygrad.features.search import time_linearizer, bufs_from_lin from tinygrad.device import Device, Buffer -from tinygrad.ops import LoadOps +from tinygrad.ops import LoadOps, BufferOps from tinygrad.tensor import Tensor class TestTimeLinearizer(unittest.TestCase): def test_reasonable_time(self): si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0] - out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype).allocate() - rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype).allocate() for x in si.inputs] + out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate() + memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast[0].lazyops if x.op is BufferOps.LOAD} + rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))] tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10) assert tm > 0 and tm != float('inf') diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 6a62575b51..4f808c51c5 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -1,8 +1,7 @@ from typing import List, Dict, Optional -from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters +from tinygrad.ops import LoadOps, ScheduleItem, BufferOps from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, JITRunner, update_stats -from tinygrad.features.graph import realized_lazybuffer -from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG +from tinygrad.helpers import colored, getenv, cpu_time_execution, DEBUG from tinygrad.shape.symbolic import Variable class CustomOp(JITRunner): @@ -20,7 +19,7 @@ class SyncOp(JITRunner): update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname) def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]: - assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT} + assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast) assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput" out, ast = si.outputs[0], si.ast[0] @@ -43,12 +42,10 @@ def run_schedule(schedule:List[ScheduleItem]): for out in si.outputs: # we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape - if out.size > 0 and not dont_allocate and out.op is not LoadOps.ASSIGN: out.buffer.allocate() + if out.size > 0 and not dont_allocate and not hasattr(out, "_buf"): out.allocate() # run the function (put it in JIT) - real_buffers = [x.buffer for x in si.outputs+si.inputs if x.size != 0] + real_buffers = [x for x in si.outputs+si.inputs if x.size != 0] assert dont_allocate or all(hasattr(x, "_buf") for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}" if prg: prg.exec(real_buffers, si.var_vals) elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device) - if GRAPH: - for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 907f3bd285..ea41fa5e53 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -1,8 +1,9 @@ import sys from collections import defaultdict, deque -from typing import List, Dict, Optional, Set, DefaultDict -from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps -from tinygrad.features.graph import log_lazybuffer +from dataclasses import dataclass +from typing import Tuple, List, Dict, Optional, Set, DefaultDict +from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters +from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int from tinygrad.shape.symbolic import Variable from tinygrad.dtype import ImageDType, dtypes @@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker # creation can recurse a lot sys.setrecursionlimit(10000) +# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort +@dataclass(frozen=True) +class _LBScheduleItem: + ast: Tuple[LazyOp, ...] + outputs: Tuple[LazyBuffer, ...] + inputs: Tuple[LazyBuffer, ...] + var_vals: Dict[Variable, int] + # recursively create a lazyop def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker, realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp: @@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg) return ret -def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem: +def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem: inputs: List[LazyBuffer] = [] var_vals: Dict[Variable, int] = out.st.var_vals.copy() - if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}: + if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}: op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs) else: output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out] op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={}) op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:] - return ScheduleItem((op,), (out,), tuple(inputs), var_vals) + return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals) # recursively search the entire graph for all LazyBuffers, insert realizes after expands def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], @@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) queue = deque(out for out in prescheduled if in_degree[out] == 0) schedule: List[ScheduleItem] = [] + kernel_number = GlobalCounters.kernel_count while queue: buf = queue.popleft() seen.add(buf) - schedule.append(prescheduled[buf]) + ps = prescheduled[buf] + if GRAPH: + kernel_number += 1 + for out in ps.outputs: realized_lazybuffer(out, kernel_number) + schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals)) for x in graph[buf]: in_degree[x] -= 1 if in_degree[x] == 0: queue.append(x) diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index bbdf6e31cf..726eac7718 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -12,10 +12,8 @@ from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))): - if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None - if op is LoadOps.CONST: - arg = dtypes.as_const(arg, dtype) - enable_cache = True + if st.size == 0 and op is not LoadOps.SYNC: op, arg, srcs, base = LoadOps.CONST, 0, (), None + if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype), True cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) if enable_cache and (rret := lazycache.get(cache_key, None)): return rret @@ -101,8 +99,7 @@ class LazyBuffer: # copies in HSA/CUDA to other HSA/CUDA don't sync either return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False) sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True) - wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True) - return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False) + return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, sync), enable_cache=False) def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer: # no COPY diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5cecbbc52e..8f5664ff02 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -18,14 +18,13 @@ class BinaryOps(Enum): class TernaryOps(Enum): WHERE = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702 -class LoadOps(Enum): - EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto(); ASSIGN = auto() # noqa: E702 +class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); ASSIGN = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps] OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]] if TYPE_CHECKING: - from tinygrad.lazy import LazyBuffer + from tinygrad.buffer import Buffer @dataclass(frozen=True) class MemBuffer: @@ -42,8 +41,8 @@ class ConstBuffer: @dataclass(frozen=True) class ScheduleItem: ast: Tuple[LazyOp, ...] - outputs: Tuple[LazyBuffer, ...] - inputs: Tuple[LazyBuffer, ...] + outputs: Tuple[Buffer, ...] + inputs: Tuple[Buffer, ...] var_vals: Dict[Variable, int] @dataclass(frozen=True, eq=False)