mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
anyone else let down by the fast conv?
This commit is contained in:
@@ -22,7 +22,7 @@ class TinyBobNet:
|
||||
# create a model with a conv layer
|
||||
class TinyConvNet:
|
||||
def __init__(self):
|
||||
conv = 7
|
||||
conv = 5
|
||||
chans = 16
|
||||
self.c1 = Tensor(layer_init_uniform(chans,1,conv,conv))
|
||||
self.l1 = Tensor(layer_init_uniform(((28-conv+1)**2)*chans, 128))
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor, Conv2D
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
|
||||
x_init = np.random.randn(1,3).astype(np.float32)
|
||||
@@ -71,7 +71,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
wt = Tensor(w.detach().numpy())
|
||||
|
||||
out = torch.nn.functional.conv2d(x,w)
|
||||
ret = Conv2D.apply(Conv2D, xt, wt)
|
||||
ret = Tensor.conv2d(xt, wt)
|
||||
np.testing.assert_allclose(ret.data, out.detach().numpy(), atol=1e-5)
|
||||
|
||||
out.mean().backward()
|
||||
|
||||
Reference in New Issue
Block a user