Tensor.kernelize (#9845)

* add kernelize

* remove that

* kernelize returns self

* update abstractions2.py

* kernelize in test_schedule

* temp: assert BUFFER_VIEW's existence

* ASSIGN must have a buffer or subbuffer target

* assert and shrink

* fix

* padded setitem

* var

* toposort once

* extra

* base_buffer

* end with BUFFER_VIEW

* setitem for disk

* test_setitem_becomes_subbuffer

* mul slice test

* torch backend fix 1

* non-deterministic

* keep subbuffer
This commit is contained in:
qazal
2025-04-20 15:53:49 +03:00
committed by GitHub
parent dd16087f62
commit e20ef7196a
7 changed files with 47 additions and 29 deletions

View File

@@ -78,6 +78,7 @@ print("******** third, the UOp ***********")
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
from tinygrad.engine.grouper import get_becomes_map
# allocate some values + load in values # allocate some values + load in values
a = UOp.new_buffer(DEVICE, 1, dtypes.int32) a = UOp.new_buffer(DEVICE, 1, dtypes.int32)
@@ -89,7 +90,19 @@ b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
out = a + b out = a + b
s = UOp(Ops.SINK, dtypes.void, (out,)) s = UOp(Ops.SINK, dtypes.void, (out,))
# schedule the computation as a list of kernels # group the computation into kernels
becomes_map = get_becomes_map(s)
# the compute maps to an assign
assign = becomes_map[a+b]
# the first source is the output buffer (data)
assert assign.src[0].op is Ops.BUFFER
# the second source is the kernel (compute)
assert assign.src[1].op is Ops.KERNEL
# schedule the kernel graph in a linear list
s = UOp(Ops.SINK, dtypes.void, (assign,))
sched, _, becomes_map = create_schedule_with_vars(s) sched, _, becomes_map = create_schedule_with_vars(s)
assert len(sched) == 1 assert len(sched) == 1
@@ -98,7 +111,7 @@ print(sched[-1].ast)
# NOTE: sched[-1].ast is the same as st_0 above # NOTE: sched[-1].ast is the same as st_0 above
# the output will be stored in a new buffer # the output will be stored in a new buffer
out = becomes_map[a+b] out = becomes_map[assign]
assert out.op is Ops.BUFFER and not out.buffer.is_allocated() assert out.op is Ops.BUFFER and not out.buffer.is_allocated()
print(out) print(out)

View File

@@ -3,6 +3,7 @@
# A002 Function argument `input` is shadowing a Python builtin # A002 Function argument `input` is shadowing a Python builtin
# A006 Lambda argument `input` is shadowing a Python builtin # A006 Lambda argument `input` is shadowing a Python builtin
from tinygrad import Tensor, dtypes, Device from tinygrad import Tensor, dtypes, Device
from tinygrad.ops import Ops
from tinygrad.helpers import getenv, prod from tinygrad.helpers import getenv, prod
import torch.lib import torch.lib
TORCH_DEBUG = getenv("TORCH_DEBUG") TORCH_DEBUG = getenv("TORCH_DEBUG")
@@ -67,6 +68,7 @@ def realize_with_views(self: Tensor, views: Tensor):
if not self.lazydata.st.contiguous: self.replace(self.contiguous()) if not self.lazydata.st.contiguous: self.replace(self.contiguous())
self.replace(self.clone().realize()) self.replace(self.clone().realize())
for v in views: for v in views:
if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
ret = self ret = self
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right? st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo) for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
from tinygrad.codegen.symbolic import symbolic_simple from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.grouper import view_left, view_right, sym from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis from extra.models.llama import precompute_freqs_cis
@@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t) elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else: else:
assert isinstance(t, UOp), f"can't schedule {t}" assert isinstance(t, UOp), f"can't schedule {t}"
sched, _, __ = create_schedule_with_vars(t.sink()) sink = UOp.sink(t)
becomes_map = get_becomes_map(sink)
sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map))
# test lowering all the ScheduleItems to ExecItems # test lowering all the ScheduleItems to ExecItems
lowered = [x[1] for x in lower_schedule(sched.copy())] lowered = [x[1] for x in lower_schedule(sched.copy())]
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
@@ -2510,12 +2512,15 @@ class TestUOpBecome(unittest.TestCase):
assert b.lazydata is c.lazydata assert b.lazydata is c.lazydata
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {}) assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
def test_setitem_becomes_view_of_base(self): def test_setitem_becomes_subbuffer(self):
a = Tensor.full((4,), 2.).contiguous().realize() a = Tensor.full((4,), 2.).contiguous().realize()
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0)) b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
b.realize() b.realize()
assert b.lazydata.is_realized assert a.lazydata.is_realized
assert b.lazydata.base.buffer._base is None assert a.lazydata.buffer._base is None
# b is a subbuffer of a
assert b.lazydata.op is Ops.BUFFER_VIEW
assert b.lazydata.src[0] is a.lazydata
def test_setitem_offset(self): def test_setitem_offset(self):
a = Tensor.full((16,), 0.).contiguous().realize() a = Tensor.full((16,), 0.).contiguous().realize()

View File

@@ -465,7 +465,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
# map tensors to buffer/const/assign, optionally apply a VIEW on top # map tensors to buffer/const/assign, optionally apply a VIEW on top
becomes_map: dict[UOp, UOp] = {} becomes_map: dict[UOp, UOp] = {}
for k,v in tensor_map.items(): for k,v in tensor_map.items():
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
if k is v: continue if k is v: continue
op = v.base.op op = v.base.op
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v

View File

@@ -3,7 +3,6 @@ from collections import deque
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
from tinygrad.device import Buffer from tinygrad.device import Buffer
from tinygrad.helpers import Metadata, DEBUG, unwrap from tinygrad.helpers import Metadata, DEBUG, unwrap
from tinygrad.engine.grouper import get_becomes_map
# **** ScheduleItem return type # **** ScheduleItem return type
@@ -34,14 +33,11 @@ pm_unbind = PatternMatcher([
# **** schedule linearizer # **** schedule linearizer
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]: def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
becomes_map = get_becomes_map(big_sink)
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
# bfs toposort # bfs toposort
children: dict[UOp, list[UOp]] = {} children: dict[UOp, list[UOp]] = {}
in_degree: dict[UOp, int] = {} in_degree: dict[UOp, int] = {}
for u in sched_sink.toposort: for u in (toposort:=sched_sink.toposort):
if u.op is not Ops.ASSIGN: continue if u.op is not Ops.ASSIGN: continue
in_degree[u] = 0 in_degree[u] = 0
for s in u.src[1].src: for s in u.src[1].src:
@@ -67,11 +63,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
# map ASSIGN to BUFFER after ScheduleItems are constructed # map ASSIGN to BUFFER after ScheduleItems are constructed
for k,v in becomes_map.items(): becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN}
if v.base.op is Ops.ASSIGN: assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}"
# if the UOp was already an assign Tensor UOp we just map it to the existing buffer
if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0]
# otherwise we map it to the new buffer, ignoring NOOP ShapeTrackers
else: becomes_map[k] = new_buf if (new_buf:=v.base.src[0]).st == v.st else new_buf.view(unwrap(v.st))
return schedule, var_vals, becomes_map return schedule, var_vals, becomes_map

View File

@@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([
# "make things that can't be images not images" can change the buffer dtype # "make things that can't be images not images" can change the buffer dtype
# this is fine as long as it's a realized buffer and base dtypes match. # this is fine as long as it's a realized buffer and base dtypes match.
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)), ((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False), (UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False),
# Tensor variable bindings # Tensor variable bindings
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True), (UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),

View File

@@ -15,6 +15,7 @@ from tinygrad.device import Device, Buffer
from tinygrad.engine.realize import run_schedule from tinygrad.engine.realize import run_schedule
from tinygrad.engine.memory import memory_planner from tinygrad.engine.memory import memory_planner
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.grouper import get_becomes_map
# *** all in scope Tensors are here. this gets relevant UOps *** # *** all in scope Tensors are here. this gets relevant UOps ***
@@ -223,12 +224,7 @@ class Tensor(SimpleMathTrait):
# ***** data handlers **** # ***** data handlers ****
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]: def kernelize(self, *lst:Tensor) -> Tensor:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
NOTE: A Tensor can only be scheduled once.
"""
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst]) big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
# TODO: move this to scheduler tensor_map pass # TODO: move this to scheduler tensor_map pass
@@ -240,7 +236,18 @@ class Tensor(SimpleMathTrait):
# verify Tensors match the spec # verify Tensors match the spec
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec) if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink) becomes_map = get_becomes_map(big_sink)
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
return self
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
"""
Creates the schedule needed to realize these Tensor(s), with Variables.
NOTE: A Tensor can only be scheduled once.
"""
self.kernelize(*lst)
schedule, var_vals, becomes_map = create_schedule_with_vars(UOp.sink(*[x.lazydata for x in (self,)+lst]))
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map") _apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
return memory_planner(schedule), var_vals return memory_planner(schedule), var_vals
@@ -1202,7 +1209,7 @@ class Tensor(SimpleMathTrait):
def __setitem__(self, indices, v:Tensor|ConstType) -> None: def __setitem__(self, indices, v:Tensor|ConstType) -> None:
if isinstance(self.device, str) and self.device.startswith("DISK"): if isinstance(self.device, str) and self.device.startswith("DISK"):
self._getitem(indices).assign(v) self.realize()._getitem(indices).assign(v)
return return
# NOTE: check that setitem target is valid first # NOTE: check that setitem target is valid first
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous") if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")