fix mem estimate

This commit is contained in:
George Hotz
2023-09-03 09:47:17 -07:00
parent 0458120cf2
commit eddb140067

View File

@@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict
import itertools
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
from tinygrad.runtime.lib import buf_is_kernel_arg
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
@@ -42,7 +42,7 @@ class Kernel:
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]