move ops lds estimate to Program [run_process_replay] (#5872)

This commit is contained in:
George Hotz
2024-08-01 19:12:07 -07:00
committed by GitHub
parent 877e0b4ba0
commit 3995f1ddf1
2 changed files with 11 additions and 7 deletions

View File

@@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import strides_for_shape
from tinygrad.codegen.uops import UOps, flops_mem
from tinygrad.codegen.uops import UOps
from tinygrad.codegen.uopgraph import UOpGraph
from tinygrad.codegen.lowerer import lazyop_to_uop
from enum import Enum, auto
@@ -760,10 +760,9 @@ class Kernel:
else:
global_size, local_size = None, None
ops, mem = flops_mem(self.uops.uops, ignore_indexing=True)
# group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
# TODO: these max and min don't work on symbolic, and results are very wrong.
mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
itertools.groupby([x for x in self.ast.lazyops if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
key=lambda x: (x.op, x.arg.idx)))
return Program(ansiname, src, self.opts.device, global_size, local_size, self.uops.uops, ops, min(mem, mem_bytes), mem)
return Program(ansiname, src, self.opts.device, global_size, local_size, self.uops.uops, mem_bytes)

View File

@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Any
import functools
from dataclasses import dataclass
from tinygrad.helpers import to_function_name, dedup
from tinygrad.codegen.uops import UOps, UOp
from tinygrad.codegen.uops import UOps, UOp, flops_mem
from tinygrad.shape.symbolic import sym_infer, sint, Variable
from tinygrad.dtype import DType
@@ -22,9 +22,14 @@ class Program:
global_size:Optional[List[int]]=None
local_size:Optional[List[int]]=None
uops:Optional[List[UOp]]=None
op_estimate:sint=0
mem_estimate:sint=0
lds_estimate:sint=0
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
@property
def op_estimate(self) -> sint: return self._ops_lds[0]
@property
def lds_estimate(self) -> sint: return self._ops_lds[1]
@functools.cached_property
def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
@functools.cached_property
def vars(self) -> List[Variable]: