mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
fix mem estimate
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import NamedTuple, Optional, List, Tuple, cast, Dict
|
||||
import itertools
|
||||
from tinygrad.ops import LazyOp, MovementOps, FlopCounter, get_lazyop_info, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.helpers import dedup, dtypes, colored, prod, ImageDType, DType
|
||||
from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType
|
||||
from tinygrad.runtime.lib import buf_is_kernel_arg
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
|
||||
@@ -42,7 +42,7 @@ class Kernel:
|
||||
|
||||
# fetch lazyop info
|
||||
self.info: FlopCounter = get_lazyop_info(cast(LazyOp, self.ast))
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
|
||||
self.mem_estimate: int = sum(x.dtype.itemsize*x.size for x in self.arg_bufs.keys())
|
||||
|
||||
# there's only allowed to be one reduceop
|
||||
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
|
||||
|
||||
Reference in New Issue
Block a user