move children for speed (#3047)

* move children for speed

* no children anymore
This commit is contained in:
George Hotz
2024-01-08 15:02:32 -08:00
committed by GitHub
parent 655c6f61d3
commit 7ea2e0035b
2 changed files with 14 additions and 32 deletions

View File

@@ -4,7 +4,6 @@ import unittest
from tinygrad.lazy import LazyBuffer
from tinygrad import Device
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
class TestLazyBuffer(unittest.TestCase):
@unittest.skip("it doesn't work like this anymore")
@@ -46,26 +45,6 @@ class TestLazyBuffer(unittest.TestCase):
z = Tensor([1, np.e]).numpy()
np.testing.assert_allclose(y, z)
@unittest.skipUnless(Device.DEFAULT in ["METAL", "CUDA", "GPU"], "Only GPU backends supports cache")
def test_children_count(self):
a = Tensor.ones(8,8,8)
d1 = a.sum((0))
d2 = a.sum((0)).reshape(32,2) # noqa: F841
assert len(d1.lazydata.base.srcs[0].base.children) == 1
in1 = d1.reshape(16,4)
d3 = in1.reshape(8,8)
assert len(d3.lazydata.base.srcs[0].base.children) == 1
CacheCollector.start()
l = Tensor.ones(8,8)
r = Tensor.ones(8,8)
dd = d1 + l
dd.realize()
de = d3 + r
de.realize()
cache = CacheCollector.finish()
assert len(cache) == 2
def test_device_canonicalize(self):
a = Tensor([1, 2, 3], f"{Device.DEFAULT}")
b = Tensor([1, 2, 3], f"{Device.DEFAULT}:0")

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import sys, math
import numpy as np
from typing import Union, Optional, Any, Tuple, List, Set, Dict
from collections import defaultdict
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
@@ -9,7 +10,7 @@ from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, Device
from tinygrad.graph import log_lazybuffer
from weakref import ref, WeakSet, WeakValueDictionary, ReferenceType
from weakref import ref, WeakValueDictionary, ReferenceType
# lazy can recurse a lot
sys.setrecursionlimit(10000)
@@ -42,8 +43,6 @@ class LazyBuffer:
self.output_buffer: Optional[Buffer] = None
self.forced_realize = False
self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None
self.children: WeakSet[LazyBuffer] = WeakSet()
for x in srcs: x.base.children.add(self.base)
else:
# properties on view
assert base.base == base, "base must be a base itself"
@@ -213,7 +212,8 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
[ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})]
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]):
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
if buf in allbufs or buf.base.realized: return
log_lazybuffer(buf)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
@@ -228,14 +228,16 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
@@ -252,7 +254,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads)
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
# check if we have to realize pads
for p in simple_pads:
@@ -282,7 +285,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
forced_realize = True
break
continue
for tr_next in tr.children:
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
@@ -299,8 +302,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(tr.children) == 1:
tr_next = next(iter(tr.children))
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break