support using str to specify dtype (#5897)

* support using str to specify dtype

in Tensor creation and args into `cast` and `bitcast`, and acc_dtype

* more tests
This commit is contained in:
chenyu
2024-08-04 12:56:28 -04:00
committed by GitHub
parent 4f9221e8dd
commit c67e9887f7
4 changed files with 45 additions and 23 deletions

View File

@@ -395,6 +395,23 @@ class TestTypeSpec(unittest.TestCase):
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
shell=True, check=True)
def test_dtype_str_arg(self):
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
tested = 0
for dtype_str, dtype in [
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
if dtype.itemsize == 4:
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
tested += 1
assert tested == 3
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy())
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
def test_creation(self, default_int, default_float):
dtypes.default_int, dtypes.default_float = default_int, default_float