diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5ae260c28b..3490fc9fd2 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -70,12 +70,13 @@ class Tensor: def shape(self): return self.data.shape + @staticmethod + def _get_data_dtype(data): + return data.getdtype() if getattr(data, 'getdtype', None) else data.dtype + @property def dtype(self): - if getattr(self.data, 'getdtype', None): - return self.data.getdtype() - else: - return self.data.dtype + return Tensor._get_data_dtype(self.data) # ***** creation helper functions ***** @@ -143,7 +144,7 @@ class Tensor: if isinstance(data, Device.buffers[device]): return data - if data.dtype != np.float32 and not Tensor.did_float_warning: + if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning: # warning? float64 is actually needed for numerical jacobian print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}") Tensor.did_float_warning = True