mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
don't open devices from children (#4425)
* don't open devices from children * correct way to do this * fix Device.DEFAULT and add back JITBEAM
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import multiprocessing
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Dict, Tuple, ClassVar, cast
|
||||
@@ -24,6 +25,7 @@ class _Device:
|
||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
if DEBUG >= 1: print(f"opening device {ix} from pid:{os.getpid()}")
|
||||
assert multiprocessing.current_process().name == "MainProcess", "can only open device from parent process"
|
||||
x = ix.split(":")[0].upper()
|
||||
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
||||
@functools.cached_property
|
||||
@@ -32,7 +34,9 @@ class _Device:
|
||||
if device_from_env: return device_from_env
|
||||
for device in ["METAL", "HSA", "CUDA", "GPU", "CLANG", "LLVM"]:
|
||||
try:
|
||||
if self[device]: return device
|
||||
if self[device]:
|
||||
os.environ[device] = "1" # we set this in environment for spawned children
|
||||
return device
|
||||
except Exception: pass
|
||||
raise RuntimeError("no usable devices")
|
||||
Device = _Device()
|
||||
|
||||
@@ -150,7 +150,7 @@ class TinyJit(Generic[ReturnType]):
|
||||
# jit capture
|
||||
self.expected_names: List[Union[int, str]] = expected_names
|
||||
self.expected_lbs: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]] = expected_lbs
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value)):
|
||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
|
||||
Reference in New Issue
Block a user