fix tests

This commit is contained in:
George Hotz
2020-10-29 08:19:07 -07:00
parent 9ae3e9daf3
commit 5e7e359706
2 changed files with 5 additions and 3 deletions

View File

@@ -77,7 +77,7 @@ class TestConvSpeed(unittest.TestCase):
x = Tensor.randn(128, 1, 28, 28)
x = x.conv2d(c1).relu().avg_pool2d()
x = x.conv2d(c2).relu().max_pool2d()
x = x.reshape(Tensor(np.array((x.shape[0], -1))))
x = x.reshape(shape=(x.shape[0], -1))
out = x.dot(l1).logsoftmax()
out = out.mean()
et1 = time.time()

View File

@@ -13,9 +13,11 @@ class Tensor:
elif type(data) != np.ndarray:
print("error constructing tensor with %r" % data)
assert(False)
if data.dtype != np.float32:
# warning? float64 is actually needed for numerical jacobian
pass
# only float32
self.data = data.astype(np.float32)
self.data = data
self.grad = None
# internal variables used for autograd graph construction