diff --git a/test/test_ops.py b/test/test_ops.py index db0952d197..8ecda6248a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -292,7 +292,7 @@ class TestOps(unittest.TestCase): def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) helper_test_op([(32,45,65), (32,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-4) - with self.assertRaises(RuntimeError): + with self.assertRaises(AssertionError): a = Tensor(3.14) a.matmul(a) @@ -319,7 +319,7 @@ class TestOps(unittest.TestCase): helper_test_op([(256,256), (256,256)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-3) def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) - with self.assertRaises(RuntimeError): + with self.assertRaises(AssertionError): a = Tensor(3.14) b = Tensor.ones(3,3) a @ b diff --git a/test/unit/test_weak.py b/test/unit/test_weak.py new file mode 100644 index 0000000000..3251257d8a --- /dev/null +++ b/test/unit/test_weak.py @@ -0,0 +1,103 @@ +from tinygrad.helpers import LightWeakSet, LightWeakValueDictionary +import unittest +import time + +CNT = 1000 + +cnt = 0 +class MyObject: + def __init__(self): + global cnt + self.cnt = cnt + cnt += 1 + #print(f"object {self.cnt} created") + #def __del__(self): print(f"object {self.cnt} destroyed") + +class TestWeak(unittest.TestCase): + def test_set_drops(self): + ss = LightWeakSet() + ss.add(MyObject()) + assert len(ss) == 0 + + def test_set_holds(self): + ss = LightWeakSet() + obj = MyObject() + ss.add(obj) + assert len(ss) == 1 + + def test_set_late_drops(self): + ss = LightWeakSet() + obj = MyObject() + ss.add(obj) + assert len(ss) == 1 + del obj + assert len(ss) == 0 + + def test_dict_drops(self): + dd = LightWeakValueDictionary() + dd[0] = MyObject() + assert 0 not in dd + + def test_dict_holds(self): + dd = LightWeakValueDictionary() + dd[0] = ret = MyObject() + assert 0 in dd + + def test_a_myobj_microbench(self): + for _ in range(3): + st = time.perf_counter_ns() + for _ in range(CNT): + obj = MyObject() + et = (time.perf_counter_ns() - st)/CNT + print(f"{et:.2f} ns to create MyObject") + + def test_set_add_microbench(self): + for _ in range(3): + ss = LightWeakSet() + st = time.perf_counter_ns() + for _ in range(CNT): + obj = MyObject() + ss.add(obj) + assert len(ss) == 1 + et = (time.perf_counter_ns() - st)/CNT + print(f"{et:.2f} ns to add to LightWeakSet") + + def test_set_del_microbench(self): + for _ in range(3): + ss = LightWeakSet() + st = time.perf_counter_ns() + for _ in range(CNT): + obj = MyObject() + ss.add(obj) + ss.discard(obj) + assert len(ss) == 0 + et = (time.perf_counter_ns() - st)/CNT + print(f"{et:.2f} ns to add/del from LightWeakSet") + + def test_dict_add_microbench(self): + for _ in range(3): + dd = LightWeakValueDictionary() + st = time.perf_counter_ns() + for i in range(CNT): + obj = MyObject() + dd[i] = obj + assert len(dd) == 1 + et = (time.perf_counter_ns() - st)/CNT + print(f"{et:.2f} ns to add to LightWeakDict") + + def test_dict_check_microbench(self): + for _ in range(3): + dd = LightWeakValueDictionary() + st = time.perf_counter_ns() + for i in range(CNT): + obj = MyObject() + dd[i] = obj + assert i in dd + tst = dd[i] + del obj,tst + assert len(dd) == 0 + et = (time.perf_counter_ns() - st)/CNT + print(f"{et:.2f} ns to add/del from LightWeakDict") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 7fca7a83dc..842331a9e3 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -11,7 +11,7 @@ ShapeType = Tuple[int, ...] OSX = platform.system() == "Darwin" def dedup(x): return list(dict.fromkeys(x)) # retains list order -def argfix(*x): +def argfix(*x): if x[0].__class__ in {tuple, list}: try: return tuple(x[0]) except IndexError: return tuple() @@ -148,6 +148,7 @@ class LightWeakValueDictionary: if o is None: raise KeyError(key) else: return o + def __len__(self): return len(self.data) + def __delitem__(self, key): del self.data[key] def __setitem__(self, key, value): self.data[key] = KeyedRef(value, self._remove, key) - - def __contains__(self, key): return key in self.data \ No newline at end of file + def __contains__(self, key): return key in self.data diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b24cef4eea..a135c7306f 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -5,7 +5,7 @@ import sys, importlib, inspect, functools, pathlib from weakref import ref import numpy as np -from tinygrad.helpers import GRAPH, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary +from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, LightWeakSet, LightWeakValueDictionary from tinygrad.runtime.ops_cpu import RawNumpyBuffer from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.shape.shapetracker import MovementOps, ShapeTracker, get_contraction @@ -17,6 +17,7 @@ sys.setrecursionlimit(10000) OPT = getenv("OPT", 2) LAZY = getenv("LAZY", 1) +LAZYCACHE = getenv("LAZYCACHE", 1) # TODO: movement ops that only change shape are really nops. treat them as such REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 @@ -27,7 +28,7 @@ PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 def _ast_reduceops(self:LazyBuffer) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before src = self.op.src[0] - if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype == BinaryOps and len(src.children) <= 1: + if MERGE_ELEMENTWISE_INTO_REDUCE and not src.realized and src.optype is BinaryOps and len(src.children) <= 1: src = src.op # type: ignore return LazyOp(self.op.op, (src,), self.op.arg) @@ -65,17 +66,15 @@ def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movement lazycache: LightWeakValueDictionary = LightWeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType): - - # fromcpu aren't cached - if optype == LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}: return LazyBuffer(device, st, optype, op, dtype) - #print("create_lazybuffer", device, shape, optype, op, dtype) - # NOTE: shape should be deterministic. annoying to cache with the ShapeTracker - # get_weakop makes all the LazyBuffers in the op have a weakref - wop = (device, dtype, optype, ref(op)) + # fromcpu aren't cached + if not LAZYCACHE or (optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.RAND, LoadOps.CONST}): return LazyBuffer(device, st, optype, op, dtype) + # wop is the deduping key. i feel this used to compare more deeply + wop = (device, dtype, optype, ref(op)) if wop in lazycache: return lazycache[wop] + lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype) return ret @@ -83,7 +82,7 @@ class LazyBuffer: __slots__ = 'st', 'device', 'shape', 'optype', 'dtype', 'op', 'realized', 'output_buffer', 'children', 'node_id', '__weakref__' __deletable__ = ('op',) def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, src:Optional[RawBuffer]=None): - self.st = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker + self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker self.device, self.shape, self.optype, self.dtype = device, self.st.shape, optype, dtype self.realized: Optional[RawBuffer] = src self.output_buffer: Optional[RawBuffer] = None # TODO: do we really need this? or can we just use realized @@ -110,16 +109,15 @@ class LazyBuffer: def realize(self:LazyBuffer) -> LazyBuffer: if not self.realized: # get real ops first - if self.optype in REALIZE_DISPATCHER: - self.op = REALIZE_DISPATCHER[self.optype](self) - elif self.op.op in REALIZE_DISPATCHER: - REALIZE_DISPATCHER[self.op.op](self) + if self.optype is BinaryOps: self.op = _ast_binaryops(self) + elif self.optype is ReduceOps: self.op = _ast_reduceops(self) + elif self.optype is LoadOps: LOAD_OPS_DISPATCHER[cast(LoadOps, self.op.op)](self) # run the ast if we still have to, and log the op if not self.realized: for x in self.op.buffers: x.realize() # HACK: image shape can be wrong, hot cast it back to a normal float - if self.optype != MovementOps and self.dtype.__class__ is ImageDType and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])): + if self.dtype.__class__ is ImageDType and self.optype != MovementOps and (prod(self.shape) != prod(cast(ImageDType, self.dtype).shape) or not any([self.shape[x]%4 == 0 for x in self.st.unit_stride_axes()])): if self.op.op == MovementOps.RESHAPE: # put CAST before the final RESHAPE self.op = LazyOp(MovementOps.RESHAPE, (LazyOp(UnaryOps.CAST, self.op.src, dtypes.float32),), self.op.arg) @@ -130,11 +128,11 @@ class LazyBuffer: assert self.realized and isinstance(self.realized, (RawConst, Device[self.device].buffer)), f"device mismatch on realized got {type(self.realized)} expected {self.device}" # HACK: allow hot casting of images - assert self.realized.dtype == self.dtype or self.dtype.name.startswith("image"), f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}" + assert self.realized.dtype == self.dtype or self.dtype.__class__ is ImageDType, f"dtype mismatch on realize got {self.realized.dtype} expected {self.dtype}" self.dtype = self.realized.dtype # log to the graph - if self.realized.__class__ is not RawConst or GRAPH >= 2: + if (DEBUG or GRAPH) and (self.realized.__class__ is not RawConst or GRAPH >= 2): from tinygrad.graph import log_op log_op(self, self.op) @@ -338,15 +336,13 @@ def _realize_const(buffer: LazyBuffer) -> None: else: buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args()) -REALIZE_DISPATCHER: Dict[Any, Callable] = { +LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = { LoadOps.CONTIGUOUS: _realize_contiguous, LoadOps.CUSTOM: _realize_custom, LoadOps.FROM: _realize_from, LoadOps.EMPTY: _realize_empty, LoadOps.RAND: _realize_rand, LoadOps.CONST: _realize_const, - ReduceOps: _ast_reduceops, - BinaryOps: _ast_binaryops, } MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { diff --git a/tinygrad/nn/image.py b/tinygrad/nn/image.py index 86255a60f5..3eddcccdfe 100644 --- a/tinygrad/nn/image.py +++ b/tinygrad/nn/image.py @@ -7,7 +7,8 @@ base_image_type = (100, 2, "imageh", np.float16) if FLOAT16 else (100, 4, "image def image_dot(self, w): # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) - if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D") + n1, n2 = len(self.shape), len(w.shape) + assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" bs, groups = prod(self.shape[0:-2]), prod(w.shape[0:-2]) cin, cout = w.shape[-2], w.shape[-1] out_shape_t = self.shape[0:-2] + (cout,-1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 3f755e0180..b100b8b338 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -33,8 +33,11 @@ class LazyOp: self.op = op self.src = src self.arg = arg - # TODO: this hasattr is required because the linearizer's key function maps the buffers to ints - self.buffers = functools.reduce(lambda x,s: (x+s.buffers) if hasattr(s, 'buffers') else x, src, tuple()) + try: + self.buffers = tuple([y for x in src for y in x.buffers]) + except AttributeError: + # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases + pass def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})" def __eq__(self, __value: object) -> bool: @@ -46,10 +49,10 @@ class LazyOp: def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg)) # Any == Union[LazyBuffer, DeviceBuffer] - def map_buffers(self, real_srcs: Dict[Any, Any]): return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg) - def get_lazyops(self) -> List['LazyOp']: return [self] + [item for x in self.src for item in x.get_lazyops()] + def map_buffers(self, real_srcs: Dict[Any, Any]) -> LazyOp: return LazyOp(self.op, tuple([y.map_buffers(real_srcs) for y in self.src]), self.arg) + def get_lazyops(self) -> List[LazyOp]: return [self] + [item for x in self.src for item in x.get_lazyops()] - def replace_with_movement_ops(self: LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer': + def replace_with_movement_ops(self:LazyOp, ops:List[Tuple[MovementOps, Tuple[Any, ...]]]) -> 'LazyBuffer': from tinygrad.lazy import elementwise_op assert self.op in BinaryOps or self.op in UnaryOps return elementwise_op(self.op, *[z.replace_with_movement_ops(ops) for z in self.src], arg=self.arg) # type: ignore diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d3f29b2bf9..7531c8d02b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -468,7 +468,8 @@ class Tensor: return ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))])) def dot(self, w:Tensor) -> Tensor: - if (n1:=len(self.shape))*(n2:=len(w.shape)) == 0: raise RuntimeError(f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D") + n1, n2 = len(self.shape), len(w.shape) + assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) return (x*w).sum(-1)