move div folding from tensor to lazy (#4022)

This commit is contained in:
chenyu
2024-03-31 18:07:39 -04:00
committed by GitHub
parent 7fa233e8c9
commit 276ef8eb87
2 changed files with 5 additions and 5 deletions

View File

@@ -138,7 +138,8 @@ class LazyBuffer:
if (val := in_srcs[0].base.arg) == 1: return self
if val == -1: return self.e(UnaryOps.NEG)
if val == 0: return self.const(0)
# TODO: DIV
if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0:
return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg))
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))

View File

@@ -892,10 +892,9 @@ class Tensor:
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Sub.apply(*self._broadcasted(x, reverse))
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Mul.apply(*self._broadcasted(x, reverse))
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
x = self._to_const_val(x)
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return F.Div.apply(*self._broadcasted(x, reverse))
return F.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
numerator, denominator = self._broadcasted(x, reverse)
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
return F.Div.apply(numerator, denominator)
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse))
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: