use get data dtype

This commit is contained in:
George Hotz
2021-11-28 23:28:37 -05:00
parent 3cdc77f526
commit 06838481e8

View File

@@ -70,12 +70,13 @@ class Tensor:
def shape(self):
return self.data.shape
@staticmethod
def _get_data_dtype(data):
return data.getdtype() if getattr(data, 'getdtype', None) else data.dtype
@property
def dtype(self):
if getattr(self.data, 'getdtype', None):
return self.data.getdtype()
else:
return self.data.dtype
return Tensor._get_data_dtype(self.data)
# ***** creation helper functions *****
@@ -143,7 +144,7 @@ class Tensor:
if isinstance(data, Device.buffers[device]):
return data
if data.dtype != np.float32 and not Tensor.did_float_warning:
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
# warning? float64 is actually needed for numerical jacobian
print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}")
Tensor.did_float_warning = True