added unit tests for 'argfix' (#10678)

This commit is contained in:
ihar
2025-06-07 21:17:10 -05:00
committed by GitHub
parent 74b849b5e1
commit 40c1479267

View File

@@ -1,6 +1,6 @@
import gzip, unittest
from tinygrad import Variable
from tinygrad.helpers import Context, ContextVar
from tinygrad.helpers import Context, ContextVar, argfix
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
from tinygrad.tensor import get_shape
from tinygrad.codegen.lowerer import get_contraction, get_contraction_with_reduce
@@ -352,5 +352,16 @@ class TestGetBits(unittest.TestCase):
def test_single_bit(self):
self.assertEqual(getbits(0b100000000, 8, 8), 1)
class TestArgFix(unittest.TestCase):
def test_none(self):
self.assertEqual(argfix(None), (None, ))
self.assertEqual(argfix(None, None), (None, None))
def test_positional_arguments(self):
self.assertEqual(argfix(1, 2, 3), (1, 2, 3))
def test_tuple(self):
self.assertEqual(argfix((1., 2., 3.)), (1., 2., 3.))
def test_list(self):
self.assertEqual(argfix([True, False]), (True, False))
if __name__ == '__main__':
unittest.main()