diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index b76c8f5b7d..e9bcaba883 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -90,7 +90,8 @@ class LazyBuffer: if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0] # const doesn't have to be copied (issues with disk tensor) - if self.is_unrealized_const(): return self.const(self.base.arg)._view(self.st) + if self.is_unrealized_const(): + return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) # if it's a shrink, do the shrink before the copy with CONTIGUOUS # TODO: why is this required on WEBGPU? @@ -100,7 +101,7 @@ class LazyBuffer: # copy the base and apply the shapetracker on the new device return create_lazybuffer(device, self.base.st, self.dtype, LoadOps.COPY, srcs=(self.base,))._view(self.st) - def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: + def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer: srcs: List[LazyBuffer] = [] for s in (self,)+in_srcs: if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None: @@ -110,34 +111,16 @@ class LazyBuffer: assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}" if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" out_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else dtypes.bool - - # if possible, reduce the amount of compute done by not computing on padded areas - if op in {BinaryOps.MUL, BinaryOps.ADD}: - m0, m1, out_mask = srcs[0].st.views[-1].mask, srcs[1].st.views[-1].mask, None - if m0 is not None and m0 == m1: out_mask = m0 # if they match, it works for both MUL and ADD - elif op == BinaryOps.MUL: out_mask = m0 or m1 # MUL only needs one mask - if out_mask is not None: - shrink_srcs = tuple(x.shrink(out_mask) for x in srcs) # remove the mask from the inputs - ret = create_lazybuffer(self.device, ShapeTracker.from_shape(shrink_srcs[0].shape), out_dtype, op, arg, shrink_srcs) - return ret.pad(tuple([(p[0], s-p[1]) for s,p in zip(self.shape, out_mask)])) - return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) # *** reduce ops *** - def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: + def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: if self.shape == tuple(new_shape): return self unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape) return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,)) - def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: - # if possible, reduce the amount of compute done by not computing on padded areas - if op == ReduceOps.SUM and (out_mask:=self.st.views[-1].mask) is not None: - new_input = self.shrink(out_mask) - new_new_shape = tuple(ns if s != ns else os for s,os,ns in zip(self.shape, new_input.shape, new_shape)) - pad_back = tuple([(p[0], s-p[1]) if ns != 1 else (0,0) for s,ns,p in zip(self.shape, new_shape, out_mask)]) - return new_input.r(op, new_new_shape).pad(pad_back) - + def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer: # TODO: can we split symbolic shape if the reduce axis is not symbolic? if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768): return self._reduce_op(op, new_shape) @@ -150,16 +133,16 @@ class LazyBuffer: # *** movement ops *** - def _view(self:LazyBuffer, new_st:ShapeTracker) -> LazyBuffer: + def _view(self, new_st:ShapeTracker) -> LazyBuffer: if new_st.contiguous and self.base.shape == new_st.shape: return self.base return create_lazybuffer(self.device, new_st, self.dtype, base=self.base) - def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.reshape(arg)) - def pad(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.pad(arg)) - def expand(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.expand(arg)) - def permute(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.permute(arg)) - def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.shrink(arg)) - def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.stride(arg)) + def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg)) + def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg)) + def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg)) + def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg)) + def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg)) + def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg)) # *** schedule creation ***