mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix Estimate.from_uops for sliced access (#14695)
"assume all DEFINE_GLOBAL memory is accessed" is wrong for partial load. get accessed accumulated from INDEX, then cap at full size. now mem_est never exceeds lds_est
This commit is contained in:
@@ -68,6 +68,25 @@ class TestMemoryCount(unittest.TestCase):
|
||||
_, mem = get_stats(a.assign(a+a))
|
||||
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
|
||||
|
||||
def test_setitem_slice_const(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = 3
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
|
||||
|
||||
def test_setitem_slice_tensor(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
v = Tensor.empty(30, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[20:50] = v
|
||||
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
|
||||
|
||||
def test_setitem_full(self):
|
||||
t = Tensor.empty(100, dtype=dtypes.int).realize()
|
||||
GlobalCounters.reset()
|
||||
t[:] = 3
|
||||
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
|
||||
def test_copyout(self):
|
||||
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")
|
||||
|
||||
@@ -172,7 +172,6 @@ class ExecItem:
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
if DEBUG >= 2:
|
||||
lds_est = sym_infer(self.prg.estimates.lds, var_vals)
|
||||
mem_est = min(mem_est, lds_est) # there can't be more memory accessed than loads/stores. remove this when symbolic is fixed
|
||||
header_color = 'magenta' if jit else ('green' if self.prg.first_run else None)
|
||||
ptm = colored(time_to_str(et, w=9), "yellow" if et > 0.01 else None) if et is not None else ""
|
||||
flops, membw, ldsbw = op_est/(et or 1e-20), mem_est/(et or 1e-20), lds_est/(et or 1e-20)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, cast
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod, DEBUG
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher, print_uops, KernelInfo
|
||||
from tinygrad.uop.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, smin, GroupOp, PatternMatcher, print_uops, KernelInfo
|
||||
from tinygrad.dtype import AddrSpace, PtrDType
|
||||
from tinygrad.codegen.opt.tc import TensorCore
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -42,8 +42,10 @@ class Estimates:
|
||||
if u.op in {Ops.LOAD, Ops.STORE}:
|
||||
buf = u
|
||||
while len(buf.src): buf = buf.src[0]
|
||||
if buf.op is Ops.PARAM: # assume all DEFINE_GLOBAL memory is accessed
|
||||
mem[(buf, u.op)] = buf.ptrdtype.size * buf.dtype.itemsize
|
||||
if buf.op is Ops.PARAM:
|
||||
# u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
|
||||
accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
|
||||
mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
|
||||
if u.op is Ops.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= cast(sint, u.src[0].ssimplify())
|
||||
|
||||
Reference in New Issue
Block a user