mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.mod (#8458)
it's a python style mod. possibily can be cleaner with a floor div relaxed the vmin for MOD slightly for cstyle negatives mod, it's more correct and might fix other bugs
This commit is contained in:
@@ -63,6 +63,8 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
|
||||
::: tinygrad.Tensor.sub
|
||||
::: tinygrad.Tensor.mul
|
||||
::: tinygrad.Tensor.div
|
||||
::: tinygrad.Tensor.idiv
|
||||
::: tinygrad.Tensor.mod
|
||||
::: tinygrad.Tensor.xor
|
||||
::: tinygrad.Tensor.lshift
|
||||
::: tinygrad.Tensor.rshift
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
tensor_methods = {
|
||||
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
|
||||
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
|
||||
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
|
||||
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
|
||||
}
|
||||
|
||||
# these values are expected to be python consts
|
||||
|
||||
7
test/external/external_test_onnx_backend.py
vendored
7
test/external/external_test_onnx_backend.py
vendored
@@ -83,8 +83,11 @@ backend_test.exclude('test_dequantizelinear_uint4_cpu')
|
||||
# we don't support indexes
|
||||
backend_test.exclude('test_nonzero_*')
|
||||
|
||||
# no support for mod
|
||||
backend_test.exclude('test_mod_*')
|
||||
# no support for fmod
|
||||
backend_test.exclude('test_mod_int64_fmod_cpu')
|
||||
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
|
||||
backend_test.exclude('test_mod_mixed_sign_float32_cpu')
|
||||
backend_test.exclude('test_mod_mixed_sign_float64_cpu')
|
||||
|
||||
# no boolean ops (2d, 3d, 4d)
|
||||
backend_test.exclude('test_bitshift_*')
|
||||
|
||||
@@ -558,6 +558,13 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([()], lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x)
|
||||
|
||||
def test_mod(self):
|
||||
helper_test_op(None, lambda x,y: x%y, Tensor.mod, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
|
||||
helper_test_op(None, lambda x,y: x%y, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
|
||||
helper_test_op(None, lambda x: x%2, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
|
||||
helper_test_op(None, lambda x: x%3, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
|
||||
helper_test_op(None, lambda x: 100%x, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
|
||||
|
||||
def test_mul_naninf(self):
|
||||
helper_test_op([(45,65)], lambda x: x*math.inf)
|
||||
helper_test_op([(45,65)], lambda x: x*-math.inf)
|
||||
|
||||
@@ -126,8 +126,8 @@ class TestVminVmaxDivMod(unittest.TestCase):
|
||||
# vmin and vmax for modulo of a variable with a range crossing zero
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x % 4
|
||||
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
|
||||
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
|
||||
self.assertEqual(uop.vmin, -3)
|
||||
self.assertEqual(uop.vmax, 3)
|
||||
|
||||
class TestVminVmaxVConst(unittest.TestCase):
|
||||
def test_vmin_vmax_vconst_single_element(self):
|
||||
|
||||
@@ -113,6 +113,9 @@ class Mul(Function):
|
||||
class IDiv(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
|
||||
|
||||
class Mod(Function):
|
||||
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
|
||||
|
||||
# ************* ternary ops *************
|
||||
|
||||
class Where(Function):
|
||||
|
||||
@@ -35,6 +35,7 @@ class SimpleMathTrait:
|
||||
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
|
||||
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
|
||||
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
|
||||
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
|
||||
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
|
||||
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
|
||||
|
||||
@@ -45,6 +46,7 @@ class SimpleMathTrait:
|
||||
def __mul__(self, x): return self.mul(x)
|
||||
def __truediv__(self, x): return self.div(x)
|
||||
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
|
||||
def __mod__(self, x): return self.mod(x)
|
||||
def __and__(self, x): return self.bitwise_and(x)
|
||||
def __or__(self, x): return self.bitwise_or(x)
|
||||
def __xor__(self, x): return self.xor(x)
|
||||
@@ -57,6 +59,7 @@ class SimpleMathTrait:
|
||||
def __rand__(self, x): return self.bitwise_and(x, True)
|
||||
def __ror__(self, x): return self.bitwise_or(x, True)
|
||||
def __rxor__(self, x): return self.xor(x, True)
|
||||
def __rmod__(self, x): return self.mod(x, True)
|
||||
|
||||
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
|
||||
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
|
||||
@@ -77,10 +80,6 @@ class MathTrait(SimpleMathTrait):
|
||||
def __rlshift__(self, x): return self.lshift(x, True)
|
||||
def __rrshift__(self, x): return self.rshift(x, True)
|
||||
|
||||
# not in Tensor
|
||||
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
|
||||
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
|
||||
|
||||
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
|
||||
def minimum(self, x): return -(-self).maximum(-x)
|
||||
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
|
||||
@@ -598,7 +597,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# SHL/SHR on consts only
|
||||
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
|
||||
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
|
||||
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
|
||||
if self.op is Ops.MOD and s1_vmin > 0:
|
||||
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
|
||||
if self.op is Ops.IDIV:
|
||||
if s1_vmin == s1_vmax: # min/max are equal in a CONST
|
||||
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin
|
||||
|
||||
@@ -3166,6 +3166,19 @@ class Tensor(SimpleMathTrait):
|
||||
numerator, denominator = self._broadcasted(x, reverse)
|
||||
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
|
||||
|
||||
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
Mod `self` by `x`.
|
||||
Equivalent to `self % x`.
|
||||
Supports broadcasting to a common shape, type promotion, and integer inputs.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
|
||||
```
|
||||
"""
|
||||
a, b = self._broadcasted(x, reverse)
|
||||
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
|
||||
|
||||
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
|
||||
"""
|
||||
Computes bitwise xor of `self` and `x`.
|
||||
|
||||
Reference in New Issue
Block a user