diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 071c981fbb..211aeaf150 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -138,7 +138,8 @@ class LazyBuffer: if (val := in_srcs[0].base.arg) == 1: return self if val == -1: return self.e(UnaryOps.NEG) if val == 0: return self.const(0) - # TODO: DIV + if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0: + return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg)) out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d7b6722326..93a9d5581f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -892,10 +892,9 @@ class Tensor: def sub(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Sub.apply(*self._broadcasted(x, reverse)) def mul(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Mul.apply(*self._broadcasted(x, reverse)) def div(self, x:Union[Tensor, ConstType], reverse=False, upcast=True) -> Tensor: - x = self._to_const_val(x) - if not isinstance(x, Tensor) and not reverse and x != 0 and upcast: return self.mul(1/x) - if (isinstance(x, Tensor) and dtypes.is_float(x.dtype)) or not upcast: return F.Div.apply(*self._broadcasted(x, reverse)) - return F.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse)) + numerator, denominator = self._broadcasted(x, reverse) + if upcast: numerator, denominator = numerator.cast(least_upper_float(numerator.dtype)), denominator.cast(least_upper_float(denominator.dtype)) + return F.Div.apply(numerator, denominator) def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: return F.Xor.apply(*self._broadcasted(x, reverse)) def pow(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: