remove numpy from dtype (#4969)

replaced all dtype.np with _to_np_dtype defined in tensor.py.

after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
This commit is contained in:
chenyu
2024-06-14 15:38:45 -04:00
committed by GitHub
parent 62dc36d371
commit 67e8df4969
17 changed files with 73 additions and 63 deletions

View File

@@ -5,6 +5,7 @@ from typing import Any, List
from tinygrad.helpers import getenv, DEBUG, CI
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
from tinygrad import Device, Tensor, dtypes
from tinygrad.tensor import _to_np_dtype
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported, rand_for_dtype
@@ -51,10 +52,10 @@ def _test_cast(a:Tensor, target_dtype:DType):
# TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084
return
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
class TestDType(unittest.TestCase):
DTYPE: Any = None
@@ -66,7 +67,8 @@ class TestDType(unittest.TestCase):
def setUp(self):
if self.DTYPE is None: raise unittest.SkipTest("base class")
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
def test_to_np(self):
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
def test_casts_to(self): list(map(
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
@@ -104,7 +106,7 @@ class TestDType(unittest.TestCase):
def test_dtypes_fields(self):
fields = dtypes.fields()
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))