mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move mem estimate to ProgramSpec [pr] (#10901)
This commit is contained in:
@@ -566,10 +566,5 @@ class Kernel:
|
||||
self.linearize(name_override, ast_transform)
|
||||
assert self.uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
src = self.opts.render(self.uops)
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.src[0].dtype.nbytes() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in GroupOp.Buffer and x.src[0].base.op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].base.arg)))
|
||||
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts, mem_bytes,
|
||||
return ProgramSpec(self.name if not name_override else name_override, src, self.opts.device, self.ast, self.uops, self.applied_opts,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable, cast
|
||||
import functools, math
|
||||
import functools, math, itertools
|
||||
from enum import Enum, auto
|
||||
from dataclasses import dataclass, field, replace
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
@@ -88,7 +88,6 @@ class ProgramSpec:
|
||||
ast:UOp # save the base ast (this is method cache key)
|
||||
uops:Optional[list[UOp]]=None
|
||||
applied_opts:Optional[list[Opt]]=None
|
||||
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
||||
|
||||
# filled in from uops (if we have uops)
|
||||
global_size:Optional[list[int]]=None
|
||||
@@ -117,6 +116,14 @@ class ProgramSpec:
|
||||
self.ins = sorted(dedup(self.ins))
|
||||
self._ran_post_init = True
|
||||
|
||||
@functools.cached_property
|
||||
def mem_estimate(self) -> sint:
|
||||
# group non-local bufs by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
return sum(max(x.src[0].dtype.nbytes() for x in group)
|
||||
for _, group in itertools.groupby([x for x in self.ast.toposort() if x.op in {Ops.LOAD, Ops.STORE} and x.src[0].base.op is Ops.DEFINE_GLOBAL],
|
||||
key=lambda x: (x.op, x.src[0].base.arg)))
|
||||
|
||||
@functools.cached_property
|
||||
def estimates(self) -> Estimates:
|
||||
return replace(Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True), mem=self.mem_estimate)
|
||||
|
||||
Reference in New Issue
Block a user