track the size in the lazybuffer (#3044)

This commit is contained in:
George Hotz
2024-01-08 13:44:55 -08:00
committed by GitHub
parent c003be7309
commit 47d67da830
3 changed files with 7 additions and 7 deletions

View File

@@ -5,7 +5,7 @@ from typing import Union, Optional, Any, Tuple, List, Set, Dict
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.symbolic import sint, Variable, Node
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer, Device
from tinygrad.graph import log_lazybuffer
@@ -35,6 +35,7 @@ class LazyBuffer:
base:Optional[LazyBuffer]=None):
assert isinstance(device, str) and device == Device.canonicalize(device)
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
self.size = prod([x.max if isinstance(x, Node) else x for x in self.shape])
if base is None:
# properties on base
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
@@ -63,7 +64,7 @@ class LazyBuffer:
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
def contiguous(self):
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
ret = self.e(LoadOps.CONTIGUOUS)
sti = self.st.invert(self.base.shape)
if sti is not None: self.base.contiguous_child = ref(ret), sti
@@ -91,7 +92,7 @@ class LazyBuffer:
if self.device == device: return self
# double COPY = one COPY
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.base.realized and self.base.op == LoadOps.COPY:
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op == LoadOps.COPY:
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
# const doesn't have to be copied (issues with disk tensor)
@@ -233,7 +234,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].st.size() == buf.srcs[0].base.st.size(), "can only copy contig"
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads)

View File

@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import prod, colored, getenv
from tinygrad.helpers import colored, getenv
from tinygrad.shape.symbolic import Variable
# *** schedule running ***
@@ -39,7 +39,7 @@ def run_schedule(schedule:List[ScheduleItem]):
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype,
Buffer(si.out.device, si.out.size, si.out.dtype,
"PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
del si.out.srcs

View File

@@ -75,7 +75,6 @@ class ShapeTracker:
@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
@functools.lru_cache(maxsize=None) # NOTE: this keeps all ShapeTrackers alive
def size(self) -> int:
if 0 in self.shape: return 0
ret = self.expr_idxs()[0].max