mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.kernelize (#9845)
* add kernelize * remove that * kernelize returns self * update abstractions2.py * kernelize in test_schedule * temp: assert BUFFER_VIEW's existence * ASSIGN must have a buffer or subbuffer target * assert and shrink * fix * padded setitem * var * toposort once * extra * base_buffer * end with BUFFER_VIEW * setitem for disk * test_setitem_becomes_subbuffer * mul slice test * torch backend fix 1 * non-deterministic * keep subbuffer
This commit is contained in:
@@ -78,6 +78,7 @@ print("******** third, the UOp ***********")
|
|||||||
|
|
||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
|
from tinygrad.engine.grouper import get_becomes_map
|
||||||
|
|
||||||
# allocate some values + load in values
|
# allocate some values + load in values
|
||||||
a = UOp.new_buffer(DEVICE, 1, dtypes.int32)
|
a = UOp.new_buffer(DEVICE, 1, dtypes.int32)
|
||||||
@@ -89,7 +90,19 @@ b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3))))
|
|||||||
out = a + b
|
out = a + b
|
||||||
s = UOp(Ops.SINK, dtypes.void, (out,))
|
s = UOp(Ops.SINK, dtypes.void, (out,))
|
||||||
|
|
||||||
# schedule the computation as a list of kernels
|
# group the computation into kernels
|
||||||
|
becomes_map = get_becomes_map(s)
|
||||||
|
|
||||||
|
# the compute maps to an assign
|
||||||
|
assign = becomes_map[a+b]
|
||||||
|
|
||||||
|
# the first source is the output buffer (data)
|
||||||
|
assert assign.src[0].op is Ops.BUFFER
|
||||||
|
# the second source is the kernel (compute)
|
||||||
|
assert assign.src[1].op is Ops.KERNEL
|
||||||
|
|
||||||
|
# schedule the kernel graph in a linear list
|
||||||
|
s = UOp(Ops.SINK, dtypes.void, (assign,))
|
||||||
sched, _, becomes_map = create_schedule_with_vars(s)
|
sched, _, becomes_map = create_schedule_with_vars(s)
|
||||||
assert len(sched) == 1
|
assert len(sched) == 1
|
||||||
|
|
||||||
@@ -98,7 +111,7 @@ print(sched[-1].ast)
|
|||||||
# NOTE: sched[-1].ast is the same as st_0 above
|
# NOTE: sched[-1].ast is the same as st_0 above
|
||||||
|
|
||||||
# the output will be stored in a new buffer
|
# the output will be stored in a new buffer
|
||||||
out = becomes_map[a+b]
|
out = becomes_map[assign]
|
||||||
assert out.op is Ops.BUFFER and not out.buffer.is_allocated()
|
assert out.op is Ops.BUFFER and not out.buffer.is_allocated()
|
||||||
print(out)
|
print(out)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
# A002 Function argument `input` is shadowing a Python builtin
|
# A002 Function argument `input` is shadowing a Python builtin
|
||||||
# A006 Lambda argument `input` is shadowing a Python builtin
|
# A006 Lambda argument `input` is shadowing a Python builtin
|
||||||
from tinygrad import Tensor, dtypes, Device
|
from tinygrad import Tensor, dtypes, Device
|
||||||
|
from tinygrad.ops import Ops
|
||||||
from tinygrad.helpers import getenv, prod
|
from tinygrad.helpers import getenv, prod
|
||||||
import torch.lib
|
import torch.lib
|
||||||
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
TORCH_DEBUG = getenv("TORCH_DEBUG")
|
||||||
@@ -67,6 +68,7 @@ def realize_with_views(self: Tensor, views: Tensor):
|
|||||||
if not self.lazydata.st.contiguous: self.replace(self.contiguous())
|
if not self.lazydata.st.contiguous: self.replace(self.contiguous())
|
||||||
self.replace(self.clone().realize())
|
self.replace(self.clone().realize())
|
||||||
for v in views:
|
for v in views:
|
||||||
|
if v.lazydata.base.op is Ops.BUFFER_VIEW: continue # skip subbuffer, we just use the real buffer view
|
||||||
ret = self
|
ret = self
|
||||||
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
|
st = ShapeTracker(self.lazydata.st.views + v.lazydata.st.views) # TODO: is this right?
|
||||||
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
|
for mo in cached_to_movement_ops(self.shape, st): ret = apply_mop(ret, mo)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
|
|||||||
from tinygrad.codegen.symbolic import symbolic_simple
|
from tinygrad.codegen.symbolic import symbolic_simple
|
||||||
from tinygrad.spec import type_verify, shape_spec
|
from tinygrad.spec import type_verify, shape_spec
|
||||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||||
from tinygrad.engine.grouper import view_left, view_right, sym
|
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map
|
||||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||||
from extra.models.llama import precompute_freqs_cis
|
from extra.models.llama import precompute_freqs_cis
|
||||||
@@ -29,7 +29,9 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
|
|||||||
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
||||||
else:
|
else:
|
||||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||||
sched, _, __ = create_schedule_with_vars(t.sink())
|
sink = UOp.sink(t)
|
||||||
|
becomes_map = get_becomes_map(sink)
|
||||||
|
sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map))
|
||||||
# test lowering all the ScheduleItems to ExecItems
|
# test lowering all the ScheduleItems to ExecItems
|
||||||
lowered = [x[1] for x in lower_schedule(sched.copy())]
|
lowered = [x[1] for x in lower_schedule(sched.copy())]
|
||||||
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
|
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||||
@@ -2510,12 +2512,15 @@ class TestUOpBecome(unittest.TestCase):
|
|||||||
assert b.lazydata is c.lazydata
|
assert b.lazydata is c.lazydata
|
||||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
|
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER),)).match(c.lazydata, {})
|
||||||
|
|
||||||
def test_setitem_becomes_view_of_base(self):
|
def test_setitem_becomes_subbuffer(self):
|
||||||
a = Tensor.full((4,), 2.).contiguous().realize()
|
a = Tensor.full((4,), 2.).contiguous().realize()
|
||||||
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
|
b = a.shrink(((0, 2),)).assign(Tensor.full((2,), 1.0))
|
||||||
b.realize()
|
b.realize()
|
||||||
assert b.lazydata.is_realized
|
assert a.lazydata.is_realized
|
||||||
assert b.lazydata.base.buffer._base is None
|
assert a.lazydata.buffer._base is None
|
||||||
|
# b is a subbuffer of a
|
||||||
|
assert b.lazydata.op is Ops.BUFFER_VIEW
|
||||||
|
assert b.lazydata.src[0] is a.lazydata
|
||||||
|
|
||||||
def test_setitem_offset(self):
|
def test_setitem_offset(self):
|
||||||
a = Tensor.full((16,), 0.).contiguous().realize()
|
a = Tensor.full((16,), 0.).contiguous().realize()
|
||||||
|
|||||||
@@ -465,7 +465,6 @@ def get_becomes_map(big_sink:UOp) -> dict[UOp, UOp]:
|
|||||||
# map tensors to buffer/const/assign, optionally apply a VIEW on top
|
# map tensors to buffer/const/assign, optionally apply a VIEW on top
|
||||||
becomes_map: dict[UOp, UOp] = {}
|
becomes_map: dict[UOp, UOp] = {}
|
||||||
for k,v in tensor_map.items():
|
for k,v in tensor_map.items():
|
||||||
if (kernel:=tensor_map.get(v.base)) is not None and kernel.base.op is Ops.ASSIGN: v = kernel.view(unwrap(v.st))
|
|
||||||
if k is v: continue
|
if k is v: continue
|
||||||
op = v.base.op
|
op = v.base.op
|
||||||
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
if op in {Ops.BUFFER, Ops.ASSIGN}: becomes_map[k] = v
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from collections import deque
|
|||||||
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
|
from tinygrad.ops import UOp, Variable, Ops, UPat, PatternMatcher, graph_rewrite, buffers
|
||||||
from tinygrad.device import Buffer
|
from tinygrad.device import Buffer
|
||||||
from tinygrad.helpers import Metadata, DEBUG, unwrap
|
from tinygrad.helpers import Metadata, DEBUG, unwrap
|
||||||
from tinygrad.engine.grouper import get_becomes_map
|
|
||||||
|
|
||||||
# **** ScheduleItem return type
|
# **** ScheduleItem return type
|
||||||
|
|
||||||
@@ -34,14 +33,11 @@ pm_unbind = PatternMatcher([
|
|||||||
|
|
||||||
# **** schedule linearizer
|
# **** schedule linearizer
|
||||||
|
|
||||||
def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
def create_schedule_with_vars(sched_sink:UOp) -> tuple[list[ScheduleItem], dict[Variable, int], dict[UOp, UOp]]:
|
||||||
becomes_map = get_becomes_map(big_sink)
|
|
||||||
sched_sink = UOp.sink(*[becomes_map.get(x,x) for x in big_sink.src])
|
|
||||||
|
|
||||||
# bfs toposort
|
# bfs toposort
|
||||||
children: dict[UOp, list[UOp]] = {}
|
children: dict[UOp, list[UOp]] = {}
|
||||||
in_degree: dict[UOp, int] = {}
|
in_degree: dict[UOp, int] = {}
|
||||||
for u in sched_sink.toposort:
|
for u in (toposort:=sched_sink.toposort):
|
||||||
if u.op is not Ops.ASSIGN: continue
|
if u.op is not Ops.ASSIGN: continue
|
||||||
in_degree[u] = 0
|
in_degree[u] = 0
|
||||||
for s in u.src[1].src:
|
for s in u.src[1].src:
|
||||||
@@ -67,11 +63,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
|||||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||||
|
|
||||||
# map ASSIGN to BUFFER after ScheduleItems are constructed
|
# map ASSIGN to BUFFER after ScheduleItems are constructed
|
||||||
for k,v in becomes_map.items():
|
becomes_map = {u:u.buf_uop for u in toposort if u.op is Ops.ASSIGN}
|
||||||
if v.base.op is Ops.ASSIGN:
|
assert all(u.op in {Ops.BUFFER, Ops.BUFFER_VIEW} for u in becomes_map.values()), f"Schedule didn't end with BUFFER {becomes_map.values()}"
|
||||||
# if the UOp was already an assign Tensor UOp we just map it to the existing buffer
|
|
||||||
if k.op is Ops.ASSIGN: becomes_map[k] = k.src[0]
|
|
||||||
# otherwise we map it to the new buffer, ignoring NOOP ShapeTrackers
|
|
||||||
else: becomes_map[k] = new_buf if (new_buf:=v.base.src[0]).st == v.st else new_buf.view(unwrap(v.st))
|
|
||||||
|
|
||||||
return schedule, var_vals, becomes_map
|
return schedule, var_vals, becomes_map
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ tensor_uop_spec = buffer_spec+PatternMatcher([
|
|||||||
# "make things that can't be images not images" can change the buffer dtype
|
# "make things that can't be images not images" can change the buffer dtype
|
||||||
# this is fine as long as it's a realized buffer and base dtypes match.
|
# this is fine as long as it's a realized buffer and base dtypes match.
|
||||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
|
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.base.op is Ops.BUFFER)),
|
||||||
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
(UPat(Ops.VIEW, src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.BUFFER_VIEW, Ops.CONST, Ops.DEVICE}),)), lambda: False),
|
||||||
|
|
||||||
# Tensor variable bindings
|
# Tensor variable bindings
|
||||||
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
(UPat(Ops.BIND, dtypes.int, (UPat(Ops.DEFINE_VAR), UPat.cvar(dtype=dtypes.int)), arg=None), lambda: True),
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from tinygrad.device import Device, Buffer
|
|||||||
from tinygrad.engine.realize import run_schedule
|
from tinygrad.engine.realize import run_schedule
|
||||||
from tinygrad.engine.memory import memory_planner
|
from tinygrad.engine.memory import memory_planner
|
||||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||||
|
from tinygrad.engine.grouper import get_becomes_map
|
||||||
|
|
||||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||||
|
|
||||||
@@ -223,12 +224,7 @@ class Tensor(SimpleMathTrait):
|
|||||||
|
|
||||||
# ***** data handlers ****
|
# ***** data handlers ****
|
||||||
|
|
||||||
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
def kernelize(self, *lst:Tensor) -> Tensor:
|
||||||
"""
|
|
||||||
Creates the schedule needed to realize these Tensor(s), with Variables.
|
|
||||||
|
|
||||||
NOTE: A Tensor can only be scheduled once.
|
|
||||||
"""
|
|
||||||
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
big_sink = UOp.sink(*[x.lazydata for x in (self,)+lst])
|
||||||
|
|
||||||
# TODO: move this to scheduler tensor_map pass
|
# TODO: move this to scheduler tensor_map pass
|
||||||
@@ -240,7 +236,18 @@ class Tensor(SimpleMathTrait):
|
|||||||
# verify Tensors match the spec
|
# verify Tensors match the spec
|
||||||
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
if __debug__: type_verify(list(big_sink.toposort), tensor_uop_spec)
|
||||||
|
|
||||||
schedule, var_vals, becomes_map = create_schedule_with_vars(big_sink)
|
becomes_map = get_becomes_map(big_sink)
|
||||||
|
_apply_map_to_tensors(becomes_map, name="Apply Kernelize Map")
|
||||||
|
return self
|
||||||
|
|
||||||
|
def schedule_with_vars(self, *lst:Tensor) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||||
|
"""
|
||||||
|
Creates the schedule needed to realize these Tensor(s), with Variables.
|
||||||
|
|
||||||
|
NOTE: A Tensor can only be scheduled once.
|
||||||
|
"""
|
||||||
|
self.kernelize(*lst)
|
||||||
|
schedule, var_vals, becomes_map = create_schedule_with_vars(UOp.sink(*[x.lazydata for x in (self,)+lst]))
|
||||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||||
return memory_planner(schedule), var_vals
|
return memory_planner(schedule), var_vals
|
||||||
|
|
||||||
@@ -1202,7 +1209,7 @@ class Tensor(SimpleMathTrait):
|
|||||||
|
|
||||||
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
def __setitem__(self, indices, v:Tensor|ConstType) -> None:
|
||||||
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
if isinstance(self.device, str) and self.device.startswith("DISK"):
|
||||||
self._getitem(indices).assign(v)
|
self.realize()._getitem(indices).assign(v)
|
||||||
return
|
return
|
||||||
# NOTE: check that setitem target is valid first
|
# NOTE: check that setitem target is valid first
|
||||||
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
if not unwrap(self.lazydata.st).contiguous: raise RuntimeError("setitem target needs to be contiguous")
|
||||||
|
|||||||
Reference in New Issue
Block a user