fix wino conv output dtype for half inputs (#7829)

This commit is contained in:
chenyu
2024-11-21 12:13:54 -05:00
committed by GitHub
parent cf1ec90ad4
commit 69e382216d
4 changed files with 13 additions and 104 deletions

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor, GlobalCounters
from tinygrad import Tensor, GlobalCounters, dtypes
from tinygrad.ops import Ops
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
from tinygrad.codegen.kernel import Kernel
@@ -69,5 +69,13 @@ class TestWinograd(unittest.TestCase):
self.assertLess(ops_ratio, 2.6) # TODO: there's issues with factorization now
self.assertLess(mem_ratio, 10)
def test_dtype(self):
IC, OC, X, Y = 4,4,9,9
x,w = Tensor.empty(1,IC,Y,X), Tensor.empty(OC,IC,3,3)
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.default_float)
x,w = Tensor.empty(1,IC,Y,X,dtype=dtypes.half), Tensor.empty(OC,IC,3,3,dtype=dtypes.half)
self.assertEqual(Tensor.conv2d(x,w).dtype, dtypes.half)
if __name__ == '__main__':
unittest.main(verbosity=2)