diff --git a/test/test_ops.py b/test/test_ops.py index ba5af6e52d..17817dff75 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -537,7 +537,8 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=[[5, 6, 7],[1, 2, 3]]) helper_test_op(None, lambda x: x/2, forward_only=True, vals=[[3, 4, 5]]) helper_test_op(None, lambda x: x//2, forward_only=True, vals=[[3, 4, 5]]) - helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True, vals=[[5, -6, 7],[1, 2, 3]]) + helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True, + vals=[[-4, 7, 5, 4, -7, 8], [2, -3, 8, -2, 3, 5]]) if is_dtype_supported(dtypes.uint64): x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1) np.testing.assert_equal(x.numpy(), 2**64 - 1) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index f7996c0272..e4138845c6 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -44,7 +44,7 @@ class SimpleMathTrait: def __sub__(self, x): return self.sub(x) def __mul__(self, x): return self.mul(x) def __truediv__(self, x): return self.div(x) - def __floordiv__(self, x): return self.idiv(x) + def __floordiv__(self, x): return self.idiv(x) # TODO: idiv is trunc div, not floordiv def __and__(self, x): return self.bitwise_and(x) def __or__(self, x): return self.bitwise_or(x) def __xor__(self, x): return self.xor(x) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index cb02e7e173..5f0aa5fbe1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3100,10 +3100,10 @@ class Tensor(SimpleMathTrait): Divides `self` by `x`. Equivalent to `self // x`. Supports broadcasting to a common shape, type promotion, and integer inputs. - `idiv` performs integer division. + `idiv` performs integer division (truncate towards zero). ```python exec="true" source="above" session="tensor" result="python" - print(Tensor([1, 4, 10]).idiv(Tensor([2, 3, 4])).numpy()) + print(Tensor([-4, 7, 5, 4, -7, 8]).idiv(Tensor([2, -3, 8, -2, 3, 5])).numpy()) ``` """ return F.IDiv.apply(*self._broadcasted(x, reverse))