diff --git a/test/test_ops.py b/test/test_ops.py index 078001f1ef..19bb431051 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -212,6 +212,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum) helper_test_op([(), ()], torch.maximum, Tensor.maximum) helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 0., 3., 4.], [1., 2., 3., 0.]]) + helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1, 0, 3, 4], [1, 2, 3, 0]], forward_only=True) def test_minimum(self): helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum) helper_test_op([(), ()], torch.minimum, Tensor.minimum) @@ -246,6 +247,7 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) helper_test_op([(), ()], lambda x,y: x/y, Tensor.div) helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=[[5],[1]]) + helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=[[3]]) def test_div_const(self): helper_test_op([(45,65)], lambda x: x/255, lambda x: x/255) helper_test_op([(45,65)], lambda x: x/1, lambda x: x/1) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 71d7b27075..cb748ab7ed 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -588,7 +588,7 @@ class Tensor: def add(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Add.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x else self def sub(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Sub.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x or reverse else self def mul(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self - def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x else self.mul(1/x) + def div(self, x:Union[Tensor, float], reverse=False) -> Tensor: return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) else self.mul(1/x) def pow(self, x:Union[Tensor, float], reverse=False) -> Tensor: if x.__class__ is not Tensor and not reverse: # simple pow identities