From 2aa39d03cda5582bc2741690a6e11230d391baf0 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 18 Dec 2024 23:01:14 -0800 Subject: [PATCH] cleanups from Estimate [pr] (#8329) --- tinygrad/codegen/kernel.py | 5 +++-- tinygrad/codegen/uopgraph.py | 3 ++- tinygrad/engine/realize.py | 2 +- tinygrad/renderer/__init__.py | 7 +++---- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 29c8254b48..2db9401b8d 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -617,7 +617,7 @@ class Kernel: if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape)) st = store_st = ShapeTracker.from_shape(local_shape) - local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{i + 1}", st.real_size())) + local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), (f"temp{i + 1}", st.real_size())) if tc_pattern: store_st = fix_st(store_st, *tc_pattern) local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i]) srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store)) @@ -646,7 +646,8 @@ class Kernel: for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) + local_size = st_uop.arg.real_size() + local_buffer = UOp(Ops.DEFINE_LOCAL, op.dtype.ptr(local_size, local=True), (), (f"temp{self.reduceops.index(op)+1}", local_size)) local_load = UOp(Ops.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, ret))) grouped_reduce = UOp(Ops.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], grouped_axes)) if op is self.reduceops[-1]: return grouped_reduce diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 0b73afba93..503e733efd 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -49,7 +49,8 @@ def fold_expanded(ex, buf): rootsrc[0] if isinstance(rootsrc, tuple) else None) else: # for non image, we upcast the index pointer - new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(local=new_src[0].dtype.local)) + new_src[0] = new_src[0].cast(new_src[0].dtype.base.vec(fold_length).ptr(size=new_src[0].dtype.size//fold_length, + local=new_src[0].dtype.local)) # generate the folded new_srcs if is_load: new_load = UOp(Ops.LOAD, load_1.dtype.vec(fold_length), tuple(new_src)) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 07deaf3ed2..55d0f9e474 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -78,7 +78,7 @@ class BufferCopy(Runner): def __init__(self, total_sz, dest_device, src_device): if total_sz >= 1e6: name = f"{type(self).__name__[6:].lower()} {total_sz/1e6:7.2f}M, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" else: name = f"{type(self).__name__[6:].lower()} {total_sz:8d}, {dest_device[:7]:>7s} <- {src_device[:7]:7s}" - super().__init__(colored(name, "yellow"), dest_device, Estimates(mem=total_sz)) + super().__init__(colored(name, "yellow"), dest_device, Estimates(lds=total_sz, mem=total_sz)) def copy(self, dest, src): disk_supports_fast_copyout = src.device.startswith("DISK") and hasattr(src.allocator.dev, 'io_uring') and \ getattr(src.allocator.dev, 'fd', None) is not None diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 8d553c3f4d..afbbeb6569 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -36,7 +36,7 @@ class Estimates: def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: flops: sint = 0 - mem: sint = 0 + lds: sint = 0 mults: sint = 1 mult_stack: List[sint] = [] dont_count: Set[UOp] = set() @@ -53,11 +53,10 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: mults *= (u.src[1] - u.src[0]).ssimplify() elif u.op is Ops.ENDRANGE: mults = mult_stack.pop(-1) elif u.op is Ops.SPECIAL: mults *= u.arg[1] # NOTE: we don't push to the mult_stack here, you can't end these - elif u.op is Ops.LOAD: mem += u.dtype.itemsize * mults - elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults + elif u.op in {Ops.LOAD, Ops.STORE}: lds += u.src[0].dtype.itemsize * mults elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults - return flops, mem + return flops, lds @dataclass class ProgramSpec: