make //0 return 0 in python_alu (#8131)

on master it raises because it cannot truncate inf to int, which crashes valid expression like `(t > 0).where(1//t, t)`.
This commit is contained in:
chenyu
2024-12-09 19:32:06 -05:00
committed by GitHub
parent f83d715f41
commit 917deb88a4
2 changed files with 10 additions and 3 deletions

View File

@@ -505,11 +505,18 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
if is_dtype_supported(dtypes.uint64):
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
# 1 // 0 is device dependent, but it should not raise
Tensor([1]).idiv(1).realize()
if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI
# ... because if might be in a where branch that the output is well defined
t = Tensor([-1, 0, 1, 2])
np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0])
def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)