for issue #1555, int64 and int8 in CI=1 ARM64=1 CLANG=1 (#1572)

* fixed for int8,int64, added dtype broadcasting test, passing all CI,ARM64,CLANG tests

* remove shifts
This commit is contained in:
corranr
2023-08-19 05:40:13 +01:00
committed by GitHub
parent ae39cf84ab
commit 68ebbd2954
2 changed files with 2 additions and 2 deletions

View File

@@ -48,6 +48,7 @@ def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype:DType):
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)+Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [2,4,6,8])
_assert_eq(Tensor([1,2,3,4], dtype=a_dtype)*Tensor([1,2,3,4], dtype=b_dtype), target_dtype, [1,4,9,16])
_assert_eq(Tensor([[1,2],[3,4]], dtype=a_dtype)@Tensor.eye(2, dtype=b_dtype), target_dtype, [[1,2],[3,4]])
_assert_eq(Tensor([1,1,1,1], dtype=a_dtype)+Tensor.ones((4,4), dtype=b_dtype), target_dtype, 2*Tensor.ones(4,4).numpy())
class TestBFloat16DType(unittest.TestCase):
def test_bf16_to_float(self):