mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move ops lds estimate to Program [run_process_replay] (#5872)
This commit is contained in:
@@ -13,7 +13,7 @@ from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, DE
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.uops import UOps, flops_mem
|
||||
from tinygrad.codegen.uops import UOps
|
||||
from tinygrad.codegen.uopgraph import UOpGraph
|
||||
from tinygrad.codegen.lowerer import lazyop_to_uop
|
||||
from enum import Enum, auto
|
||||
@@ -760,10 +760,9 @@ class Kernel:
|
||||
else:
|
||||
global_size, local_size = None, None
|
||||
|
||||
ops, mem = flops_mem(self.uops.uops, ignore_indexing=True)
|
||||
# group non-local MemBuffers by the op type (LOAD or STORE) and the buffer arg. take the max access of that buffer in bytes
|
||||
# TODO: these max and min don't work on symbolic, and results are very wrong.
|
||||
mem_bytes = sum(max(x.arg.dtype.itemsize * x.arg.st.real_size() for x in group) for _, group in
|
||||
itertools.groupby([x for x in self.ast.lazyops if x.op in BufferOps and isinstance(x.arg, MemBuffer) and x.arg.idx >= 0],
|
||||
key=lambda x: (x.op, x.arg.idx)))
|
||||
return Program(ansiname, src, self.opts.device, global_size, local_size, self.uops.uops, ops, min(mem, mem_bytes), mem)
|
||||
return Program(ansiname, src, self.opts.device, global_size, local_size, self.uops.uops, mem_bytes)
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Optional, List, Tuple, Dict, Any
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import to_function_name, dedup
|
||||
from tinygrad.codegen.uops import UOps, UOp
|
||||
from tinygrad.codegen.uops import UOps, UOp, flops_mem
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@@ -22,9 +22,14 @@ class Program:
|
||||
global_size:Optional[List[int]]=None
|
||||
local_size:Optional[List[int]]=None
|
||||
uops:Optional[List[UOp]]=None
|
||||
op_estimate:sint=0
|
||||
mem_estimate:sint=0
|
||||
lds_estimate:sint=0
|
||||
mem_estimate:sint=0 # TODO: get this from the load/store uops once min/max are good
|
||||
|
||||
@property
|
||||
def op_estimate(self) -> sint: return self._ops_lds[0]
|
||||
@property
|
||||
def lds_estimate(self) -> sint: return self._ops_lds[1]
|
||||
@functools.cached_property
|
||||
def _ops_lds(self) -> Tuple[sint, sint]: return (0,0) if self.uops is None else flops_mem(self.uops, ignore_indexing=True)
|
||||
|
||||
@functools.cached_property
|
||||
def vars(self) -> List[Variable]:
|
||||
|
||||
Reference in New Issue
Block a user