mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
track the size in the lazybuffer (#3044)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Union, Optional, Any, Tuple, List, Set, Dict
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.helpers import prod, merge_dicts, flatten, getenv, dedup, DEBUG, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.symbolic import sint, Variable, Node
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.graph import log_lazybuffer
|
||||
@@ -35,6 +35,7 @@ class LazyBuffer:
|
||||
base:Optional[LazyBuffer]=None):
|
||||
assert isinstance(device, str) and device == Device.canonicalize(device)
|
||||
self.device, self.st, self.dtype, self.shape = device, st, dtype, st.shape
|
||||
self.size = prod([x.max if isinstance(x, Node) else x for x in self.shape])
|
||||
if base is None:
|
||||
# properties on base
|
||||
self.op, self.arg, self.srcs = op, arg, srcs # this is a LazyOp, except the src is LazyBuffers and not LazyOps
|
||||
@@ -63,7 +64,7 @@ class LazyBuffer:
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(self.shape)).expand(self.shape)
|
||||
|
||||
def contiguous(self):
|
||||
if not self.st.contiguous or self.st.size() != self.base.st.size() or self.is_unrealized_const():
|
||||
if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const():
|
||||
ret = self.e(LoadOps.CONTIGUOUS)
|
||||
sti = self.st.invert(self.base.shape)
|
||||
if sti is not None: self.base.contiguous_child = ref(ret), sti
|
||||
@@ -91,7 +92,7 @@ class LazyBuffer:
|
||||
if self.device == device: return self
|
||||
|
||||
# double COPY = one COPY
|
||||
if self.st.contiguous and self.st.size() == self.base.st.size() and not self.base.realized and self.base.op == LoadOps.COPY:
|
||||
if self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op == LoadOps.COPY:
|
||||
return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape)
|
||||
|
||||
# const doesn't have to be copied (issues with disk tensor)
|
||||
@@ -233,7 +234,7 @@ def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffe
|
||||
allbufs[buf] = None
|
||||
if buf.op in LoadOps: realizes.add(buf.base)
|
||||
if buf.op == LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].st.size() == buf.srcs[0].base.st.size(), "can only copy contig"
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs: _recurse_lb(x, realizes, allbufs, simple_pads)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
|
||||
from tinygrad.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.helpers import prod, colored, getenv
|
||||
from tinygrad.helpers import colored, getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
# *** schedule running ***
|
||||
@@ -39,7 +39,7 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
si.out.realized = si.out.output_buffer if si.out.output_buffer is not None else \
|
||||
Buffer(si.out.device, prod((s if isinstance(s, int) else s.max for s in si.out.shape)), si.out.dtype,
|
||||
Buffer(si.out.device, si.out.size, si.out.dtype,
|
||||
"PLACEHOLDER" if isinstance(prg, InterpretedASTRunner) else None)
|
||||
del si.out.srcs
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ class ShapeTracker:
|
||||
@property
|
||||
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
|
||||
|
||||
@functools.lru_cache(maxsize=None) # NOTE: this keeps all ShapeTrackers alive
|
||||
def size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
ret = self.expr_idxs()[0].max
|
||||
|
||||
Reference in New Issue
Block a user