mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
helpers.py: improved test coverage + exception handling (#1165)
* Fixes + improved test coverage for helpers.py - added exception handling in `proc`, if an exception was thrown, the thread would hang - made `_early_exec_process` catch any Exception, before if an exception was thrown before the process was started, it would hand the thread * Made `_early_exec_process` catch any Exception Otherwise, if an exception was thrown before the process was started, it would hang the thread. For example a type error for an argument passed to `subprocess.check_output` * Fixed `from tinygrad.helpers import Timing` import oops, for some reason my IDE cleaned that import from extra/helpers. * Fixed import in llama.py Another one that I skipped by accident, mybad * Extracted a class for tests of early exec * Normalize line endings, windows uses /r/n * Made `cross_process` not a daemon
This commit is contained in:
@@ -10,9 +10,8 @@ from tqdm import tqdm
|
||||
np.set_printoptions(linewidth=200)
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from tinygrad.helpers import dtypes, getenv, DEBUG
|
||||
from tinygrad.helpers import Timing, getenv, DEBUG
|
||||
from tinygrad.lazy import Device
|
||||
from extra.helpers import Timing
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Embedding, Linear
|
||||
from tinygrad.ops import GlobalCounters
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
from tinygrad.helpers import Timing
|
||||
from typing import Any
|
||||
import multiprocessing, subprocess
|
||||
import cloudpickle # type: ignore
|
||||
import subprocess
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
|
||||
def _early_exec_process(qin, qout):
|
||||
while True:
|
||||
path, inp = qin.get()
|
||||
try:
|
||||
qout.put(subprocess.check_output(path, input=inp))
|
||||
except subprocess.CalledProcessError as e:
|
||||
except Exception as e:
|
||||
qout.put(e)
|
||||
|
||||
def enable_early_exec():
|
||||
@@ -26,32 +24,27 @@ def enable_early_exec():
|
||||
return early_exec
|
||||
|
||||
def proc(itermaker, q) -> None:
|
||||
for x in itermaker(): q.put(x)
|
||||
q.put(None)
|
||||
q.close()
|
||||
try:
|
||||
for x in itermaker(): q.put(x)
|
||||
except Exception as e:
|
||||
q.put(e)
|
||||
finally:
|
||||
q.put(None)
|
||||
q.close()
|
||||
|
||||
class _CloudpickleFunctionWrapper:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __getstate__(self):
|
||||
return cloudpickle.dumps(self.fn)
|
||||
|
||||
def __setstate__(self, pfn):
|
||||
self.fn = cloudpickle.loads(pfn)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
return self.fn(*args, **kwargs)
|
||||
def __init__(self, fn): self.fn = fn
|
||||
def __getstate__(self): return cloudpickle.dumps(self.fn)
|
||||
def __setstate__(self, pfn): self.fn = cloudpickle.loads(pfn)
|
||||
def __call__(self, *args, **kwargs) -> Any: return self.fn(*args, **kwargs)
|
||||
|
||||
def cross_process(itermaker, maxsize=16):
|
||||
q: multiprocessing.Queue = multiprocessing.Queue(maxsize)
|
||||
# multiprocessing uses pickle which cannot dump lambdas, so use cloudpickle.
|
||||
p = multiprocessing.Process(target=proc, args=(_CloudpickleFunctionWrapper(itermaker), q))
|
||||
#p.daemon = True
|
||||
p.start()
|
||||
|
||||
# TODO: write tests and handle exit case
|
||||
while True:
|
||||
ret = q.get()
|
||||
if ret is None: break
|
||||
yield ret
|
||||
if isinstance(ret, Exception): raise ret
|
||||
elif ret is None: break
|
||||
else: yield ret
|
||||
2
test/external/external_multi_gpu.py
vendored
2
test/external/external_multi_gpu.py
vendored
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import colored
|
||||
from extra.helpers import Timing
|
||||
from tinygrad.helpers import Timing
|
||||
from tinygrad.runtime.ops_gpu import CL
|
||||
|
||||
# TODO: support multidevice in cuda
|
||||
|
||||
@@ -1,15 +1,57 @@
|
||||
#!/usr/bin/env python
|
||||
import multiprocessing
|
||||
import unittest
|
||||
from extra.helpers import cross_process
|
||||
import os, cloudpickle, tempfile, unittest, subprocess
|
||||
from extra.helpers import enable_early_exec, cross_process, _CloudpickleFunctionWrapper
|
||||
|
||||
def normalize_line_endings(s): return s.replace(b'\r\n', b'\n')
|
||||
|
||||
class TestEarlyExec(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.early_exec = enable_early_exec()
|
||||
|
||||
def early_exec_py_file(self, file_content, exec_args):
|
||||
with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp:
|
||||
temp.write(file_content)
|
||||
temp_path = temp.name
|
||||
try:
|
||||
output = self.early_exec((["python", temp_path] + exec_args, None))
|
||||
return output
|
||||
finally:
|
||||
os.remove(temp_path)
|
||||
|
||||
def test_enable_early_exec(self):
|
||||
output = self.early_exec_py_file(b'print("Hello, world!")', [])
|
||||
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
|
||||
|
||||
def test_enable_early_exec_with_arg(self):
|
||||
output = self.early_exec_py_file(b'import sys\nprint("Hello, " + sys.argv[1] + "!")', ["world"])
|
||||
self.assertEqual(b"Hello, world!\n", normalize_line_endings(output))
|
||||
|
||||
def test_enable_early_exec_process_exception(self):
|
||||
with self.assertRaises(subprocess.CalledProcessError):
|
||||
self.early_exec_py_file(b'raise Exception("Test exception")', [])
|
||||
|
||||
def test_enable_early_exec_type_exception(self):
|
||||
with self.assertRaises(TypeError):
|
||||
self.early_exec((["python"], "print('Hello, world!')"))
|
||||
|
||||
class TestCrossProcess(unittest.TestCase):
|
||||
|
||||
def test_cross_process(self):
|
||||
def _iterate():
|
||||
for i in range(3): yield i
|
||||
|
||||
ret = cross_process(lambda: _iterate())
|
||||
assert len(list(ret)) == 3
|
||||
for i in range(10): yield i
|
||||
results = list(cross_process(_iterate))
|
||||
self.assertEqual(list(range(10)), results)
|
||||
|
||||
def test_cross_process_exception(self):
|
||||
def _iterate():
|
||||
for i in range(10):
|
||||
if i == 5: raise ValueError("Test exception")
|
||||
yield i
|
||||
with self.assertRaises(ValueError): list(cross_process(_iterate))
|
||||
|
||||
def test_CloudpickleFunctionWrapper(self):
|
||||
def add(x, y): return x + y
|
||||
self.assertEqual(7, cloudpickle.loads(cloudpickle.dumps(_CloudpickleFunctionWrapper(add)))(3, 4))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -2,12 +2,11 @@ import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import safe_load, safe_save, get_state_dict
|
||||
from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from extra.helpers import Timing
|
||||
from tinygrad.helpers import Timing
|
||||
from extra.utils import fetch_as_file, temp
|
||||
from tinygrad.state import torch_load, get_state_dict
|
||||
|
||||
def compare_weights_both(url):
|
||||
import torch
|
||||
|
||||
Reference in New Issue
Block a user