mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
only warn once
This commit is contained in:
@@ -6,15 +6,18 @@ import numpy as np
|
||||
# **** start with two base classes ****
|
||||
|
||||
class Tensor:
|
||||
did_float_warning = False
|
||||
|
||||
def __init__(self, data):
|
||||
if isinstance(data, list):
|
||||
data = np.array(data, dtype=np.float32)
|
||||
elif not isinstance(data, np.ndarray):
|
||||
raise TypeError("Error constructing tensor with %r" % data)
|
||||
|
||||
if data.dtype != np.float32:
|
||||
if data.dtype != np.float32 and not Tensor.did_float_warning:
|
||||
# warning? float64 is actually needed for numerical jacobian
|
||||
print("warning, %r isn't float32" % (data.shape,))
|
||||
Tensor.did_float_warning = True
|
||||
|
||||
self.data = data
|
||||
self.grad = None
|
||||
|
||||
Reference in New Issue
Block a user