From d399a4587d992f480f67e8ad704d3f96843e375f Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 20 Jun 2025 15:54:28 -0700 Subject: [PATCH] move mem estimate to ProgramSpec [pr] (#10901) --- tinygrad/opt/kernel.py | 7 +------ tinygrad/renderer/__init__.py | 11 +++++++++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tinygrad/opt/kernel.py b/tinygrad/opt/kernel.py index 5ef60fc614..b475647731 100644 --- a/tinygrad/opt/kernel.py +++ b/tinygrad/opt/kernel.py @@ -566,10 +566,5 @@ class Kernel: self.linearize(name_override, ast_transform) assert self.uops[-1].op is Ops.SINK, "last uop must be sink" src = self.opts.render(self.uops) - # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes - # TODO: these max and min don't work on symbolic, and results are very wrong. - mem_bytes = sum(max(x.src[0].dtype.nbytes() for x in group) - for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].base.op is Ops.DEFINE_GLOBAL], - key=lambda x: (x.op, x.src[0].base.arg))) - return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes, + return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) diff --git a/tinygrad/renderer/__init__.py b/tinygrad/renderer/__init__.py index 0c4a57e806..14212bd003 100644 --- a/tinygrad/renderer/__init__.py +++ b/tinygrad/renderer/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Optional, Callable, cast -import functools, math +import functools, math, itertools from enum import Enum, auto from dataclasses import dataclass, field, replace from tinygrad.helpers import to_function_name, dedup, prod @@ -88,7 +88,6 @@ class ProgramSpec: ast:UOp # save the base ast (this is method cache key) uops:Optional[list[UOp]]=None applied_opts:Optional[list[Opt]]=None - mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good # filled in from uops (if we have uops) global_size:Optional[list[int]]=None @@ -117,6 +116,14 @@ class ProgramSpec: self.ins = sorted(dedup(self.ins)) self._ran_post_init = True + @functools.cached_property + def mem_estimate(self) -> sint: + # group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes + # TODO: these max and min don't work on symbolic, and results are very wrong. + return sum(max(x.src[0].dtype.nbytes() for x in group) + for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in {Ops.LOAD, Ops.STORE} and x.src[0].base.op is Ops.DEFINE_GLOBAL], + key=lambda x: (x.op, x.src[0].base.arg))) + @functools.cached_property def estimates(self) -> Estimates: return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)