mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
make //0 return 0 in python_alu (#8131)
on master it raises because it cannot truncate inf to int, which crashes valid expression like `(t > 0).where(1//t, t)`.
This commit is contained in:
@@ -505,11 +505,18 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x,y: x//y, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
|
||||
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv, forward_only=True,
|
||||
vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
if is_dtype_supported(dtypes.uint64):
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
|
||||
np.testing.assert_equal(x.numpy(), 2**64 - 1)
|
||||
# 1 // 0 is device dependent, but it should not raise
|
||||
Tensor([1]).idiv(1).realize()
|
||||
if not (CI and (Device.DEFAULT=="LLVM" or getenv("PTX"))): # TODO: crashed in CI
|
||||
# ... because if might be in a where branch that the output is well defined
|
||||
t = Tensor([-1, 0, 1, 2])
|
||||
np.testing.assert_equal((t > 0).where(1//t, t).numpy(), [-1, 0, 1, 0])
|
||||
|
||||
def test_scalar_div(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1)
|
||||
|
||||
Reference in New Issue
Block a user