correct division dtype casting (#3405)

* 新年快乐

* fix: exclude floordiv onnx tests

* fix: less weird if statements in div

* 龙年大吉

* fix: tempfix onnx div

* fix: use reference impl for div
This commit is contained in:
geohotstan
2024-02-16 08:34:40 +08:00
committed by GitHub
parent 5de660ca0d
commit 5eb4c902f6
4 changed files with 23 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ from extra.onnx import safe_numpy, DTYPE_MAP
import numpy as np
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Div", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
"Elu", "Celu", "Xor", "Round"}
# **************** Free Ops ****************
@@ -31,6 +31,11 @@ def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_t
# **************** Simple Ops ****************
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
def Div(x: Tensor, other: Tensor):
ret = x/other
return ret if ret.dtype == x.dtype else ret.cast(x.dtype)
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
if value is not None: return value
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)

View File

@@ -484,5 +484,17 @@ class TestAutoCastType(unittest.TestCase):
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
def test_div(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
def test_div_const(self):
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
if __name__ == '__main__':
unittest.main()

View File

@@ -337,8 +337,8 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
helper_test_op([(), ()], lambda x,y: x/y)
def test_div_int(self):
helper_test_op(None, lambda x,y: x//y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: x/2, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)

View File

@@ -836,8 +836,9 @@ class Tensor:
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
x = self._to_const_val(x)
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) \
else self.mul(1/x)
if x.__class__ is not Tensor and not reverse and x != 0: return self.mul(1/x)
if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse))
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor: