From 40c147926727ea6404995ab8e8dbaecadae1e807 Mon Sep 17 00:00:00 2001 From: ihar Date: Sat, 7 Jun 2025 21:17:10 -0500 Subject: [PATCH] added unit tests for 'argfix' (#10678) --- test/unit/test_helpers.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/unit/test_helpers.py b/test/unit/test_helpers.py index e820e31fde..b9c0616737 100644 --- a/test/unit/test_helpers.py +++ b/test/unit/test_helpers.py @@ -1,6 +1,6 @@ import gzip, unittest from tinygrad import Variable -from tinygrad.helpers import Context, ContextVar +from tinygrad.helpers import Context, ContextVar, argfix from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits from tinygrad.tensor import get_shape from tinygrad.codegen.lowerer import get_contraction, get_contraction_with_reduce @@ -352,5 +352,16 @@ class TestGetBits(unittest.TestCase): def test_single_bit(self): self.assertEqual(getbits(0b100000000, 8, 8), 1) +class TestArgFix(unittest.TestCase): + def test_none(self): + self.assertEqual(argfix(None), (None, )) + self.assertEqual(argfix(None, None), (None, None)) + def test_positional_arguments(self): + self.assertEqual(argfix(1, 2, 3), (1, 2, 3)) + def test_tuple(self): + self.assertEqual(argfix((1., 2., 3.)), (1., 2., 3.)) + def test_list(self): + self.assertEqual(argfix([True, False]), (True, False)) + if __name__ == '__main__': unittest.main()