mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix wino conv output dtype for half inputs (#7829)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
@@ -69,5 +69,13 @@ class TestWinograd(unittest.TestCase):
|
||||
self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
|
||||
self.assertLess(mem_ratio, 10)
|
||||
|
||||
def test_dtype(self):
|
||||
IC, OC, X, Y = 4,4,9,9
|
||||
x,w = Tensor.empty(1,IC,Y,X), Tensor.empty(OC,IC,3,3)
|
||||
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.default_float)
|
||||
|
||||
x,w = Tensor.empty(1,IC,Y,X,dtype=dtypes.half), Tensor.empty(OC,IC,3,3,dtype=dtypes.half)
|
||||
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.half)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
Reference in New Issue
Block a user