simplify CacheCollector (#1944)

* rewrite cc

* fix

* fix tests

* fix all tests

* is it better

* better with shape

* cleaner

* linter fix

* no ;

* better comment

* better comments

* no thneed changes
This commit is contained in:
nimlgen
2023-09-29 20:13:04 +03:00
committed by GitHub
parent 90326dbdc3
commit 692bec7b6f
3 changed files with 41 additions and 58 deletions

View File

@@ -59,8 +59,9 @@ class TestCacheCollector(unittest.TestCase):
assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match." assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 4, "Should have 4 buffers in total" assert get_bufs_count(cache) == 5, "Should have 5 buffers in total"
assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel" # This is not worth added complexity on real models
# assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
FAKE_GLOBAL_ALLOCATOR = None FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_cycle_avoidance(self): def test_cache_collector_cycle_avoidance(self):
@@ -78,8 +79,8 @@ class TestCacheCollector(unittest.TestCase):
assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match." assert cache[-1][1][0] == out, "Output does not match."
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th." assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
FAKE_GLOBAL_ALLOCATOR = None FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_all_alive(self): def test_cache_collector_all_alive(self):
@@ -144,7 +145,7 @@ class TestCacheCollector(unittest.TestCase):
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}." assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
for i in range(3, 6): for i in range(3, 6):
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}." assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
assert get_bufs_count(cache) == 6 assert get_bufs_count(cache) == 7
FAKE_GLOBAL_ALLOCATOR = None FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_anybufs_inputs(self): def test_cache_collector_anybufs_inputs(self):
@@ -161,7 +162,7 @@ class TestCacheCollector(unittest.TestCase):
cache = CacheCollector.finish() cache = CacheCollector.finish()
assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[0][1][1] == inps[0], "Input should be on its place."
assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place."
assert get_bufs_count(cache) == 7 assert get_bufs_count(cache) == 8
FAKE_GLOBAL_ALLOCATOR = None FAKE_GLOBAL_ALLOCATOR = None
def test_cache_collector_optimize_when_not_cached_anymore(self): def test_cache_collector_optimize_when_not_cached_anymore(self):
@@ -201,7 +202,7 @@ class TestCacheCollector(unittest.TestCase):
assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place."
assert cache[-1][1][0] == out, "Output does not match." assert cache[-1][1][0] == out, "Output does not match."
assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself" assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself"
assert get_bufs_count(cache) == 7, "Should have 7 buffers in total" assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
FAKE_GLOBAL_ALLOCATOR = None FAKE_GLOBAL_ALLOCATOR = None
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -2,7 +2,7 @@ from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional, Set
from weakref import ref from weakref import ref
from collections import defaultdict from collections import defaultdict
import functools, itertools import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@@ -70,63 +70,47 @@ class TinyJit:
class _CacheCollector: class _CacheCollector:
class _Placeholder: class _Placeholder:
def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf) def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf)
def alive(self): return self.ref() is not None
def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict())) def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
def __init__(self): def __init__(self):
self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
self.placeholders: Dict[RawBuffer, _CacheCollector._Placeholder] = {} # Rawbuffers are replaced with placeholders to allow freeing of the real buffer while collecting cache. self.placeholders: Dict[ref[RawBuffer], _CacheCollector._Placeholder] = {} # Output rawbufs are replaced with placeholders to allow freeing of the real buffer while collecting cache.
self.last_buftype: Dict[Tuple[int,...], int] = {} # Last index of the cached entry where a buffer with the shape (shape is a key) is used as input to the prog.
self.last_placeholder_index: Dict[_CacheCollector._Placeholder, int] = {} # Last index where the placeholder is used as output. This allows tracking when we need to stick to the original buffer if it is still alive.
self.freed_placeholders: Dict[Tuple[int,...], List[_CacheCollector._Placeholder]] = defaultdict(list)
self.circular_signatures: Set[Any] = set() self.circular_signatures: Set[Any] = set()
def start(self): def start(self): self.cache, self.placeholders, self.circular_signatures = [], {}, set()
self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = [], {}, {}, {}, defaultdict(list), set()
def add(self, prg, rawbufs, var_vals): def add(self, prg, rawbufs, var_vals):
if self.cache is None: return if self.cache is None: return
# When we got buffers with the same signature, we can use just 1(max 2, see cycle avoidance below) buffer insted of all of them. # Substitute output buffers with placeholders to find the most optimal reusage.
# Current implementation of a signature is an underlying buffer, because if 2 or more different RawBuffers shares the same, all but the very last are dead. if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
for buf in rawbufs[1:]: cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
# Check if the input matches any of placeholder to determine if it's existing or newly created input. self.cache.append((prg, cached_rawbufs, var_vals))
# In case of newly created input remove placeholder and capture the whole buffer.
if isinstance(buf, RawBuffer) and self._get_signature(buf) in self.placeholders and self.placeholders[self._get_signature(buf)].ref != ref(buf):
self.placeholders.pop(self._get_signature(buf))
if isinstance(buf, RawBuffer) and self._get_signature(buf) not in self.placeholders:
self.last_buftype[self._buftype_key(buf)] = len(self.cache)
# Creating/updating a placeholder for the current output buffer. If we update output, set the ref to point to the new RawBuffer,
# since the previous RawBuffer is dead (overwise we won't get a new RawBuffer with the same signature). Do not care about dead buffers, they 100% could be replaced with any other buffer.
if self._get_signature(rawbufs[0]) in self.placeholders: self.placeholders[self._get_signature(rawbufs[0])].ref = ref(rawbufs[0])
else:
# This is a new output buffer. Try to reuse any freed placeholders with the same type to "merge" these buffers.
# If this output buffer is "output_buffer", reusage of the output buffer is scary.
plh = self.freed_placeholders[self._buftype_key(rawbufs[0])].pop() if self._get_signature(rawbufs[0]) not in self.circular_signatures and self.freed_placeholders[self._buftype_key(rawbufs[0])] else _CacheCollector._Placeholder(rawbufs[0])
self.placeholders.setdefault(self._get_signature(rawbufs[0]), plh).ref = ref(rawbufs[0])
self.last_placeholder_index[self.placeholders[self._get_signature(rawbufs[0])]] = len(self.cache)
self.cache.append((prg,[self.placeholders.get(self._get_signature(x), x) for x in rawbufs],var_vals))
def finish(self): def finish(self):
if self.cache is None: return [] if self.cache is None: return []
placeholder_mapper, cache_result = {}, []
for j,(p,cached_bufs,var_vals) in enumerate(self.cache): rawbuf_pool: List[Tuple[RawBuffer, List[Tuple[int, int]]]] = []
if cached_bufs[0].__class__ is _CacheCollector._Placeholder: buf_usage_bounds: Dict[_CacheCollector._Placeholder, Tuple[int, int]] = {}
if cached_bufs[0].alive(): buf_map: Dict[_CacheCollector._Placeholder, RawBuffer] = {}
# Since the placeholder is alive (someone holds refed RawBuffer) to avoid hazards when this output buffer could be used as input on the other launch (e.g., LSTM),
# we allocate a backing buffer and and use it until the penultimate entry (the last entry is 100% safe to use the original RawBuffer). for j,(_,bufs,_) in enumerate(self.cache):
if self.last_buftype.get(self._buftype_key(cached_bufs[0]), -1) < j or self.last_placeholder_index[cached_bufs[0]] == j: for buf in bufs:
# Safe to use the original buffer when all inputs buffers of the same size and dtype are behind or this is the last usage of this buffer as output. if buf.__class__ is not _CacheCollector._Placeholder: continue
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].ref() if buf.ref() is not None: buf_map[buf] = buf.ref() # rawbufs that are referenced are not replaced but are used as is.
elif cached_bufs[0] not in placeholder_mapper: else: buf_usage_bounds[buf] = buf_usage_bounds.get(buf, (j, j))[0], j
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf() # Allocating a backing buffer.
elif cached_bufs[0] not in placeholder_mapper: # The query list contains a query for every placeholder that should be replaced with the actual rawbuffer. Queries are served from the largest to the smallest.
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf() # For each query, find any rawbuffer that is free within the query timeframe or allocate a new one.
cache_result.append((p, [placeholder_mapper.get(buf, buf) for buf in cached_bufs], var_vals)) query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = None, {}, {}, {}, defaultdict(list), set() for _, start, end, buf in query_list:
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
if pool_idx == -1:
rawbuf_pool.append((buf.alloc_rawbuf(), []))
pool_idx = len(rawbuf_pool) - 1
buf_map[buf] = rawbuf_pool[pool_idx][0]
rawbuf_pool[pool_idx][1].append((start, end))
cache_result = [(p, [buf_map.get(buf, buf) for buf in cached_bufs], var_vals) for p, cached_bufs, var_vals in self.cache]
self.cache = None
return cache_result return cache_result
def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(self._get_signature(output_buffer)) def _no_intersect(self, start:int, end:int, usages:List[Tuple[int, int]]): return all(en < start or end < st for st, en in usages)
def _on_buf_free(self, underlying_buf): def _can_substitute(self, buf, with_buf): return buf._device==with_buf._device and (buf.size*buf.dtype.itemsize<=with_buf.size*with_buf.dtype.itemsize if not isinstance(buf.dtype, ImageDType) and not isinstance(with_buf.dtype, ImageDType) else buf.size==with_buf.size and buf.dtype==with_buf.dtype and buf.dtype.shape==with_buf.dtype.shape)
if underlying_buf not in self.placeholders: return def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(ref(output_buffer))
self.freed_placeholders[self._buftype_key(self.placeholders[underlying_buf])].append(self.placeholders[underlying_buf])
self.placeholders.pop(underlying_buf)
def _get_signature(self, buf): return buf._buf if getattr(buf, '_buf', None) is not None and getattr(buf, '_allocator', None) is not None else buf
def _buftype_key(self, buf): return (buf.size, buf.dtype, buf._device, buf.dtype.shape if hasattr(buf.dtype, 'shape') else None)
CacheCollector = _CacheCollector() CacheCollector = _CacheCollector()

View File

@@ -92,8 +92,6 @@ class LRUAllocator:
self.buffer_info[newbuf] = (size, dtype, device) self.buffer_info[newbuf] = (size, dtype, device)
return newbuf return newbuf
def _free_buffer(self, buf_to_free): def _free_buffer(self, buf_to_free):
from tinygrad.jit import CacheCollector
CacheCollector._on_buf_free(buf_to_free)
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free) self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free) GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
self.buffer_info.pop(buf_to_free) self.buffer_info.pop(buf_to_free)