Files
tinygrad/test/extra/test_helpers.py
Stan 9b6e57eccd helpers.py: improved test coverage + exception handling (#1165)
* Fixes + improved test coverage for helpers.py

- added exception handling in `proc`, if an exception was thrown, the thread would hang
- made `_early_exec_process` catch any Exception, before if an exception was thrown before the process was started, it would hand the thread

* Made `_early_exec_process` catch any Exception

 Otherwise, if an exception was thrown before the process was started, it would hang the thread. For example a type error for an argument passed to `subprocess.check_output`

* Fixed `from tinygrad.helpers import Timing` import

oops, for some reason my IDE cleaned that import from extra/helpers.

* Fixed import in llama.py

Another one that I skipped by accident, mybad

* Extracted a class for tests of early exec

* Normalize line endings, windows uses /r/n

* Made `cross_process` not a daemon
2023-07-07 10:26:05 -07:00

57 lines
2.0 KiB
Python

#!/usr/bin/env python
import os, cloudpickle, tempfile, unittest, subprocess
from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper
def normalize_line_endings(s): return s.replace(b'\r\n', b'\n')
class TestEarlyExec(unittest.TestCase):
def setUp(self) -> None:
self.early_exec = enable_early_exec()
def early_exec_py_file(self, file_content, exec_args):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
try:
output = self.early_exec((["python", temp_path] + exec_args, None))
return output
finally:
os.remove(temp_path)
def test_enable_early_exec(self):
output = self.early_exec_py_file(b'print("Hello, world!")', [])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_with_arg(self):
output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_process_exception(self):
with self.assertRaises(subprocess.CalledProcessError):
self.early_exec_py_file(b'raise Exception("Test exception")', [])
def test_enable_early_exec_type_exception(self):
with self.assertRaises(TypeError):
self.early_exec((["python"], "print('Hello, world!')"))
class TestCrossProcess(unittest.TestCase):
def test_cross_process(self):
def _iterate():
for i in range(10): yield i
results = list(cross_process(_iterate))
self.assertEqual(list(range(10)), results)
def test_cross_process_exception(self):
def _iterate():
for i in range(10):
if i == 5: raise ValueError("Test exception")
yield i
with self.assertRaises(ValueError): list(cross_process(_iterate))
def test_CloudpickleFunctionWrapper(self):
def add(x, y): return x + y
self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4))
if __name__ == '__main__':
unittest.main()