mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-08 13:45:50 -05:00
use get data dtype
This commit is contained in:
@@ -70,12 +70,13 @@ class Tensor:
|
||||
def shape(self):
|
||||
return self.data.shape
|
||||
|
||||
@staticmethod
|
||||
def _get_data_dtype(data):
|
||||
return data.getdtype() if getattr(data, 'getdtype', None) else data.dtype
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
if getattr(self.data, 'getdtype', None):
|
||||
return self.data.getdtype()
|
||||
else:
|
||||
return self.data.dtype
|
||||
return Tensor._get_data_dtype(self.data)
|
||||
|
||||
# ***** creation helper functions *****
|
||||
|
||||
@@ -143,7 +144,7 @@ class Tensor:
|
||||
if isinstance(data, Device.buffers[device]):
|
||||
return data
|
||||
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
if Tensor._get_data_dtype(data) != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print(f"warning, {data.shape!r} isn't float32, it's {data.dtype}")
|
||||
Tensor.did_float_warning = True
|
||||
|
||||
Reference in New Issue
Block a user