fix Estimate.from_uops for sliced access (#14695)

"assume all DEFINE_GLOBAL memory is accessed" is wrong for partial load. get accessed accumulated from INDEX, then cap at full size. now mem_est never exceeds lds_est
This commit is contained in:
chenyu
2026-02-12 11:18:07 -05:00
committed by GitHub
parent 8551fa50d3
commit 56caf6a3a2
3 changed files with 24 additions and 4 deletions

View File

@@ -68,6 +68,25 @@ class TestMemoryCount(unittest.TestCase):
_, mem = get_stats(a.assign(a+a))
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
def test_setitem_slice_const(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[20:50] = 3
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
def test_setitem_slice_tensor(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
v = Tensor.empty(30, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[20:50] = v
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
def test_setitem_full(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[:] = 3
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
def test_copyout(self):
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")

View File

@@ -172,7 +172,6 @@ class ExecItem:
if et is not None: GlobalCounters.time_sum_s += et
if DEBUG >= 2:
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)

View File

@@ -3,7 +3,7 @@ from typing import Callable, cast
import functools
from dataclasses import dataclass, field
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, smin, GroupOp, PatternMatcher, print_uops, KernelInfo
from tinygrad.dtype import AddrSpace, PtrDType
from tinygrad.codegen.opt.tc import TensorCore
from tinygrad.codegen.opt import Opt
@@ -42,8 +42,10 @@ class Estimates:
if u.op in {Ops.LOAD, Ops.STORE}:
buf = u
while len(buf.src): buf = buf.src[0]
if buf.op is Ops.PARAM: # assume all DEFINE_GLOBAL memory is accessed
mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize
if buf.op is Ops.PARAM:
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
if u.op is Ops.RANGE:
mult_stack.append(mults)
mults *= cast(sint, u.src[0].ssimplify())