mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
correct division dtype casting (#3405)
* 新年快乐 * fix: exclude floordiv onnx tests * fix: less weird if statements in div * 龙年大吉 * fix: tempfix onnx div * fix: use reference impl for div
This commit is contained in:
@@ -7,7 +7,7 @@ from extra.onnx import safe_numpy, DTYPE_MAP
|
||||
import numpy as np
|
||||
|
||||
tensor_methods = {"Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Relu", "Sigmoid", "MatMul",
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Div", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
|
||||
"Floor", "Ceil", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", "Softsign", "Asinh", "Acosh", "Atanh",
|
||||
"Elu", "Celu", "Xor", "Round"}
|
||||
|
||||
# **************** Free Ops ****************
|
||||
@@ -31,6 +31,11 @@ def CastLike(x: Tensor, target_type: Tensor, saturate=1): return x.cast(target_t
|
||||
|
||||
# **************** Simple Ops ****************
|
||||
|
||||
# https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_div.py
|
||||
def Div(x: Tensor, other: Tensor):
|
||||
ret = x/other
|
||||
return ret if ret.dtype == x.dtype else ret.cast(x.dtype)
|
||||
|
||||
def Constant(value: Tensor=None, value_float=None, value_floats=None, value_int=None, value_ints=None, value_string=None, value_strings=None):
|
||||
if value is not None: return value
|
||||
if value_float is not None: return Tensor(value_float, dtype=dtypes.float32, requires_grad=False)
|
||||
|
||||
@@ -484,5 +484,17 @@ class TestAutoCastType(unittest.TestCase):
|
||||
assert Tensor([1, 2], dtype=dt).maximum(3).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int)
|
||||
assert Tensor([1, 2], dtype=dt).maximum(True).dtype == dt
|
||||
|
||||
def test_div(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int16) / Tensor([2, 2], dtype=dtypes.int32)).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float32
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / Tensor([2, 2], dtype=dtypes.float16)).dtype == dtypes.float16
|
||||
|
||||
def test_div_const(self):
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.int32) / 2.0).dtype == dtypes.default_float
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2).dtype == dtypes.float16
|
||||
assert (Tensor([1, 2], dtype=dtypes.float16) / 2.0).dtype == dtypes.float16
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -337,8 +337,8 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: x/y)
|
||||
helper_test_op([(), ()], lambda x,y: x/y)
|
||||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x,y: x//y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: (x/2).to(torch.int), lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x/2, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
def test_scalar_div(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1)
|
||||
|
||||
@@ -836,8 +836,9 @@ class Tensor:
|
||||
return mlops.Mul.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or x != 1.0 else self
|
||||
def div(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
x = self._to_const_val(x)
|
||||
return mlops.Div.apply(*self._broadcasted(x, reverse)) if x.__class__ is Tensor or reverse or not x or not dtypes.is_float(self.dtype) \
|
||||
else self.mul(1/x)
|
||||
if x.__class__ is not Tensor and not reverse and x != 0: return self.mul(1/x)
|
||||
if isinstance(x, Tensor) and dtypes.is_float(x.dtype): return mlops.Div.apply(*self._broadcasted(x, reverse))
|
||||
return mlops.Div.apply(*self.cast(least_upper_float(self.dtype))._broadcasted(x, reverse))
|
||||
def xor(self, x:Tensor, reverse=False) -> Tensor: return mlops.Xor.apply(*self._broadcasted(x, reverse))
|
||||
|
||||
def pow(self, x:Union[Tensor, Scalar], reverse=False) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user