much better cache

This commit is contained in:
George Hotz
2022-07-07 11:32:00 -07:00
parent eb6696c3a5
commit 9ee8426c51

View File

@@ -30,17 +30,6 @@ MERGE_ELEMENTWISE_OPS = OPT>=2
SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_INTO_CONV_OUTPUTS = OPT>=3, OPT>=3
SHUFFLE_SLICE_OPS = OPT>=4 # NOTE: 0/0 is NaN if you slice, so this can change the output
# this is doing something, but missing some
# TODO: see the test for why, double reshape isn't cached
def lazycache(func):
cache = weakref.WeakValueDictionary()
def wrapper(*args):
weakargs = tuple(weakref.ref(x) if isinstance(x, LazyBuffer) else x for x in args)
# NOTE: even though we don't use ref, we need to keep it around or it will be deleted
if weakargs not in cache: cache[weakargs] = ret = func(*args)
return cache[weakargs]
return wrapper
# **** enumerate supported devices ****
class Device:
@@ -174,11 +163,21 @@ class LazyOp(NamedTuple):
def get_lazybuffers(op:LazyOp) -> List[LazyBuffer]: return functools.reduce(operator.add, [get_lazybuffers(x) if isinstance(x, LazyOp) else [x] for x in op.src], [])
def get_lazyops(op:LazyOp) -> List[LazyOp]: return functools.reduce(operator.add, [get_lazyops(x) for x in op.src if isinstance(x, LazyOp)], [op])
def get_weakop(y:LazyBuffer) -> LazyOp: return LazyOp(y.op, tuple(get_weakop(x) if isinstance(x, LazyOp) else weakref.ref(x) for x in y.src), y.arg)
LAZY = int(os.getenv("LAZY", "1"))
class LazyBuffer:
lazycache = weakref.WeakValueDictionary()
def __new__(cls, device, shape, optype, op):
# loadops aren't cached
if optype == LoadOps: return super().__new__(cls)
wop = (device, optype, get_weakop(op)) # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker
if wop not in LazyBuffer.lazycache: LazyBuffer.lazycache[wop] = ret = super().__new__(cls)
return LazyBuffer.lazycache[wop]
def __init__(self, device, shape:Union[ShapeTracker, Tuple[int, ...]], optype:OpType, op:LazyOp):
if getattr(self, 'device', None) is not None: return # cache hit, we return and don't reinit
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape = self.st.shape
self.optype, self.op = optype, op
@@ -216,11 +215,9 @@ class LazyBuffer:
def unary_op(x:LazyBuffer, op:UnaryOps) -> LazyBuffer: return elementwise_op(op, x)
def binary_op(x:LazyBuffer, op:BinaryOps, y:LazyBuffer) -> LazyBuffer: return elementwise_op(op, x, y)
@lazycache
def reduce_op(x:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
return LazyBuffer(x.device, tuple(new_shape), ReduceOps, LazyOp(op, (x,), tuple(new_shape)))
@lazycache
def movement_op(x:LazyBuffer, op:MovementOps, arg) -> LazyBuffer:
# TODO: look into why that copy is needed
arg = copy(arg)
@@ -245,11 +242,9 @@ class LazyBuffer:
return ret
@lazycache
def processing_op(x:LazyBuffer, op:ProcessingOps, w:LazyBuffer, C:ConvArgs) -> LazyBuffer:
return LazyBuffer(x.device, C.out_shape, ProcessingOps, LazyOp(op, (x, w), C))
@lazycache
def elementwise_op(op:Union[UnaryOps, BinaryOps], *srcs:LazyBuffer) -> LazyBuffer:
out_device, out_shape = srcs[0].device, srcs[0].shape