use ALLOW_DEVICE_USAGE context variable instead of MainProcess check (#10693)

* use DISALLOW_DEVICE_OPEN context variable instead of MainProcess check

* device usage can be disallowed
This commit is contained in:
George Hotz
2025-06-08 00:07:40 -07:00
committed by GitHub
parent dedff0e96c
commit 48eb7d76b1
3 changed files with 8 additions and 6 deletions

View File

@@ -2,9 +2,9 @@ from __future__ import annotations
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Optional, Any, Generic, TypeVar, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
import importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE, ALLOW_DEVICE_USAGE
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -22,8 +22,7 @@ class _Device:
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"can only open device {ix} from parent, not {cpn}"
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
x = ix.split(":")[0].lower()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x}')) \
if (cname.lower() == x + "device")][0](ix)

View File

@@ -81,8 +81,10 @@ def _try_compile_linearized_w_idx(x:tuple[int,Kernel], compiler:Compiler) -> tup
if hasattr(signal, "alarm"): signal.alarm(0)
return x[0], ret
# workers should ignore ctrl c
def _init_worker(): signal.signal(signal.SIGINT, signal.SIG_IGN)
# workers should not open devices and should ignore ctrl c
def _init_worker():
Context(ALLOW_DEVICE_USAGE=0).__enter__()
signal.signal(signal.SIGINT, signal.SIG_IGN)
def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_allocated() if buf is not None else buf for buf in bufs]

View File

@@ -119,6 +119,7 @@ DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE, VALIDATE_WITH_CPU, IGNORE_OOB = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0), ContextVar("IGNORE_OOB", 1)
CORRECT_DIVMOD_FOLDING, FUSE_OPTIM = ContextVar("CORRECT_DIVMOD_FOLDING", 0), ContextVar("FUSE_OPTIM", 0)
ALLOW_DEVICE_USAGE = ContextVar("ALLOW_DEVICE_USAGE", 1)
@dataclass(frozen=True)
class Metadata: