From 06838481e8aaaf2aa828c6b50cb152e357fafa58 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 28 Nov 2021 23:28:37 -0500 Subject: [PATCH] use get data dtype --- tinygrad/tensor.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5ae260c28b..3490fc9fd2 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -70,12 +70,13 @@ class Tensor: def shape(self): return self.data.shape + @staticmethod + def _get_data_dtype(data): + return data.getdtype() if getattr(data, 'getdtype', None) else data.dtype + @property def dtype(self): - if getattr(self.data, 'getdtype', None): - return self.data.getdtype() - else: - return self.data.dtype + return Tensor._get_data_dtype(self.data) # ***** creation helper functions ***** @@ -143,7 +144,7 @@ class Tensor: if isinstance(data, Device.buffers[device]): return data - if data.dtype != np.float32 and not Tensor.did_float_warning: + if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning: # warning? float64 is actually needed for numerical jacobian print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}") Tensor.did_float_warning = True