minor cleanups / remove that op (#2905)

This commit is contained in:
George Hotz
2023-12-21 18:24:20 -08:00
committed by GitHub
parent fd0ba33b38
commit 4432cb17bb

View File

@@ -90,7 +90,8 @@ class LazyBuffer:
if self.base == self and not self.realized and self.op == LoadOps.COPY and self.srcs[0].device == device: return self.srcs[0]
# const doesn't have to be copied (issues with disk tensor)
if self.is_unrealized_const(): return self.const(self.base.arg)._view(self.st)
if self.is_unrealized_const():
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st)
# if it's a shrink, do the shrink before the copy with CONTIGUOUS
# TODO: why is this required on WEBGPU?
@@ -100,7 +101,7 @@ class LazyBuffer:
# copy the base and apply the shapetracker on the new device
return create_lazybuffer(device, self.base.st, self.dtype, LoadOps.COPY, srcs=(self.base,))._view(self.st)
def e(self:LazyBuffer, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
def e(self, op:Union[LoadOps, UnaryOps, BinaryOps, TernaryOps], *in_srcs:LazyBuffer, arg:Optional[Any]=None) -> LazyBuffer:
srcs: List[LazyBuffer] = []
for s in (self,)+in_srcs:
if s == s.base and s.base.contiguous_child and (root:=s.base.contiguous_child[0]()) is not None:
@@ -110,34 +111,16 @@ class LazyBuffer:
assert all_same(dts:=[x.dtype.scalar() for x in (srcs if op != TernaryOps.WHERE else srcs[1:])]), f"all dtypes must match {dts} on {op}"
if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
out_dtype = srcs[-1].dtype if op != BinaryOps.CMPLT else dtypes.bool
# if possible, reduce the amount of compute done by not computing on padded areas
if op in {BinaryOps.MUL, BinaryOps.ADD}:
m0, m1, out_mask = srcs[0].st.views[-1].mask, srcs[1].st.views[-1].mask, None
if m0 is not None and m0 == m1: out_mask = m0 # if they match, it works for both MUL and ADD
elif op == BinaryOps.MUL: out_mask = m0 or m1 # MUL only needs one mask
if out_mask is not None:
shrink_srcs = tuple(x.shrink(out_mask) for x in srcs) # remove the mask from the inputs
ret = create_lazybuffer(self.device, ShapeTracker.from_shape(shrink_srcs[0].shape), out_dtype, op, arg, shrink_srcs)
return ret.pad(tuple([(p[0], s-p[1]) for s,p in zip(self.shape, out_mask)]))
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
# *** reduce ops ***
def _reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
def _reduce_op(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape): return self
unbound_new_shape = tuple(s.unbind()[0] if not isinstance(s, int) else s for s in new_shape)
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
def r(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
# if possible, reduce the amount of compute done by not computing on padded areas
if op == ReduceOps.SUM and (out_mask:=self.st.views[-1].mask) is not None:
new_input = self.shrink(out_mask)
new_new_shape = tuple(ns if s != ns else os for s,os,ns in zip(self.shape, new_input.shape, new_shape))
pad_back = tuple([(p[0], s-p[1]) if ns != 1 else (0,0) for s,ns,p in zip(self.shape, new_shape, out_mask)])
return new_input.r(op, new_new_shape).pad(pad_back)
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768):
return self._reduce_op(op, new_shape)
@@ -150,16 +133,16 @@ class LazyBuffer:
# *** movement ops ***
def _view(self:LazyBuffer, new_st:ShapeTracker) -> LazyBuffer:
def _view(self, new_st:ShapeTracker) -> LazyBuffer:
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return create_lazybuffer(self.device, new_st, self.dtype, base=self.base)
def reshape(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.reshape(arg))
def pad(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.pad(arg))
def expand(self:LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: return self._view(self.st.expand(arg))
def permute(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.permute(arg))
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: return self._view(self.st.shrink(arg))
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: return self._view(self.st.stride(arg))
def reshape(self, arg:Tuple[sint, ...]): return self._view(self.st.reshape(arg))
def pad(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.pad(arg))
def expand(self, arg:Tuple[sint, ...]): return self._view(self.st.expand(arg))
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
# *** schedule creation ***