From 7ea2e0035b30dd6a75c0ac2f4446ee2c68f7455a Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 8 Jan 2024 15:02:32 -0800 Subject: [PATCH] move children for speed (#3047) * move children for speed * no children anymore --- test/test_lazybuffer.py | 21 --------------------- tinygrad/lazy.py | 25 ++++++++++++++----------- 2 files changed, 14 insertions(+), 32 deletions(-) diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 92131d9839..bf32055aaa 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -4,7 +4,6 @@ import unittest from tinygrad.lazy import LazyBuffer from tinygrad import Device from tinygrad.tensor import Tensor -from tinygrad.jit import CacheCollector class TestLazyBuffer(unittest.TestCase): @unittest.skip("it doesn't work like this anymore") @@ -46,26 +45,6 @@ class TestLazyBuffer(unittest.TestCase): z = Tensor([1, np.e]).numpy() np.testing.assert_allclose(y, z) - @unittest.skipUnless(Device.DEFAULT in ["METAL", "CUDA", "GPU"], "Only GPU backends supports cache") - def test_children_count(self): - a = Tensor.ones(8,8,8) - d1 = a.sum((0)) - d2 = a.sum((0)).reshape(32,2) # noqa: F841 - assert len(d1.lazydata.base.srcs[0].base.children) == 1 - in1 = d1.reshape(16,4) - d3 = in1.reshape(8,8) - assert len(d3.lazydata.base.srcs[0].base.children) == 1 - - CacheCollector.start() - l = Tensor.ones(8,8) - r = Tensor.ones(8,8) - dd = d1 + l - dd.realize() - de = d3 + r - de.realize() - cache = CacheCollector.finish() - assert len(cache) == 2 - def test_device_canonicalize(self): a = Tensor([1, 2, 3], f"{Device.DEFAULT}") b = Tensor([1, 2, 3], f"{Device.DEFAULT}:0") diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index a059b0ec93..c72d3694c4 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys, math import numpy as np -from typing import Union, Optional, Any, Tuple, List, Set, Dict +from collections import defaultdict +from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict from tinygrad.dtype import dtypes, DType, ImageDType from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem @@ -9,7 +10,7 @@ from tinygrad.shape.symbolic import sint, Variable from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer, Device from tinygrad.graph import log_lazybuffer -from weakref import ref, WeakSet, WeakValueDictionary, ReferenceType +from weakref import ref, WeakValueDictionary, ReferenceType # lazy can recurse a lot sys.setrecursionlimit(10000) @@ -42,8 +43,6 @@ class LazyBuffer: self.output_buffer: Optional[Buffer] = None self.forced_realize = False self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None - self.children: WeakSet[LazyBuffer] = WeakSet() - for x in srcs: x.base.children.add(self.base) else: # properties on view assert base.base == base, "base must be a base itself" @@ -213,7 +212,8 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB [ScheduleItem(op, out, tuple(inputs), {k:var_vals[k] for k in op.vars()})] # recursively search the entire graph for all LazyBuffers, insert realizes after expands -def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], simple_pads:Set[LazyBuffer]): +def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None], + simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]): if buf in allbufs or buf.base.realized: return log_lazybuffer(buf) if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or @@ -228,14 +228,16 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe simple_pads.add(buf.base) else: realizes.add(buf.base) - return _recurse_lb(buf.base, realizes, allbufs, simple_pads) + return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children) if buf.forced_realize: realizes.add(buf) allbufs[buf] = None if buf.op in LoadOps: realizes.add(buf.base) if buf.op == LoadOps.COPY: assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig" realizes.add(buf.srcs[0].base) - for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads) + for x in buf.srcs: + children[x.base][buf] = None + _recurse_lb(x, realizes, allbufs, simple_pads, children) UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP} def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool: @@ -252,7 +254,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized]) allbufs: Dict[LazyBuffer, None] = {} simple_pads: Set[LazyBuffer] = set() - for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads) + children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) + for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children) # check if we have to realize pads for p in simple_pads: @@ -282,7 +285,7 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) forced_realize = True break continue - for tr_next in tr.children: + for tr_next in children[tr].keys(): if not tr_next.realized: # max one reduceop per kernel if tr_next.op in ReduceOps: @@ -299,8 +302,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) if can_chase: # can chase this down to contiguous children st = tr.st - while len(tr.children) == 1: - tr_next = next(iter(tr.children)) + while len(children[tr]) == 1: + tr_next = next(iter(children[tr].keys())) st_childs = dedup([s for s in tr_next.srcs if s.base == tr]) if len(st_childs) > 1: break if st.size != st_childs[0].st.size: break