diff --git a/tinygrad/device.py b/tinygrad/device.py index 5acca1216b..a09fbdf724 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -1,4 +1,5 @@ from __future__ import annotations +import multiprocessing from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, cast @@ -24,6 +25,7 @@ class _Device: @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Compiled: if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}") + assert multiprocessing.current_process().name == "MainProcess", "can only open device from parent process" x = ix.split(":")[0].upper() return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501 @functools.cached_property @@ -32,7 +34,9 @@ class _Device: if device_from_env: return device_from_env for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]: try: - if self[device]: return device + if self[device]: + os.environ[device] = "1" # we set this in environment for spawned children + return device except Exception: pass raise RuntimeError("no usable devices") Device = _Device() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index a1cfae9da6..fbaae63aa8 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -150,7 +150,7 @@ class TinyJit(Generic[ReturnType]): # jit capture self.expected_names: List[Union[int, str]] = expected_names self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs - with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)): + with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)): capturing.append(self) self.ret = self.fxn(*args, **kwargs) if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])