mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
move children for speed (#3047)
* move children for speed * no children anymore
This commit is contained in:
@@ -4,7 +4,6 @@ import unittest
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
@unittest.skip("it doesn't work like this anymore")
|
||||
@@ -46,26 +45,6 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
z = Tensor([1, np.e]).numpy()
|
||||
np.testing.assert_allclose(y, z)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in ["METAL", "CUDA", "GPU"], "Only GPU backends supports cache")
|
||||
def test_children_count(self):
|
||||
a = Tensor.ones(8,8,8)
|
||||
d1 = a.sum((0))
|
||||
d2 = a.sum((0)).reshape(32,2) # noqa: F841
|
||||
assert len(d1.lazydata.base.srcs[0].base.children) == 1
|
||||
in1 = d1.reshape(16,4)
|
||||
d3 = in1.reshape(8,8)
|
||||
assert len(d3.lazydata.base.srcs[0].base.children) == 1
|
||||
|
||||
CacheCollector.start()
|
||||
l = Tensor.ones(8,8)
|
||||
r = Tensor.ones(8,8)
|
||||
dd = d1 + l
|
||||
dd.realize()
|
||||
de = d3 + r
|
||||
de.realize()
|
||||
cache = CacheCollector.finish()
|
||||
assert len(cache) == 2
|
||||
|
||||
def test_device_canonicalize(self):
|
||||
a = Tensor([1, 2, 3], f"{Device.DEFAULT}")
|
||||
b = Tensor([1, 2, 3], f"{Device.DEFAULT}:0")
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import sys, math
|
||||
import numpy as np
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
@@ -9,7 +10,7 @@ from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.graph import log_lazybuffer
|
||||
from weakref import ref, WeakSet, WeakValueDictionary, ReferenceType
|
||||
from weakref import ref, WeakValueDictionary, ReferenceType
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
@@ -42,8 +43,6 @@ class LazyBuffer:
|
||||
self.output_buffer: Optional[Buffer] = None
|
||||
self.forced_realize = False
|
||||
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
|
||||
self.children: WeakSet[LazyBuffer] = WeakSet()
|
||||
for x in srcs: x.base.children.add(self.base)
|
||||
else:
|
||||
# properties on view
|
||||
assert base.base == base, "base must be a base itself"
|
||||
@@ -213,7 +212,8 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
||||
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]):
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
|
||||
if buf in allbufs or buf.base.realized: return
|
||||
log_lazybuffer(buf)
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
@@ -228,14 +228,16 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
||||
simple_pads.add(buf.base)
|
||||
else:
|
||||
realizes.add(buf.base)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
||||
if buf.forced_realize: realizes.add(buf)
|
||||
allbufs[buf] = None
|
||||
if buf.op in LoadOps: realizes.add(buf.base)
|
||||
if buf.op == LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads)
|
||||
for x in buf.srcs:
|
||||
children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
@@ -252,7 +254,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Set[LazyBuffer] = set()
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads)
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
|
||||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
@@ -282,7 +285,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
forced_realize = True
|
||||
break
|
||||
continue
|
||||
for tr_next in tr.children:
|
||||
for tr_next in children[tr].keys():
|
||||
if not tr_next.realized:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps:
|
||||
@@ -299,8 +302,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
st = tr.st
|
||||
while len(tr.children) == 1:
|
||||
tr_next = next(iter(tr.children))
|
||||
while len(children[tr]) == 1:
|
||||
tr_next = next(iter(children[tr].keys()))
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].st.size: break
|
||||
|
||||
Reference in New Issue
Block a user