int / List[int] data -> dtypes.int32 (#2789)

This commit is contained in:
chenyu
2023-12-16 01:25:44 -05:00
committed by GitHub
parent dad4ee4539
commit c5fa9eb36e
4 changed files with 17 additions and 21 deletions

View File

@@ -226,18 +226,18 @@ class TestHelpers(unittest.TestCase):
class TestTypeSpec(unittest.TestCase):
def test_creation(self):
assert Tensor([]).dtype == Tensor.default_type
# assert Tensor([1]).dtype == dtypes.int
assert Tensor([1]).dtype == dtypes.int
assert Tensor([1.1]).dtype == Tensor.default_type
def test_const_full(self):
assert Tensor.ones([2,3]).dtype == Tensor.default_type
assert Tensor.zeros([2,3]).dtype == Tensor.default_type
assert Tensor.full([2,3], 3.3).dtype == Tensor.default_type
# assert Tensor.full([2,3], 3).dtype == dtypes.int
assert Tensor.full([2,3], 3).dtype == dtypes.int
def test_reduce_0d_default(self):
assert Tensor.ones([2,3,0]).sum(2).dtype == Tensor.default_type
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int
# assert Tensor.ones([2,3,0], dtype=dtypes.int).sum(2).dtype == dtypes.int # requires reduceop acc fix
def test_arange(self):
assert Tensor.arange(5).dtype == dtypes.int32