diff --git a/test/test_cache_collector.py b/test/test_cache_collector.py index 2989f97d32..72f6af1240 100644 --- a/test/test_cache_collector.py +++ b/test/test_cache_collector.py @@ -59,8 +59,9 @@ class TestCacheCollector(unittest.TestCase): assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[-1][1][0] == out, "Output does not match." - assert get_bufs_count(cache) == 4, "Should have 4 buffers in total" - assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel" + assert get_bufs_count(cache) == 5, "Should have 5 buffers in total" + # This is not worth added complexity on real models + # assert cache[-1][1][0] == cache[0][1][0], "Should reuse final output buffer as output in 1st kernel" FAKE_GLOBAL_ALLOCATOR = None def test_cache_collector_cycle_avoidance(self): @@ -78,8 +79,8 @@ class TestCacheCollector(unittest.TestCase): assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[-1][1][0] == out, "Output does not match." - assert get_bufs_count(cache) == 6, "Should have 6 buffers in total" assert cache[-1][1][0] != cache[0][1][0] and cache[0][1][0] == cache[3][1][0], "Output buffers from 1st and 4th kernel could not be the same as the 5th." + assert get_bufs_count(cache) == 6, "Should have 6 buffers in total" FAKE_GLOBAL_ALLOCATOR = None def test_cache_collector_all_alive(self): @@ -144,7 +145,7 @@ class TestCacheCollector(unittest.TestCase): assert cache[i][1][0]._device == '1', f"Device does not match {i}, has {cache[i][1][0]._device}." for i in range(3, 6): assert cache[i][1][0]._device == '2', f"Device does not match {i}, has {cache[i][1][0]._device}." - assert get_bufs_count(cache) == 6 + assert get_bufs_count(cache) == 7 FAKE_GLOBAL_ALLOCATOR = None def test_cache_collector_anybufs_inputs(self): @@ -161,7 +162,7 @@ class TestCacheCollector(unittest.TestCase): cache = CacheCollector.finish() assert cache[0][1][1] == inps[0], "Input should be on its place." assert cache[1][1][2] == inps[1], "Input should be on its place." - assert get_bufs_count(cache) == 7 + assert get_bufs_count(cache) == 8 FAKE_GLOBAL_ALLOCATOR = None def test_cache_collector_optimize_when_not_cached_anymore(self): @@ -201,7 +202,7 @@ class TestCacheCollector(unittest.TestCase): assert cache[1][1][2] == inps[1], "Input should be on its place." assert cache[-1][1][0] == out, "Output does not match." assert cache[0][1][0] != cache[3][1][0], "Cannot reuse 4th output buffer, it's an output buffer which might ovewrite itself" - assert get_bufs_count(cache) == 7, "Should have 7 buffers in total" + assert get_bufs_count(cache) == 6, "Should have 6 buffers in total" FAKE_GLOBAL_ALLOCATOR = None if __name__ == "__main__": diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 1fce6044c8..3f1f353ff6 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -2,7 +2,7 @@ from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional, Set from weakref import ref from collections import defaultdict import functools, itertools -from tinygrad.helpers import DEBUG, DType, merge_dicts +from tinygrad.helpers import DEBUG, DType, merge_dicts, ImageDType from tinygrad.ops import RawBuffer, Device, BasicBatchExecutor from tinygrad.tensor import Tensor from tinygrad.shape.shapetracker import ShapeTracker @@ -70,63 +70,47 @@ class TinyJit: class _CacheCollector: class _Placeholder: def __init__(self, buf): self.size, self.dtype, self._device, self.ref, self.buftype = buf.size, buf.dtype, getattr(buf, '_device', None), ref(buf), type(buf) - def alive(self): return self.ref() is not None def alloc_rawbuf(self): return self.buftype(self.size, self.dtype, **({'device':self._device} if self._device is not None else dict())) def __init__(self): self.cache: Optional[List[Tuple[Callable, List[Any], Dict[Any,Any]]]] = None - self.placeholders: Dict[RawBuffer, _CacheCollector._Placeholder] = {} # Rawbuffers are replaced with placeholders to allow freeing of the real buffer while collecting cache. - self.last_buftype: Dict[Tuple[int,...], int] = {} # Last index of the cached entry where a buffer with the shape (shape is a key) is used as input to the prog. - self.last_placeholder_index: Dict[_CacheCollector._Placeholder, int] = {} # Last index where the placeholder is used as output. This allows tracking when we need to stick to the original buffer if it is still alive. - self.freed_placeholders: Dict[Tuple[int,...], List[_CacheCollector._Placeholder]] = defaultdict(list) + self.placeholders: Dict[ref[RawBuffer], _CacheCollector._Placeholder] = {} # Output rawbufs are replaced with placeholders to allow freeing of the real buffer while collecting cache. self.circular_signatures: Set[Any] = set() - def start(self): - self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = [], {}, {}, {}, defaultdict(list), set() + def start(self): self.cache, self.placeholders, self.circular_signatures = [], {}, set() def add(self, prg, rawbufs, var_vals): if self.cache is None: return - # When we got buffers with the same signature, we can use just 1(max 2, see cycle avoidance below) buffer insted of all of them. - # Current implementation of a signature is an underlying buffer, because if 2 or more different RawBuffers shares the same, all but the very last are dead. - for buf in rawbufs[1:]: - # Check if the input matches any of placeholder to determine if it's existing or newly created input. - # In case of newly created input remove placeholder and capture the whole buffer. - if isinstance(buf, RawBuffer) and self._get_signature(buf) in self.placeholders and self.placeholders[self._get_signature(buf)].ref != ref(buf): - self.placeholders.pop(self._get_signature(buf)) - if isinstance(buf, RawBuffer) and self._get_signature(buf) not in self.placeholders: - self.last_buftype[self._buftype_key(buf)] = len(self.cache) - - # Creating/updating a placeholder for the current output buffer. If we update output, set the ref to point to the new RawBuffer, - # since the previous RawBuffer is dead (overwise we won't get a new RawBuffer with the same signature). Do not care about dead buffers, they 100% could be replaced with any other buffer. - if self._get_signature(rawbufs[0]) in self.placeholders: self.placeholders[self._get_signature(rawbufs[0])].ref = ref(rawbufs[0]) - else: - # This is a new output buffer. Try to reuse any freed placeholders with the same type to "merge" these buffers. - # If this output buffer is "output_buffer", reusage of the output buffer is scary. - plh = self.freed_placeholders[self._buftype_key(rawbufs[0])].pop() if self._get_signature(rawbufs[0]) not in self.circular_signatures and self.freed_placeholders[self._buftype_key(rawbufs[0])] else _CacheCollector._Placeholder(rawbufs[0]) - self.placeholders.setdefault(self._get_signature(rawbufs[0]), plh).ref = ref(rawbufs[0]) - self.last_placeholder_index[self.placeholders[self._get_signature(rawbufs[0])]] = len(self.cache) - self.cache.append((prg,[self.placeholders.get(self._get_signature(x), x) for x in rawbufs],var_vals)) + # Substitute output buffers with placeholders to find the most optimal reusage. + if ref(rawbufs[0]) not in self.placeholders: self.placeholders[ref(rawbufs[0])] = _CacheCollector._Placeholder(rawbufs[0]) + cached_rawbufs = [self.placeholders.get(ref(buf), buf) if isinstance(buf, RawBuffer) and ref(buf) not in self.circular_signatures else buf for buf in rawbufs] + self.cache.append((prg, cached_rawbufs, var_vals)) def finish(self): if self.cache is None: return [] - placeholder_mapper, cache_result = {}, [] - for j,(p,cached_bufs,var_vals) in enumerate(self.cache): - if cached_bufs[0].__class__ is _CacheCollector._Placeholder: - if cached_bufs[0].alive(): - # Since the placeholder is alive (someone holds refed RawBuffer) to avoid hazards when this output buffer could be used as input on the other launch (e.g., LSTM), - # we allocate a backing buffer and and use it until the penultimate entry (the last entry is 100% safe to use the original RawBuffer). - if self.last_buftype.get(self._buftype_key(cached_bufs[0]), -1) < j or self.last_placeholder_index[cached_bufs[0]] == j: - # Safe to use the original buffer when all inputs buffers of the same size and dtype are behind or this is the last usage of this buffer as output. - placeholder_mapper[cached_bufs[0]] = cached_bufs[0].ref() - elif cached_bufs[0] not in placeholder_mapper: - placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf() # Allocating a backing buffer. - elif cached_bufs[0] not in placeholder_mapper: - placeholder_mapper[cached_bufs[0]] = cached_bufs[0].alloc_rawbuf() - cache_result.append((p, [placeholder_mapper.get(buf, buf) for buf in cached_bufs], var_vals)) - self.cache, self.placeholders, self.last_buftype, self.last_placeholder_index, self.freed_buffers, self.circular_signatures = None, {}, {}, {}, defaultdict(list), set() + + rawbuf_pool: List[Tuple[RawBuffer, List[Tuple[int, int]]]] = [] + buf_usage_bounds: Dict[_CacheCollector._Placeholder, Tuple[int, int]] = {} + buf_map: Dict[_CacheCollector._Placeholder, RawBuffer] = {} + + for j,(_,bufs,_) in enumerate(self.cache): + for buf in bufs: + if buf.__class__ is not _CacheCollector._Placeholder: continue + if buf.ref() is not None: buf_map[buf] = buf.ref() # rawbufs that are referenced are not replaced but are used as is. + else: buf_usage_bounds[buf] = buf_usage_bounds.get(buf, (j, j))[0], j + + # The query list contains a query for every placeholder that should be replaced with the actual rawbuffer. Queries are served from the largest to the smallest. + # For each query, find any rawbuffer that is free within the query timeframe or allocate a new one. + query_list = sorted([(buf.size*buf.dtype.itemsize, buf_usage_bounds[buf][0], buf_usage_bounds[buf][1], buf) for buf in buf_usage_bounds.keys()], key=lambda x: x[0], reverse=True) + for _, start, end, buf in query_list: + pool_idx = next((i for i,(with_buf, usages) in enumerate(rawbuf_pool) if self._can_substitute(buf, with_buf) and self._no_intersect(start,end,usages)), -1) + if pool_idx == -1: + rawbuf_pool.append((buf.alloc_rawbuf(), [])) + pool_idx = len(rawbuf_pool) - 1 + buf_map[buf] = rawbuf_pool[pool_idx][0] + rawbuf_pool[pool_idx][1].append((start, end)) + + cache_result = [(p, [buf_map.get(buf, buf) for buf in cached_bufs], var_vals) for p, cached_bufs, var_vals in self.cache] + self.cache = None return cache_result - def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(self._get_signature(output_buffer)) - def _on_buf_free(self, underlying_buf): - if underlying_buf not in self.placeholders: return - self.freed_placeholders[self._buftype_key(self.placeholders[underlying_buf])].append(self.placeholders[underlying_buf]) - self.placeholders.pop(underlying_buf) - def _get_signature(self, buf): return buf._buf if getattr(buf, '_buf', None) is not None and getattr(buf, '_allocator', None) is not None else buf - def _buftype_key(self, buf): return (buf.size, buf.dtype, buf._device, buf.dtype.shape if hasattr(buf.dtype, 'shape') else None) + def _no_intersect(self, start:int, end:int, usages:List[Tuple[int, int]]): return all(en < start or end < st for st, en in usages) + def _can_substitute(self, buf, with_buf): return buf._device==with_buf._device and (buf.size*buf.dtype.itemsize<=with_buf.size*with_buf.dtype.itemsize if not isinstance(buf.dtype, ImageDType) and not isinstance(with_buf.dtype, ImageDType) else buf.size==with_buf.size and buf.dtype==with_buf.dtype and buf.dtype.shape==with_buf.dtype.shape) + def _mark_output_buffer(self, output_buffer): self.circular_signatures.add(ref(output_buffer)) CacheCollector = _CacheCollector() diff --git a/tinygrad/runtime/lib.py b/tinygrad/runtime/lib.py index 88bf32b212..5b67467930 100644 --- a/tinygrad/runtime/lib.py +++ b/tinygrad/runtime/lib.py @@ -92,8 +92,6 @@ class LRUAllocator: self.buffer_info[newbuf] = (size, dtype, device) return newbuf def _free_buffer(self, buf_to_free): - from tinygrad.jit import CacheCollector - CacheCollector._on_buf_free(buf_to_free) self.free_space[self.buffer_info[buf_to_free][2]] += self._underlying_buf_memsz(buf_to_free) GlobalCounters.mem_cached -= self._underlying_buf_memsz(buf_to_free) self.buffer_info.pop(buf_to_free)