diff --git a/test/null/test_uops_stats.py b/test/null/test_uops_stats.py index 72e8b73d3b..a9bf4459ef 100644 --- a/test/null/test_uops_stats.py +++ b/test/null/test_uops_stats.py @@ -68,6 +68,25 @@ class TestMemoryCount(unittest.TestCase): _, mem = get_stats(a.assign(a+a)) self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write + def test_setitem_slice_const(self): + t = Tensor.empty(100, dtype=dtypes.int).realize() + GlobalCounters.reset() + t[20:50] = 3 + self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written + + def test_setitem_slice_tensor(self): + t = Tensor.empty(100, dtype=dtypes.int).realize() + v = Tensor.empty(30, dtype=dtypes.int).realize() + GlobalCounters.reset() + t[20:50] = v + self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written + + def test_setitem_full(self): + t = Tensor.empty(100, dtype=dtypes.int).realize() + GlobalCounters.reset() + t[:] = 3 + self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written + @unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device") def test_copyout(self): a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index af114609c9..b18546970b 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -172,7 +172,6 @@ class ExecItem: if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 2: lds_est = sym_infer(self.prg.estimates.lds, var_vals) - mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed header_color = 'magenta' if jit else ('green' if self.prg.first_run else None) ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else "" flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 3a8234b75e..e6daee76ac 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -3,7 +3,7 @@ from typing import Callable, cast import functools from dataclasses import dataclass, field from tinygrad.helpers import to_function_name, dedup, prod, DEBUG -from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo +from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, smin, GroupOp, PatternMatcher, print_uops, KernelInfo from tinygrad.dtype import AddrSpace, PtrDType from tinygrad.codegen.opt.tc import TensorCore from tinygrad.codegen.opt import Opt @@ -42,8 +42,10 @@ class Estimates: if u.op in {Ops.LOAD, Ops.STORE}: buf = u while len(buf.src): buf = buf.src[0] - if buf.op is Ops.PARAM: # assume all DEFINE_GLOBAL memory is accessed - mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize + if buf.op is Ops.PARAM: + # u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul) + accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults + mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed if u.op is Ops.RANGE: mult_stack.append(mults) mults *= cast(sint, u.src[0].ssimplify())