Tensor.mod (#8458)

it's a python style mod. possibily can be cleaner with a floor div

relaxed the vmin for MOD slightly for cstyle negatives mod, it's more correct and might fix other bugs
This commit is contained in:
chenyu
2024-12-31 11:31:42 -05:00
committed by GitHub
parent ae00fa3b28
commit f3fdec940d
8 changed files with 38 additions and 10 deletions

View File

@@ -63,6 +63,8 @@ Elementwise ops operate on a per element basis. They don't change the shape of t
::: tinygrad.Tensor.sub
::: tinygrad.Tensor.mul
::: tinygrad.Tensor.div
::: tinygrad.Tensor.idiv
::: tinygrad.Tensor.mod
::: tinygrad.Tensor.xor
::: tinygrad.Tensor.lshift
::: tinygrad.Tensor.rshift

View File

@@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto):
tensor_methods = {
op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan",
"Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh",
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf")
"Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod")
}
# these values are expected to be python consts

View File

@@ -83,8 +83,11 @@ backend_test.exclude('test_dequantizelinear_uint4_cpu')
# we don't support indexes
backend_test.exclude('test_nonzero_*')
# no support for mod
backend_test.exclude('test_mod_*')
# no support for fmod
backend_test.exclude('test_mod_int64_fmod_cpu')
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
backend_test.exclude('test_mod_mixed_sign_float32_cpu')
backend_test.exclude('test_mod_mixed_sign_float64_cpu')
# no boolean ops (2d, 3d, 4d)
backend_test.exclude('test_bitshift_*')

View File

@@ -558,6 +558,13 @@ class TestOps(unittest.TestCase):
helper_test_op([()], lambda x: x/2)
helper_test_op([()], lambda x: 2/x)
def test_mod(self):
helper_test_op(None, lambda x,y: x%y, Tensor.mod, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
helper_test_op(None, lambda x,y: x%y, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]])
helper_test_op(None, lambda x: x%2, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
helper_test_op(None, lambda x: x%3, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
helper_test_op(None, lambda x: 100%x, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]])
def test_mul_naninf(self):
helper_test_op([(45,65)], lambda x: x*math.inf)
helper_test_op([(45,65)], lambda x: x*-math.inf)

View File

@@ -126,8 +126,8 @@ class TestVminVmaxDivMod(unittest.TestCase):
# vmin and vmax for modulo of a variable with a range crossing zero
x = UOp.variable('x', -10, 10)
uop = x % 4
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
self.assertEqual(uop.vmin, -3)
self.assertEqual(uop.vmax, 3)
class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vconst_single_element(self):

View File

@@ -113,6 +113,9 @@ class Mul(Function):
class IDiv(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x // y
class Mod(Function):
def forward(self, x:UOp, y:UOp) -> UOp: return x % y
# ************* ternary ops *************
class Where(Function):

View File

@@ -35,6 +35,7 @@ class SimpleMathTrait:
def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse)
def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse)
def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse)
def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse)
def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x))
def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP))
@@ -45,6 +46,7 @@ class SimpleMathTrait:
def __mul__(self, x): return self.mul(x)
def __truediv__(self, x): return self.div(x)
def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv
def __mod__(self, x): return self.mod(x)
def __and__(self, x): return self.bitwise_and(x)
def __or__(self, x): return self.bitwise_or(x)
def __xor__(self, x): return self.xor(x)
@@ -57,6 +59,7 @@ class SimpleMathTrait:
def __rand__(self, x): return self.bitwise_and(x, True)
def __ror__(self, x): return self.bitwise_or(x, True)
def __rxor__(self, x): return self.xor(x, True)
def __rmod__(self, x): return self.mod(x, True)
def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x))
def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self)
@@ -77,10 +80,6 @@ class MathTrait(SimpleMathTrait):
def __rlshift__(self, x): return self.lshift(x, True)
def __rrshift__(self, x): return self.rshift(x, True)
# not in Tensor
def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x))
def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self)
def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x))
def minimum(self, x): return -(-self).maximum(-x)
def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y))
@@ -598,7 +597,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# SHL/SHR on consts only
if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2]
if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2]
if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1
if self.op is Ops.MOD and s1_vmin > 0:
return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1)
if self.op is Ops.IDIV:
if s1_vmin == s1_vmax: # min/max are equal in a CONST
if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin

View File

@@ -3166,6 +3166,19 @@ class Tensor(SimpleMathTrait):
numerator, denominator = self._broadcasted(x, reverse)
return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal()
def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Mod `self` by `x`.
Equivalent to `self % x`.
Supports broadcasting to a common shape, type promotion, and integer inputs.
```python exec="true" source="above" session="tensor" result="python"
print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy())
```
"""
a, b = self._broadcasted(x, reverse)
return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0)))
def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor:
"""
Computes bitwise xor of `self` and `x`.