diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index b1c96c8beb..194c9a37ba 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict import itertools from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps from tinygrad.lazy import LazyBuffer -from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType +from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType from tinygrad.runtime.lib import buf_is_kernel_arg from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape @@ -42,7 +42,7 @@ class Kernel: # fetch lazyop info self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast)) - self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None) + self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys()) # there's only allowed to be one reduceop reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]