diff --git a/docs/abstractions2.py b/docs/abstractions2.py index d73e42d6a9..668bc9f21c 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -78,6 +78,7 @@ print("******** third, the UOp ***********") from tinygrad.engine.realize import run_schedule from tinygrad.engine.schedule import create_schedule_with_vars +from tinygrad.engine.grouper import get_becomes_map # allocate some values + load in values a = UOp.new_buffer(DEVICE, 1, dtypes.int32) @@ -89,7 +90,19 @@ b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3)))) out = a + b s = UOp(Ops.SINK, dtypes.void, (out,)) -# schedule the computation as a list of kernels +# group the computation into kernels +becomes_map = get_becomes_map(s) + +# the compute maps to an assign +assign = becomes_map[a+b] + +# the first source is the output buffer (data) +assert assign.src[0].op is Ops.BUFFER +# the second source is the kernel (compute) +assert assign.src[1].op is Ops.KERNEL + +# schedule the kernel graph in a linear list +s = UOp(Ops.SINK, dtypes.void, (assign,)) sched, _, becomes_map = create_schedule_with_vars(s) assert len(sched) == 1 @@ -98,7 +111,7 @@ print(sched[-1].ast) # NOTE: sched[-1].ast is the same as st_0 above # the output will be stored in a new buffer -out = becomes_map[a+b] +out = becomes_map[assign] assert out.op is Ops.BUFFER and not out.buffer.is_allocated() print(out) diff --git a/extra/torch_backend/backend.py b/extra/torch_backend/backend.py index 69f8f6ae99..28b00d1a4f 100644 --- a/extra/torch_backend/backend.py +++ b/extra/torch_backend/backend.py @@ -3,6 +3,7 @@ # A002 Function argument `input` is shadowing a Python builtin # A006 Lambda argument `input` is shadowing a Python builtin from tinygrad import Tensor, dtypes, Device +from tinygrad.ops import Ops from tinygrad.helpers import getenv, prod import torch.lib TORCH_DEBUG = getenv("TORCH_DEBUG") @@ -67,6 +68,7 @@ def realize_with_views(self: Tensor, views: Tensor): if not self.lazydata.st.contiguous: self.replace(self.contiguous()) self.replace(self.clone().realize()) for v in views: + if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view ret = self st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right? for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo) diff --git a/test/test_schedule.py b/test/test_schedule.py index ecf8dfb7cc..8fa38e3704 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.spec import type_verify, shape_spec from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp -from tinygrad.engine.grouper import view_left, view_right, sym +from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from extra.models.llama import precompute_freqs_cis @@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) else: assert isinstance(t, UOp), f"can't schedule {t}" - sched, _, __ = create_schedule_with_vars(t.sink()) + sink = UOp.sink(t) + becomes_map = get_becomes_map(sink) + sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map)) # test lowering all the ScheduleItems to ExecItems lowered = [x[1] for x in lower_schedule(sched.copy())] if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] @@ -2510,12 +2512,15 @@ class TestUOpBecome(unittest.TestCase): assert b.lazydata is c.lazydata assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {}) - def test_setitem_becomes_view_of_base(self): + def test_setitem_becomes_subbuffer(self): a = Tensor.full((4,), 2.).contiguous().realize() b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) b.realize() - assert b.lazydata.is_realized - assert b.lazydata.base.buffer._base is None + assert a.lazydata.is_realized + assert a.lazydata.buffer._base is None + # b is a subbuffer of a + assert b.lazydata.op is Ops.BUFFER_VIEW + assert b.lazydata.src[0] is a.lazydata def test_setitem_offset(self): a = Tensor.full((16,), 0.).contiguous().realize() diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 23e7070eef..ff06f5c646 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -465,7 +465,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]: # map tensors to buffer/const/assign, optionally apply a VIEW on top becomes_map: dict[UOp, UOp] = {} for k,v in tensor_map.items(): - if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st)) if k is v: continue op = v.base.op if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index c3a1fd8157..0db4ffbe0d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -3,7 +3,6 @@ from collections import deque from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers from tinygrad.device import Buffer from tinygrad.helpers import Metadata, DEBUG, unwrap -from tinygrad.engine.grouper import get_becomes_map # **** ScheduleItem return type @@ -34,14 +33,11 @@ pm_unbind = PatternMatcher([ # **** schedule linearizer -def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: - becomes_map = get_becomes_map(big_sink) - sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src]) - +def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: # bfs toposort children: dict[UOp, list[UOp]] = {} in_degree: dict[UOp, int] = {} - for u in sched_sink.toposort: + for u in (toposort:=sched_sink.toposort): if u.op is not Ops.ASSIGN: continue in_degree[u] = 0 for s in u.src[1].src: @@ -67,11 +63,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") # map ASSIGN to BUFFER after ScheduleItems are constructed - for k,v in becomes_map.items(): - if v.base.op is Ops.ASSIGN: - # if the UOp was already an assign Tensor UOp we just map it to the existing buffer - if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0] - # otherwise we map it to the new buffer, ignoring NOOP ShapeTrackers - else: becomes_map[k] = new_buf if (new_buf:=v.base.src[0]).st == v.st else new_buf.view(unwrap(v.st)) + becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN} + assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}" return schedule, var_vals, becomes_map diff --git a/tinygrad/spec.py b/tinygrad/spec.py index 6273642ca4..b9bd90a131 100644 --- a/tinygrad/spec.py +++ b/tinygrad/spec.py @@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([ # "make things that can't be images not images" can change the buffer dtype # this is fine as long as it's a realized buffer and base dtypes match. ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)), - (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False), + (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False), # Tensor variable bindings (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 4ef54a27c1..c377af2abc 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -15,6 +15,7 @@ from tinygrad.device import Device, Buffer from tinygrad.engine.realize import run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars +from tinygrad.engine.grouper import get_becomes_map # *** all in scope Tensors are here. this gets relevant UOps *** @@ -223,12 +224,7 @@ class Tensor(SimpleMathTrait): # ***** data handlers **** - def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: - """ - Creates the schedule needed to realize these Tensor(s), with Variables. - - NOTE: A Tensor can only be scheduled once. - """ + def kernelize(self, *lst:Tensor) -> Tensor: big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) # TODO: move this to scheduler tensor_map pass @@ -240,7 +236,18 @@ class Tensor(SimpleMathTrait): # verify Tensors match the spec if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec) - schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) + becomes_map = get_becomes_map(big_sink) + _apply_map_to_tensors(becomes_map, name="Apply Kernelize Map") + return self + + def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: + """ + Creates the schedule needed to realize these Tensor(s), with Variables. + + NOTE: A Tensor can only be scheduled once. + """ + self.kernelize(*lst) + schedule, var_vals, becomes_map = create_schedule_with_vars(UOp.sink(*[x.lazydata for x in (self,)+lst])) _apply_map_to_tensors(becomes_map, name="Apply Schedule Map") return memory_planner(schedule), var_vals @@ -1202,7 +1209,7 @@ class Tensor(SimpleMathTrait): def __setitem__(self, indices, v:Tensor|ConstType) -> None: if isinstance(self.device, str) and self.device.startswith("DISK"): - self._getitem(indices).assign(v) + self.realize()._getitem(indices).assign(v) return # NOTE: check that setitem target is valid first if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")