remove SQRT hack in llvm (#9067)

replaced with xpow 0.5 in transcendental. fixed sqrt(0) backward
This commit is contained in:
chenyu
2025-02-13 15:42:34 -05:00
committed by GitHub
parent 947c97e6ff
commit e02e3b94c3
5 changed files with 8 additions and 13 deletions

View File

@@ -1054,7 +1054,7 @@ class TestIndexing(unittest.TestCase):
one = Tensor(1, dtype=dtypes.int64)
# non-scalar indexed with scalars
a = Tensor.randn(2, 3)
a = Tensor.randn(2, 3).realize()
numpy_testing_assert_equal_helper(a[0], a[zero])
numpy_testing_assert_equal_helper(a[0][1], a[zero][one])
numpy_testing_assert_equal_helper(a[0, 1], a[zero, one])
@@ -1066,7 +1066,7 @@ class TestIndexing(unittest.TestCase):
numpy_testing_assert_equal_helper(a[1], a[one.cast(dtypes.int16)])
# scalar indexed with scalar
r = Tensor.randn()
r = Tensor.randn().realize()
with self.assertRaises(IndexError):
r[:]
with self.assertRaises(IndexError):