diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 4e6e908327..c22ed5bb73 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -32,7 +32,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit # TODO: jit should expose this correctly with graph kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms") - assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB" + assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB - {mem_used/1e9:.2} GB used" assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels" if all_jitted: assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501 diff --git a/tinygrad/ops.py b/tinygrad/ops.py index aacc7b7f66..19bc2069cb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -228,13 +228,16 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))") def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is Ops.REDUCE_AXIS else self.arg - @functools.cached_property + @property def toposort(self) -> Dict[UOp, None]: - nodes: Dict[UOp, None] = {} - # NOTE: this is a lot faster than the comprehension in parents - for parent in self.src: nodes.update(parent.toposort) - nodes[self] = None - return nodes + @functools.lru_cache(None) + def _toposort(u:UOp): + nodes: Dict[UOp, None] = {} + # NOTE: this is a lot faster than the comprehension in parents + for parent in u.src: nodes.update(_toposort(parent)) + nodes[u] = None + return nodes + return _toposort(self) @functools.cached_property def tuplize(self:UOp) -> Tuple[int, Any, Optional[DType], Tuple]: @@ -261,8 +264,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def shape(self) -> Tuple[sint, ...]: return unwrap(self.st).shape @property def size(self) -> int: return self.arg[1][1] if self.op is Ops.BUFFER else unwrap(self.st).size - @property - def nbytes(self) -> int: return self.size*self.dtype.itemsize # *** uop evaluation *** @@ -356,7 +357,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): @property def base(self) -> UOp: return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self def view(self, new_st:ShapeTracker) -> UOp: - if self.st is None: return UOp(Ops.VIEW, self.dtype, (self,), new_st) + if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st) ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st) # instant folding rules if self.st.size == 0 or (new_st.views[-1].mask is not None and any((x[1]-x[0]) == 0 for x in new_st.views[-1].mask)): return ret.const_like(0)