mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
15 lines
413 B
Python
15 lines
413 B
Python
import unittest
|
|
from tinygrad.tensor import Tensor
|
|
|
|
class TestConv(unittest.TestCase):
|
|
def test_simple(self):
|
|
x = Tensor.ones(1,12,128,256)
|
|
w = Tensor.ones(32,12,3,3)
|
|
ret = x.conv2d(w, padding=(1,1)).numpy()
|
|
# it's not 108 around the padding
|
|
assert (ret[:, :, 1:-1, 1:-1] == 108).all()
|
|
assert ret[0,0,0,0] == 48
|
|
assert ret[0,0,0,1] == 72
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |