diff --git a/docs/tensor/elementwise.md b/docs/tensor/elementwise.md index 3d3858ad79..fae9de79f0 100644 --- a/docs/tensor/elementwise.md +++ b/docs/tensor/elementwise.md @@ -63,6 +63,8 @@ Elementwise ops operate on a per element basis. They don't change the shape of t ::: tinygrad.Tensor.sub ::: tinygrad.Tensor.mul ::: tinygrad.Tensor.div +::: tinygrad.Tensor.idiv +::: tinygrad.Tensor.mod ::: tinygrad.Tensor.xor ::: tinygrad.Tensor.lshift ::: tinygrad.Tensor.rshift diff --git a/extra/onnx.py b/extra/onnx.py index 63c5141036..bfc397ead2 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -81,7 +81,7 @@ def get_run_onnx(onnx_model: ModelProto): tensor_methods = { op:op.lower() for op in ("Neg", "Reciprocal", "Pow", "Sqrt", "Sign", "Abs", "Exp", "Log", "Mish", "Sin", "Cos", "Tan", "Asin", "Acos", "Atan", "Relu", "Sigmoid", "MatMul", "Floor", "Ceil", "IsInf", "IsNaN", "Softplus", "HardSwish", "Where", "Mul", "Sinh", "Cosh", "Tanh", - "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf") + "Softsign", "Asinh", "Acosh", "Atanh", "Elu", "Celu", "Selu", "Xor", "Round", "Erf", "Mod") } # these values are expected to be python consts diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 13a07b8853..dc0252320c 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -83,8 +83,11 @@ backend_test.exclude('test_dequantizelinear_uint4_cpu') # we don't support indexes backend_test.exclude('test_nonzero_*') -# no support for mod -backend_test.exclude('test_mod_*') +# no support for fmod +backend_test.exclude('test_mod_int64_fmod_cpu') +backend_test.exclude('test_mod_mixed_sign_float16_cpu') +backend_test.exclude('test_mod_mixed_sign_float32_cpu') +backend_test.exclude('test_mod_mixed_sign_float64_cpu') # no boolean ops (2d, 3d, 4d) backend_test.exclude('test_bitshift_*') diff --git a/test/test_ops.py b/test/test_ops.py index 17817dff75..b9e989b54d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -558,6 +558,13 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x/2) helper_test_op([()], lambda x: 2/x) + def test_mod(self): + helper_test_op(None, lambda x,y: x%y, Tensor.mod, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]]) + helper_test_op(None, lambda x,y: x%y, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]]) + helper_test_op(None, lambda x: x%2, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]]) + helper_test_op(None, lambda x: x%3, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]]) + helper_test_op(None, lambda x: 100%x, forward_only=True, vals=[[-4, 7, 5, 4, -7, 8]]) + def test_mul_naninf(self): helper_test_op([(45,65)], lambda x: x*math.inf) helper_test_op([(45,65)], lambda x: x*-math.inf) diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index f687df3749..008958140a 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -126,8 +126,8 @@ class TestVminVmaxDivMod(unittest.TestCase): # vmin and vmax for modulo of a variable with a range crossing zero x = UOp.variable('x', -10, 10) uop = x % 4 - self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive - self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4 + self.assertEqual(uop.vmin, -3) + self.assertEqual(uop.vmax, 3) class TestVminVmaxVConst(unittest.TestCase): def test_vmin_vmax_vconst_single_element(self): diff --git a/tinygrad/function.py b/tinygrad/function.py index 96b53d10a3..5527870711 100644 --- a/tinygrad/function.py +++ b/tinygrad/function.py @@ -113,6 +113,9 @@ class Mul(Function): class IDiv(Function): def forward(self, x:UOp, y:UOp) -> UOp: return x // y +class Mod(Function): + def forward(self, x:UOp, y:UOp) -> UOp: return x % y + # ************* ternary ops ************* class Where(Function): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5e64352a41..ab834d9432 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -35,6 +35,7 @@ class SimpleMathTrait: def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) + def mod(self, x, reverse=False): return self._binop(Ops.MOD, x, reverse) def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) @@ -45,6 +46,7 @@ class SimpleMathTrait: def __mul__(self, x): return self.mul(x) def __truediv__(self, x): return self.div(x) def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv + def __mod__(self, x): return self.mod(x) def __and__(self, x): return self.bitwise_and(x) def __or__(self, x): return self.bitwise_or(x) def __xor__(self, x): return self.xor(x) @@ -57,6 +59,7 @@ class SimpleMathTrait: def __rand__(self, x): return self.bitwise_and(x, True) def __ror__(self, x): return self.bitwise_or(x, True) def __rxor__(self, x): return self.xor(x, True) + def __rmod__(self, x): return self.mod(x, True) def __lt__(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) def __gt__(self, x): return self.ufix(x).alu(Ops.CMPLT, self) @@ -77,10 +80,6 @@ class MathTrait(SimpleMathTrait): def __rlshift__(self, x): return self.lshift(x, True) def __rrshift__(self, x): return self.rshift(x, True) - # not in Tensor - def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x)) - def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self) - def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x): return -(-self).maximum(-x) def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y)) @@ -598,7 +597,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass): # SHL/SHR on consts only if self.op is Ops.SHL and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] << t[2], t[1] << t[2] if self.op is Ops.SHR and s1_vmin == s1_vmax and all_int(t:=(s0_vmin, s0_vmax, s1_vmin)): return t[0] >> t[2], t[1] >> t[2] - if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1 + if self.op is Ops.MOD and s1_vmin > 0: + return (0, s1_vmax-1) if s0_vmin >= 0 else (-(s1_vmax-1), s1_vmax-1) if self.op is Ops.IDIV: if s1_vmin == s1_vmax: # min/max are equal in a CONST if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 32f1f486ab..48556ccfda 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3166,6 +3166,19 @@ class Tensor(SimpleMathTrait): numerator, denominator = self._broadcasted(x, reverse) return numerator.cast(least_upper_float(numerator.dtype)) * denominator.cast(least_upper_float(denominator.dtype)).reciprocal() + def mod(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: + """ + Mod `self` by `x`. + Equivalent to `self % x`. + Supports broadcasting to a common shape, type promotion, and integer inputs. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor([-4, 7, 5, 4, -7, 8]).mod(Tensor([2, -3, 8, -2, 3, 5])).numpy()) + ``` + """ + a, b = self._broadcasted(x, reverse) + return (r := F.Mod.apply(a, b)) + b * (((r < 0) & (b > 0)) | ((r > 0) & (b < 0))) + def xor(self, x:Union[Tensor, ConstType], reverse=False) -> Tensor: """ Computes bitwise xor of `self` and `x`.