mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
much better cache
This commit is contained in:
@@ -30,17 +30,6 @@ MERGE_ELEMENTWISE_OPS = OPT>=2
|
||||
SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_INTO_CONV_OUTPUTS = OPT>=3, OPT>=3
|
||||
SHUFFLE_SLICE_OPS = OPT>=4 # NOTE: 0/0 is NaN if you slice, so this can change the output
|
||||
|
||||
# this is doing something, but missing some
|
||||
# TODO: see the test for why, double reshape isn't cached
|
||||
def lazycache(func):
|
||||
cache = weakref.WeakValueDictionary()
|
||||
def wrapper(*args):
|
||||
weakargs = tuple(weakref.ref(x) if isinstance(x, LazyBuffer) else x for x in args)
|
||||
# NOTE: even though we don't use ref, we need to keep it around or it will be deleted
|
||||
if weakargs not in cache: cache[weakargs] = ret = func(*args)
|
||||
return cache[weakargs]
|
||||
return wrapper
|
||||
|
||||
# **** enumerate supported devices ****
|
||||
|
||||
class Device:
|
||||
@@ -174,11 +163,21 @@ class LazyOp(NamedTuple):
|
||||
|
||||
def get_lazybuffers(op:LazyOp) -> List[LazyBuffer]: return functools.reduce(operator.add, [get_lazybuffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
|
||||
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
|
||||
def get_weakop(y:LazyBuffer) -> LazyOp: return LazyOp(y.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in y.src), y.arg)
|
||||
|
||||
LAZY = int(os.getenv("LAZY", "1"))
|
||||
|
||||
class LazyBuffer:
|
||||
lazycache = weakref.WeakValueDictionary()
|
||||
def __new__(cls, device, shape, optype, op):
|
||||
# loadops aren't cached
|
||||
if optype == LoadOps: return super().__new__(cls)
|
||||
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
|
||||
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls)
|
||||
return LazyBuffer.lazycache[wop]
|
||||
|
||||
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
|
||||
if getattr(self, 'device', None) is not None: return # cache hit, we return and don't reinit
|
||||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape = self.st.shape
|
||||
self.optype, self.op = optype, op
|
||||
@@ -216,11 +215,9 @@ class LazyBuffer:
|
||||
def unary_op(x:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, x)
|
||||
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, x, y)
|
||||
|
||||
@lazycache
|
||||
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
|
||||
|
||||
@lazycache
|
||||
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
|
||||
# TODO: look into why that copy is needed
|
||||
arg = copy(arg)
|
||||
@@ -245,11 +242,9 @@ class LazyBuffer:
|
||||
|
||||
return ret
|
||||
|
||||
@lazycache
|
||||
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
|
||||
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
|
||||
|
||||
@lazycache
|
||||
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
|
||||
out_device, out_shape = srcs[0].device, srcs[0].shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user