mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 23:25:04 -05:00
move div folding from tensor to lazy (#4022)
This commit is contained in:
@@ -138,7 +138,8 @@ class LazyBuffer:
|
||||
if (val := in_srcs[0].base.arg) == 1: return self
|
||||
if val == -1: return self.e(UnaryOps.NEG)
|
||||
if val == 0: return self.const(0)
|
||||
# TODO: DIV
|
||||
if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0:
|
||||
return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg))
|
||||
|
||||
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
@@ -892,10 +892,9 @@ class Tensor:
|
||||
def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Sub.apply(*self._broadcasted(x, reverse))
|
||||
def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Mul.apply(*self._broadcasted(x, reverse))
|
||||
def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x)
|
||||
if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return F.Div.apply(*self._broadcasted(x, reverse))
|
||||
return F.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
||||
numerator, denominator = self._broadcasted(x, reverse)
|
||||
if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype))
|
||||
return F.Div.apply(numerator, denominator)
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user