mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
added unit tests for 'argfix' (#10678)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import gzip, unittest
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import Context, ContextVar, argfix
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv, polyN, time_to_str, cdiv, cmod, getbits
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction, get_contraction_with_reduce
|
||||
@@ -352,5 +352,16 @@ class TestGetBits(unittest.TestCase):
|
||||
def test_single_bit(self):
|
||||
self.assertEqual(getbits(0b100000000, 8, 8), 1)
|
||||
|
||||
class TestArgFix(unittest.TestCase):
|
||||
def test_none(self):
|
||||
self.assertEqual(argfix(None), (None, ))
|
||||
self.assertEqual(argfix(None, None), (None, None))
|
||||
def test_positional_arguments(self):
|
||||
self.assertEqual(argfix(1, 2, 3), (1, 2, 3))
|
||||
def test_tuple(self):
|
||||
self.assertEqual(argfix((1., 2., 3.)), (1., 2., 3.))
|
||||
def test_list(self):
|
||||
self.assertEqual(argfix([True, False]), (True, False))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user