mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove numpy from dtype (#4969)
replaced all dtype.np with _to_np_dtype defined in tensor.py. after this, the only numpy usages are (1) Tensor(np.ndarray), (2) construct .numpy() output, (3) numpy random buffer
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Any, List
|
||||
from tinygrad.helpers import getenv, DEBUG, CI
|
||||
from tinygrad.dtype import DType, DTYPES_DICT, ImageDType, PtrDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
from test.helpers import is_dtype_supported, rand_for_dtype
|
||||
|
||||
@@ -51,10 +52,10 @@ def _test_cast(a:Tensor, target_dtype:DType):
|
||||
# TODO: cast between double and half are broken https://github.com/tinygrad/tinygrad/issues/4084
|
||||
return
|
||||
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(target_dtype.np)))
|
||||
_test_op(lambda: a.cast(target_dtype), target_dtype, list(a.numpy().astype(_to_np_dtype(target_dtype))))
|
||||
def _test_bitcast(a:Tensor, target_dtype:DType, target=None):
|
||||
if target_dtype == dtypes.bfloat16: raise unittest.SkipTest("no test for bf16 bitcast yet")
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(target_dtype.np).tolist())
|
||||
_test_op(lambda: a.bitcast(target_dtype), target_dtype, target or a.numpy().view(_to_np_dtype(target_dtype)).tolist())
|
||||
|
||||
class TestDType(unittest.TestCase):
|
||||
DTYPE: Any = None
|
||||
@@ -66,7 +67,8 @@ class TestDType(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if self.DTYPE is None: raise unittest.SkipTest("base class")
|
||||
|
||||
def test_to_np(self): _test_to_np(Tensor(self.DATA, dtype=self.DTYPE), self.DTYPE.np, np.array(self.DATA, dtype=self.DTYPE.np))
|
||||
def test_to_np(self):
|
||||
_test_to_np(Tensor(self.DATA, dtype=self.DTYPE), _to_np_dtype(self.DTYPE), np.array(self.DATA, dtype=_to_np_dtype(self.DTYPE)))
|
||||
|
||||
def test_casts_to(self): list(map(
|
||||
lambda dtype: _test_cast(Tensor(self.DATA, dtype=dtype), self.DTYPE),
|
||||
@@ -104,7 +106,7 @@ class TestDType(unittest.TestCase):
|
||||
def test_dtypes_fields(self):
|
||||
fields = dtypes.fields()
|
||||
self.assertTrue(all(isinstance(value, DType) for value in fields.values()))
|
||||
self.assertTrue(all(issubclass(value.np, np.generic) for value in fields.values() if value.np is not None))
|
||||
self.assertTrue(all(issubclass(_to_np_dtype(value), np.generic) for value in fields.values() if _to_np_dtype(value) is not None))
|
||||
|
||||
def test_resulting_and_init_dtypes_match(self):
|
||||
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
|
||||
|
||||
Reference in New Issue
Block a user