mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
simplify CacheCollector (#1944)
* rewrite cc * fix * fix tests * fix all tests * is it better * better with shape * cleaner * linter fix * no ; * better comment * better comments * no thneed changes
This commit is contained in:
@@ -59,8 +59,9 @@ class TestCacheCollector(unittest.TestCase):
|
|||||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||||
assert cache[-1][1][0] == out, "Output does not match."
|
assert cache[-1][1][0] == out, "Output does not match."
|
||||||
assert get_bufs_count(cache) == 4, "Should have 4 buffers in total"
|
assert get_bufs_count(cache) == 5, "Should have 5 buffers in total"
|
||||||
assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
|
# This is not worth added complexity on real models
|
||||||
|
# assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel"
|
||||||
FAKE_GLOBAL_ALLOCATOR = None
|
FAKE_GLOBAL_ALLOCATOR = None
|
||||||
|
|
||||||
def test_cache_collector_cycle_avoidance(self):
|
def test_cache_collector_cycle_avoidance(self):
|
||||||
@@ -78,8 +79,8 @@ class TestCacheCollector(unittest.TestCase):
|
|||||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||||
assert cache[-1][1][0] == out, "Output does not match."
|
assert cache[-1][1][0] == out, "Output does not match."
|
||||||
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
|
||||||
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
|
assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th."
|
||||||
|
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
||||||
FAKE_GLOBAL_ALLOCATOR = None
|
FAKE_GLOBAL_ALLOCATOR = None
|
||||||
|
|
||||||
def test_cache_collector_all_alive(self):
|
def test_cache_collector_all_alive(self):
|
||||||
@@ -144,7 +145,7 @@ class TestCacheCollector(unittest.TestCase):
|
|||||||
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
||||||
for i in range(3, 6):
|
for i in range(3, 6):
|
||||||
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}."
|
||||||
assert get_bufs_count(cache) == 6
|
assert get_bufs_count(cache) == 7
|
||||||
FAKE_GLOBAL_ALLOCATOR = None
|
FAKE_GLOBAL_ALLOCATOR = None
|
||||||
|
|
||||||
def test_cache_collector_anybufs_inputs(self):
|
def test_cache_collector_anybufs_inputs(self):
|
||||||
@@ -161,7 +162,7 @@ class TestCacheCollector(unittest.TestCase):
|
|||||||
cache = CacheCollector.finish()
|
cache = CacheCollector.finish()
|
||||||
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
assert cache[0][1][1] == inps[0], "Input should be on its place."
|
||||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||||
assert get_bufs_count(cache) == 7
|
assert get_bufs_count(cache) == 8
|
||||||
FAKE_GLOBAL_ALLOCATOR = None
|
FAKE_GLOBAL_ALLOCATOR = None
|
||||||
|
|
||||||
def test_cache_collector_optimize_when_not_cached_anymore(self):
|
def test_cache_collector_optimize_when_not_cached_anymore(self):
|
||||||
@@ -201,7 +202,7 @@ class TestCacheCollector(unittest.TestCase):
|
|||||||
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
assert cache[1][1][2] == inps[1], "Input should be on its place."
|
||||||
assert cache[-1][1][0] == out, "Output does not match."
|
assert cache[-1][1][0] == out, "Output does not match."
|
||||||
assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself"
|
assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself"
|
||||||
assert get_bufs_count(cache) == 7, "Should have 7 buffers in total"
|
assert get_bufs_count(cache) == 6, "Should have 6 buffers in total"
|
||||||
FAKE_GLOBAL_ALLOCATOR = None
|
FAKE_GLOBAL_ALLOCATOR = None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional, Set
|
|||||||
from weakref import ref
|
from weakref import ref
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import functools, itertools
|
import functools, itertools
|
||||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType
|
||||||
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
|
from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
@@ -70,63 +70,47 @@ class TinyJit:
|
|||||||
class _CacheCollector:
|
class _CacheCollector:
|
||||||
class _Placeholder:
|
class _Placeholder:
|
||||||
def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf)
|
def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf)
|
||||||
def alive(self): return self.ref() is not None
|
|
||||||
def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
|
def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict()))
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
|
self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None
|
||||||
self.placeholders: Dict[RawBuffer, _CacheCollector._Placeholder] = {} # Rawbuffers are replaced with placeholders to allow freeing of the real buffer while collecting cache.
|
self.placeholders: Dict[ref[RawBuffer], _CacheCollector._Placeholder] = {} # Output rawbufs are replaced with placeholders to allow freeing of the real buffer while collecting cache.
|
||||||
self.last_buftype: Dict[Tuple[int,...], int] = {} # Last index of the cached entry where a buffer with the shape (shape is a key) is used as input to the prog.
|
|
||||||
self.last_placeholder_index: Dict[_CacheCollector._Placeholder, int] = {} # Last index where the placeholder is used as output. This allows tracking when we need to stick to the original buffer if it is still alive.
|
|
||||||
self.freed_placeholders: Dict[Tuple[int,...], List[_CacheCollector._Placeholder]] = defaultdict(list)
|
|
||||||
self.circular_signatures: Set[Any] = set()
|
self.circular_signatures: Set[Any] = set()
|
||||||
def start(self):
|
def start(self): self.cache, self.placeholders, self.circular_signatures = [], {}, set()
|
||||||
self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = [], {}, {}, {}, defaultdict(list), set()
|
|
||||||
def add(self, prg, rawbufs, var_vals):
|
def add(self, prg, rawbufs, var_vals):
|
||||||
if self.cache is None: return
|
if self.cache is None: return
|
||||||
# When we got buffers with the same signature, we can use just 1(max 2, see cycle avoidance below) buffer insted of all of them.
|
# Substitute output buffers with placeholders to find the most optimal reusage.
|
||||||
# Current implementation of a signature is an underlying buffer, because if 2 or more different RawBuffers shares the same, all but the very last are dead.
|
if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0])
|
||||||
for buf in rawbufs[1:]:
|
cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs]
|
||||||
# Check if the input matches any of placeholder to determine if it's existing or newly created input.
|
self.cache.append((prg, cached_rawbufs, var_vals))
|
||||||
# In case of newly created input remove placeholder and capture the whole buffer.
|
|
||||||
if isinstance(buf, RawBuffer) and self._get_signature(buf) in self.placeholders and self.placeholders[self._get_signature(buf)].ref != ref(buf):
|
|
||||||
self.placeholders.pop(self._get_signature(buf))
|
|
||||||
if isinstance(buf, RawBuffer) and self._get_signature(buf) not in self.placeholders:
|
|
||||||
self.last_buftype[self._buftype_key(buf)] = len(self.cache)
|
|
||||||
|
|
||||||
# Creating/updating a placeholder for the current output buffer. If we update output, set the ref to point to the new RawBuffer,
|
|
||||||
# since the previous RawBuffer is dead (overwise we won't get a new RawBuffer with the same signature). Do not care about dead buffers, they 100% could be replaced with any other buffer.
|
|
||||||
if self._get_signature(rawbufs[0]) in self.placeholders: self.placeholders[self._get_signature(rawbufs[0])].ref = ref(rawbufs[0])
|
|
||||||
else:
|
|
||||||
# This is a new output buffer. Try to reuse any freed placeholders with the same type to "merge" these buffers.
|
|
||||||
# If this output buffer is "output_buffer", reusage of the output buffer is scary.
|
|
||||||
plh = self.freed_placeholders[self._buftype_key(rawbufs[0])].pop() if self._get_signature(rawbufs[0]) not in self.circular_signatures and self.freed_placeholders[self._buftype_key(rawbufs[0])] else _CacheCollector._Placeholder(rawbufs[0])
|
|
||||||
self.placeholders.setdefault(self._get_signature(rawbufs[0]), plh).ref = ref(rawbufs[0])
|
|
||||||
self.last_placeholder_index[self.placeholders[self._get_signature(rawbufs[0])]] = len(self.cache)
|
|
||||||
self.cache.append((prg,[self.placeholders.get(self._get_signature(x), x) for x in rawbufs],var_vals))
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.cache is None: return []
|
if self.cache is None: return []
|
||||||
placeholder_mapper, cache_result = {}, []
|
|
||||||
for j,(p,cached_bufs,var_vals) in enumerate(self.cache):
|
rawbuf_pool: List[Tuple[RawBuffer, List[Tuple[int, int]]]] = []
|
||||||
if cached_bufs[0].__class__ is _CacheCollector._Placeholder:
|
buf_usage_bounds: Dict[_CacheCollector._Placeholder, Tuple[int, int]] = {}
|
||||||
if cached_bufs[0].alive():
|
buf_map: Dict[_CacheCollector._Placeholder, RawBuffer] = {}
|
||||||
# Since the placeholder is alive (someone holds refed RawBuffer) to avoid hazards when this output buffer could be used as input on the other launch (e.g., LSTM),
|
|
||||||
# we allocate a backing buffer and and use it until the penultimate entry (the last entry is 100% safe to use the original RawBuffer).
|
for j,(_,bufs,_) in enumerate(self.cache):
|
||||||
if self.last_buftype.get(self._buftype_key(cached_bufs[0]), -1) < j or self.last_placeholder_index[cached_bufs[0]] == j:
|
for buf in bufs:
|
||||||
# Safe to use the original buffer when all inputs buffers of the same size and dtype are behind or this is the last usage of this buffer as output.
|
if buf.__class__ is not _CacheCollector._Placeholder: continue
|
||||||
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].ref()
|
if buf.ref() is not None: buf_map[buf] = buf.ref() # rawbufs that are referenced are not replaced but are used as is.
|
||||||
elif cached_bufs[0] not in placeholder_mapper:
|
else: buf_usage_bounds[buf] = buf_usage_bounds.get(buf, (j, j))[0], j
|
||||||
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf() # Allocating a backing buffer.
|
|
||||||
elif cached_bufs[0] not in placeholder_mapper:
|
# The query list contains a query for every placeholder that should be replaced with the actual rawbuffer. Queries are served from the largest to the smallest.
|
||||||
placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf()
|
# For each query, find any rawbuffer that is free within the query timeframe or allocate a new one.
|
||||||
cache_result.append((p, [placeholder_mapper.get(buf, buf) for buf in cached_bufs], var_vals))
|
query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True)
|
||||||
self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = None, {}, {}, {}, defaultdict(list), set()
|
for _, start, end, buf in query_list:
|
||||||
|
pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1)
|
||||||
|
if pool_idx == -1:
|
||||||
|
rawbuf_pool.append((buf.alloc_rawbuf(), []))
|
||||||
|
pool_idx = len(rawbuf_pool) - 1
|
||||||
|
buf_map[buf] = rawbuf_pool[pool_idx][0]
|
||||||
|
rawbuf_pool[pool_idx][1].append((start, end))
|
||||||
|
|
||||||
|
cache_result = [(p, [buf_map.get(buf, buf) for buf in cached_bufs], var_vals) for p, cached_bufs, var_vals in self.cache]
|
||||||
|
self.cache = None
|
||||||
return cache_result
|
return cache_result
|
||||||
def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(self._get_signature(output_buffer))
|
def _no_intersect(self, start:int, end:int, usages:List[Tuple[int, int]]): return all(en < start or end < st for st, en in usages)
|
||||||
def _on_buf_free(self, underlying_buf):
|
def _can_substitute(self, buf, with_buf): return buf._device==with_buf._device and (buf.size*buf.dtype.itemsize<=with_buf.size*with_buf.dtype.itemsize if not isinstance(buf.dtype, ImageDType) and not isinstance(with_buf.dtype, ImageDType) else buf.size==with_buf.size and buf.dtype==with_buf.dtype and buf.dtype.shape==with_buf.dtype.shape)
|
||||||
if underlying_buf not in self.placeholders: return
|
def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(ref(output_buffer))
|
||||||
self.freed_placeholders[self._buftype_key(self.placeholders[underlying_buf])].append(self.placeholders[underlying_buf])
|
|
||||||
self.placeholders.pop(underlying_buf)
|
|
||||||
def _get_signature(self, buf): return buf._buf if getattr(buf, '_buf', None) is not None and getattr(buf, '_allocator', None) is not None else buf
|
|
||||||
def _buftype_key(self, buf): return (buf.size, buf.dtype, buf._device, buf.dtype.shape if hasattr(buf.dtype, 'shape') else None)
|
|
||||||
CacheCollector = _CacheCollector()
|
CacheCollector = _CacheCollector()
|
||||||
|
|||||||
@@ -92,8 +92,6 @@ class LRUAllocator:
|
|||||||
self.buffer_info[newbuf] = (size, dtype, device)
|
self.buffer_info[newbuf] = (size, dtype, device)
|
||||||
return newbuf
|
return newbuf
|
||||||
def _free_buffer(self, buf_to_free):
|
def _free_buffer(self, buf_to_free):
|
||||||
from tinygrad.jit import CacheCollector
|
|
||||||
CacheCollector._on_buf_free(buf_to_free)
|
|
||||||
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
|
self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free)
|
||||||
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
|
GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free)
|
||||||
self.buffer_info.pop(buf_to_free)
|
self.buffer_info.pop(buf_to_free)
|
||||||
|
|||||||
Reference in New Issue
Block a user