mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
winograd should be 4 kernels (#3268)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
|
||||
@@ -35,5 +35,12 @@ class TestWinograd(unittest.TestCase):
|
||||
out = Tensor.conv2d(x,w).realize()
|
||||
out.numpy()
|
||||
|
||||
def test_four_kernels(self):
|
||||
x,w = Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize()
|
||||
GlobalCounters.reset()
|
||||
out = Tensor.conv2d(x,w).realize()
|
||||
assert GlobalCounters.kernel_count == 4
|
||||
out.numpy()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
Reference in New Issue
Block a user