mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
minor cleanups / remove that op (#2905)
This commit is contained in:
@@ -90,7 +90,8 @@ class LazyBuffer:
|
||||
if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0]
|
||||
|
||||
# const doesn't have to be copied (issues with disk tensor)
|
||||
if self.is_unrealized_const(): return self.const(self.base.arg)._view(self.st)
|
||||
if self.is_unrealized_const():
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
|
||||
|
||||
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
|
||||
# TODO: why is this required on WEBGPU?
|
||||
@@ -100,7 +101,7 @@ class LazyBuffer:
|
||||
# copy the base and apply the shapetracker on the new device
|
||||
return create_lazybuffer(device, self.base.st, self.dtype, LoadOps.COPY, srcs=(self.base,))._view(self.st)
|
||||
|
||||
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
|
||||
srcs: List[LazyBuffer] = []
|
||||
for s in (self,)+in_srcs:
|
||||
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
|
||||
@@ -110,34 +111,16 @@ class LazyBuffer:
|
||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
|
||||
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
out_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else dtypes.bool
|
||||
|
||||
# if possible, reduce the amount of compute done by not computing on padded areas
|
||||
if op in {BinaryOps.MUL, BinaryOps.ADD}:
|
||||
m0, m1, out_mask = srcs[0].st.views[-1].mask, srcs[1].st.views[-1].mask, None
|
||||
if m0 is not None and m0 == m1: out_mask = m0 # if they match, it works for both MUL and ADD
|
||||
elif op == BinaryOps.MUL: out_mask = m0 or m1 # MUL only needs one mask
|
||||
if out_mask is not None:
|
||||
shrink_srcs = tuple(x.shrink(out_mask) for x in srcs) # remove the mask from the inputs
|
||||
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(shrink_srcs[0].shape), out_dtype, op, arg, shrink_srcs)
|
||||
return ret.pad(tuple([(p[0], s-p[1]) for s,p in zip(self.shape, out_mask)]))
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
# *** reduce ops ***
|
||||
|
||||
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == tuple(new_shape): return self
|
||||
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
|
||||
|
||||
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
# if possible, reduce the amount of compute done by not computing on padded areas
|
||||
if op == ReduceOps.SUM and (out_mask:=self.st.views[-1].mask) is not None:
|
||||
new_input = self.shrink(out_mask)
|
||||
new_new_shape = tuple(ns if s != ns else os for s,os,ns in zip(self.shape, new_input.shape, new_shape))
|
||||
pad_back = tuple([(p[0], s-p[1]) if ns != 1 else (0,0) for s,ns,p in zip(self.shape, new_shape, out_mask)])
|
||||
return new_input.r(op, new_new_shape).pad(pad_back)
|
||||
|
||||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
|
||||
return self._reduce_op(op, new_shape)
|
||||
@@ -150,16 +133,16 @@ class LazyBuffer:
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def _view(self:LazyBuffer, new_st:ShapeTracker) -> LazyBuffer:
|
||||
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
|
||||
|
||||
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.reshape(arg))
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.pad(arg))
|
||||
def expand(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.expand(arg))
|
||||
def permute(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.permute(arg))
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.shrink(arg))
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.stride(arg))
|
||||
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
|
||||
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
|
||||
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
|
||||
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
||||
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
||||
|
||||
# *** schedule creation ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user