helpers.py: improved test coverage + exception handling (#1165)

* Fixes + improved test coverage for helpers.py

- added exception handling in `proc`, if an exception was thrown, the thread would hang
- made `_early_exec_process` catch any Exception, before if an exception was thrown before the process was started, it would hand the thread

* Made `_early_exec_process` catch any Exception

 Otherwise, if an exception was thrown before the process was started, it would hang the thread. For example a type error for an argument passed to `subprocess.check_output`

* Fixed `from tinygrad.helpers import Timing` import

oops, for some reason my IDE cleaned that import from extra/helpers.

* Fixed import in llama.py

Another one that I skipped by accident, mybad

* Extracted a class for tests of early exec

* Normalize line endings, windows uses /r/n

* Made `cross_process` not a daemon
This commit is contained in:
Stan
2023-07-07 19:26:05 +02:00
committed by GitHub
parent 8391648822
commit 9b6e57eccd
5 changed files with 70 additions and 37 deletions

View File

@@ -10,9 +10,8 @@ from tqdm import tqdm
np.set_printoptions(linewidth=200)
from typing import Optional, Tuple
from tinygrad.helpers import dtypes, getenv, DEBUG
from tinygrad.helpers import Timing, getenv, DEBUG
from tinygrad.lazy import Device
from extra.helpers import Timing
from tinygrad.tensor import Tensor
from tinygrad.nn import Embedding, Linear
from tinygrad.ops import GlobalCounters

View File

@@ -1,15 +1,13 @@
from tinygrad.helpers import Timing
from typing import Any
import multiprocessing, subprocess
import cloudpickle # type: ignore
import subprocess
import multiprocessing
from typing import Any
def _early_exec_process(qin, qout):
while True:
path, inp = qin.get()
try:
qout.put(subprocess.check_output(path, input=inp))
except subprocess.CalledProcessError as e:
except Exception as e:
qout.put(e)
def enable_early_exec():
@@ -26,32 +24,27 @@ def enable_early_exec():
return early_exec
def proc(itermaker, q) -> None:
for x in itermaker(): q.put(x)
q.put(None)
q.close()
try:
for x in itermaker(): q.put(x)
except Exception as e:
q.put(e)
finally:
q.put(None)
q.close()
class _CloudpickleFunctionWrapper:
def __init__(self, fn):
self.fn = fn
def __getstate__(self):
return cloudpickle.dumps(self.fn)
def __setstate__(self, pfn):
self.fn = cloudpickle.loads(pfn)
def __call__(self, *args, **kwargs) -> Any:
return self.fn(*args, **kwargs)
def __init__(self, fn): self.fn = fn
def __getstate__(self): return cloudpickle.dumps(self.fn)
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
def cross_process(itermaker, maxsize=16):
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
#p.daemon = True
p.start()
# TODO: write tests and handle exit case
while True:
ret = q.get()
if ret is None: break
yield ret
if isinstance(ret, Exception): raise ret
elif ret is None: break
else: yield ret

View File

@@ -4,7 +4,7 @@
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.helpers import colored
from extra.helpers import Timing
from tinygrad.helpers import Timing
from tinygrad.runtime.ops_gpu import CL
# TODO: support multidevice in cuda

View File

@@ -1,15 +1,57 @@
#!/usr/bin/env python
import multiprocessing
import unittest
from extra.helpers import cross_process
import os, cloudpickle, tempfile, unittest, subprocess
from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper
def normalize_line_endings(s): return s.replace(b'\r\n', b'\n')
class TestEarlyExec(unittest.TestCase):
def setUp(self) -> None:
self.early_exec = enable_early_exec()
def early_exec_py_file(self, file_content, exec_args):
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
try:
output = self.early_exec((["python", temp_path] + exec_args, None))
return output
finally:
os.remove(temp_path)
def test_enable_early_exec(self):
output = self.early_exec_py_file(b'print("Hello, world!")', [])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_with_arg(self):
output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"])
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
def test_enable_early_exec_process_exception(self):
with self.assertRaises(subprocess.CalledProcessError):
self.early_exec_py_file(b'raise Exception("Test exception")', [])
def test_enable_early_exec_type_exception(self):
with self.assertRaises(TypeError):
self.early_exec((["python"], "print('Hello, world!')"))
class TestCrossProcess(unittest.TestCase):
def test_cross_process(self):
def _iterate():
for i in range(3): yield i
ret = cross_process(lambda: _iterate())
assert len(list(ret)) == 3
for i in range(10): yield i
results = list(cross_process(_iterate))
self.assertEqual(list(range(10)), results)
def test_cross_process_exception(self):
def _iterate():
for i in range(10):
if i == 5: raise ValueError("Test exception")
yield i
with self.assertRaises(ValueError): list(cross_process(_iterate))
def test_CloudpickleFunctionWrapper(self):
def add(x, y): return x + y
self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4))
if __name__ == '__main__':
unittest.main()

View File

@@ -2,12 +2,11 @@ import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import safe_load, safe_save, get_state_dict
from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from extra.helpers import Timing
from tinygrad.helpers import Timing
from extra.utils import fetch_as_file, temp
from tinygrad.state import torch_load, get_state_dict
def compare_weights_both(url):
import torch