mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
support using str to specify dtype (#5897)
* support using str to specify dtype in Tensor creation and args into `cast` and `bitcast`, and acc_dtype * more tests
This commit is contained in:
@@ -395,6 +395,23 @@ class TestTypeSpec(unittest.TestCase):
|
||||
subprocess.run(['DEFAULT_FLOAT=TYPO python3 -c "from tinygrad import dtypes"'],
|
||||
shell=True, check=True)
|
||||
|
||||
def test_dtype_str_arg(self):
|
||||
n = np.random.normal(0, 1, (10, 10)).astype(np.float32)
|
||||
tested = 0
|
||||
for dtype_str, dtype in [
|
||||
("bool", dtypes.bool), ("int8", dtypes.int8), ("int", dtypes.int), ("uint32", dtypes.uint32), ("float32", dtypes.float32)]:
|
||||
np.testing.assert_equal(Tensor(n, dtype=dtype_str).numpy(), Tensor(n, dtype=dtype).numpy())
|
||||
np.testing.assert_equal(Tensor(n).cast(dtype_str).numpy(), Tensor(n).cast(dtype).numpy())
|
||||
if dtype.itemsize == 4:
|
||||
np.testing.assert_equal(Tensor(n).bitcast(dtype_str).numpy(), Tensor(n).bitcast(dtype).numpy())
|
||||
tested += 1
|
||||
assert tested == 3
|
||||
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="nonexistdtype")
|
||||
with self.assertRaises(AttributeError): Tensor([1, 2, 3], dtype="")
|
||||
|
||||
np.testing.assert_equal(Tensor(n).sum(acc_dtype="int16").numpy(), Tensor(n).sum(acc_dtype=dtypes.int16).numpy())
|
||||
|
||||
@given(strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats))
|
||||
def test_creation(self, default_int, default_float):
|
||||
dtypes.default_int, dtypes.default_float = default_int, default_float
|
||||
|
||||
Reference in New Issue
Block a user