move mem estimate to ProgramSpec [pr] (#10901)

This commit is contained in:
George Hotz
2025-06-20 15:54:28 -07:00
committed by GitHub
parent 92678e59ee
commit d399a4587d
2 changed files with 10 additions and 8 deletions

View File

@@ -566,10 +566,5 @@ class Kernel:
self.linearize(name_override, ast_transform)
assert self.uops[-1].op is Ops.SINK, "last uop must be sink"
src = self.opts.render(self.uops)
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.src[0].dtype.nbytes() for x in group)
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].base.op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].base.arg)))
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Optional, Callable, cast
import functools, math
import functools, math, itertools
from enum import Enum, auto
from dataclasses import dataclass, field, replace
from tinygrad.helpers import to_function_name, dedup, prod
@@ -88,7 +88,6 @@ class ProgramSpec:
ast:UOp # save the base ast (this is method cache key)
uops:Optional[list[UOp]]=None
applied_opts:Optional[list[Opt]]=None
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
# filled in from uops (if we have uops)
global_size:Optional[list[int]]=None
@@ -117,6 +116,14 @@ class ProgramSpec:
self.ins = sorted(dedup(self.ins))
self._ran_post_init = True
@functools.cached_property
def mem_estimate(self) -> sint:
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
return sum(max(x.src[0].dtype.nbytes() for x in group)
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in {Ops.LOAD, Ops.STORE} and x.src[0].base.op is Ops.DEFINE_GLOBAL],
key=lambda x: (x.op, x.src[0].base.arg)))
@functools.cached_property
def estimates(self) -> Estimates:
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)