mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 14:43:57 -05:00
ScheduleItem uses Buffer (#3995)
* schedule Buffer * update * update tests * master * works * remove LoadOps.WAIT * fix compile2 * bad test * rename and note
This commit is contained in:
@@ -36,7 +36,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
schedule = create_schedule([ret.lazydata])
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base for x in inputs.values()]
|
||||
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
@@ -89,10 +89,10 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
|
||||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {})
|
||||
for si in schedule: lower_schedule_item(si)(si.outputs+si.inputs, {})
|
||||
|
||||
new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
new_tinygrad_out = np.frombuffer(schedule[-1].outputs[0].as_buffer(), dtype=schedule[-1].outputs[0].dtype.np)
|
||||
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, BufferOps
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
@@ -15,10 +15,8 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast[0].op not in LoadOps]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
print(sched[0])
|
||||
for arg in [sched[0].outputs[0], *sched[0].inputs]:
|
||||
print(arg.st)
|
||||
assert len(arg.st.views) == 1
|
||||
for st in [x.arg.st for x in sched[0].ast[0].lazyops if x.op is BufferOps.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -283,8 +283,8 @@ def helper_realized_ast(r:Tensor):
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# allocate an output buffer
|
||||
output_buffer = Buffer((out:=s[-1].outputs[0]).device, prod((s if isinstance(s, int) else s.max for s in out.shape)), out.dtype).allocate()
|
||||
return s[-1].ast[0], [output_buffer] + [l.realized for l in s[-1].inputs]
|
||||
output_buffer = Buffer((out:=s[-1].outputs[0]).device, out.size, out.dtype).allocate()
|
||||
return s[-1].ast[0], [output_buffer] + list(s[-1].inputs)
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].compiler.compiler_opts.supports_float4, "need backends that support float4")
|
||||
class TestFloat4(unittest.TestCase):
|
||||
|
||||
@@ -351,7 +351,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices and sched.ast[0].op is not LoadOps.COPY]
|
||||
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices), "should have ast on each shard device"
|
||||
asts = [sched.ast for sched in scheds]
|
||||
assert len(asts) == 8, len(asts)
|
||||
assert len(asts)
|
||||
# test case to show that ast can be different on devices
|
||||
# TODO: make ast identical on devices
|
||||
#assert len(set(asts)) == 4, len(asts)
|
||||
|
||||
@@ -4,14 +4,15 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.features.search import time_linearizer, bufs_from_lin
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, BufferOps
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast[0].op not in LoadOps][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].st.real_size(), si.outputs[0].dtype).allocate()
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype).allocate() for x in si.inputs]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast[0].lazyops if x.op is BufferOps.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
tm = time_linearizer(Linearizer(*si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from typing import List, Dict, Optional
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, JITRunner, update_stats
|
||||
from tinygrad.features.graph import realized_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
|
||||
from tinygrad.helpers import colored, getenv, cpu_time_execution, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class CustomOp(JITRunner):
|
||||
@@ -20,7 +19,7 @@ class SyncOp(JITRunner):
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op in {LoadOps.COPY, LoadOps.WAIT}
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
@@ -43,12 +42,10 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0 and not dont_allocate and out.op is not LoadOps.ASSIGN: out.buffer.allocate()
|
||||
if out.size > 0 and not dont_allocate and not hasattr(out, "_buf"): out.allocate()
|
||||
|
||||
# run the function (put it in JIT)
|
||||
real_buffers = [x.buffer for x in si.outputs+si.inputs if x.size != 0]
|
||||
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
|
||||
assert dont_allocate or all(hasattr(x, "_buf") for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
|
||||
if prg: prg.exec(real_buffers, si.var_vals)
|
||||
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
|
||||
if GRAPH:
|
||||
for out in si.outputs: realized_lazybuffer(out, GlobalCounters.kernel_count)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from typing import List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.features.graph import log_lazybuffer
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps, GlobalCounters
|
||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
@@ -12,6 +13,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
|
||||
@dataclass(frozen=True)
|
||||
class _LBScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
@@ -63,16 +72,16 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
|
||||
return ScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
@@ -201,10 +210,15 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
|
||||
queue = deque(out for out in prescheduled if in_degree[out] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
buf = queue.popleft()
|
||||
seen.add(buf)
|
||||
schedule.append(prescheduled[buf])
|
||||
ps = prescheduled[buf]
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
|
||||
for x in graph[buf]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
||||
@@ -12,10 +12,8 @@ from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0 and op not in {LoadOps.SYNC, LoadOps.WAIT}: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST:
|
||||
arg = dtypes.as_const(arg, dtype)
|
||||
enable_cache = True
|
||||
if st.size == 0 and op is not LoadOps.SYNC: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype), True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
if enable_cache and (rret := lazycache.get(cache_key, None)): return rret
|
||||
@@ -101,8 +99,7 @@ class LazyBuffer:
|
||||
# copies in HSA/CUDA to other HSA/CUDA don't sync either
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
||||
wait = LazyBuffer.loadop(LoadOps.WAIT, (0,), dtypes.uint32, device, src=(sync,), enable_cache=True)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, wait), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, sync), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
||||
# no COPY
|
||||
|
||||
@@ -18,14 +18,13 @@ class BinaryOps(Enum):
|
||||
class TernaryOps(Enum): WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class LoadOps(Enum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); WAIT = auto(); ASSIGN = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); ASSIGN = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.buffer import Buffer
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemBuffer:
|
||||
@@ -42,8 +41,8 @@ class ConstBuffer:
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[LazyBuffer, ...]
|
||||
inputs: Tuple[LazyBuffer, ...]
|
||||
outputs: Tuple[Buffer, ...]
|
||||
inputs: Tuple[Buffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
|
||||
Reference in New Issue
Block a user