don't open devices from children (#4425)

* don't open devices from children

* correct way to do this

* fix Device.DEFAULT and add back JITBEAM
This commit is contained in:
George Hotz
2024-05-04 10:35:40 -07:00
committed by GitHub
parent fa17dcaf07
commit cf33afa778
2 changed files with 6 additions and 2 deletions

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
import multiprocessing
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, cast
@@ -24,6 +25,7 @@ class _Device:
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
assert multiprocessing.current_process().name == "MainProcess", "can only open device from parent process"
x = ix.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
@functools.cached_property
@@ -32,7 +34,9 @@ class _Device:
if device_from_env: return device_from_env
for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
try:
if self[device]: return device
if self[device]:
os.environ[device] = "1" # we set this in environment for spawned children
return device
except Exception: pass
raise RuntimeError("no usable devices")
Device = _Device()

View File

@@ -150,7 +150,7 @@ class TinyJit(Generic[ReturnType]):
# jit capture
self.expected_names: List[Union[int, str]] = expected_names
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])