mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
use ALLOW_DEVICE_USAGE context variable instead of MainProcess check (#10693)
* use DISALLOW_DEVICE_OPEN context variable instead of MainProcess check * device usage can be disallowed
This commit is contained in:
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass, replace
|
||||
from collections import defaultdict
|
||||
from typing import Optional, Any, Generic, TypeVar, Iterator, Generator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
import importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE
|
||||
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -22,8 +22,7 @@ class _Device:
|
||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
cpn = multiprocessing.current_process().name
|
||||
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
|
||||
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
||||
x = ix.split(":")[0].lower()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x}')) \
|
||||
if (cname.lower() == x + "device")][0](ix)
|
||||
|
||||
@@ -81,8 +81,10 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
|
||||
if hasattr(signal, "alarm"): signal.alarm(0)
|
||||
return x[0], ret
|
||||
|
||||
# workers should ignore ctrl c
|
||||
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
# workers should not open devices and should ignore ctrl c
|
||||
def _init_worker():
|
||||
Context(ALLOW_DEVICE_USAGE=0).__enter__()
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
|
||||
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]
|
||||
|
||||
|
||||
@@ -119,6 +119,7 @@ DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
|
||||
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
||||
QUANTIZE, VALIDATE_WITH_CPU, IGNORE_OOB = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("IGNORE_OOB", 1)
|
||||
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
|
||||
ALLOW_DEVICE_USAGE = ContextVar("ALLOW_DEVICE_USAGE", 1)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
Reference in New Issue
Block a user