ALLOW_DEVICE_USAGE=0 in codegen (#14238)

This commit is contained in:
qazal
2026-01-20 01:15:16 -05:00
committed by GitHub
parent 0243f4a0f1
commit dddd0e384f
3 changed files with 10 additions and 5 deletions

View File

@@ -1,6 +1,6 @@
from typing import cast
import itertools
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
from tinygrad.renderer import Renderer, ProgramSpec
@@ -26,6 +26,7 @@ pm_syntactic_sugar = PatternMatcher([
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
])
@Context(ALLOW_DEVICE_USAGE=0)
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()

View File

@@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
from tinygrad.device import Buffer
from tinygrad.dtype import dtypes, ImageDType
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
from tinygrad.helpers import IMAGE, ALLOW_TF32, count
from tinygrad.helpers import IMAGE, ALLOW_TF32, count, Context
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
from tinygrad.codegen.simplify import pm_flatten_range
from tinygrad.renderer import Renderer
@@ -342,7 +342,9 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
elif BEAM >= 1:
from tinygrad.codegen.opt.search import beam_search
rawbufs = bufs_from_ast(ast, ren.device)
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
# beam search may open devices
with Context(ALLOW_DEVICE_USAGE=1):
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet

View File

@@ -20,10 +20,12 @@ class _Device:
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT)
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
def __getitem__(self, ix:str) -> Compiled:
ix = self.canonicalize(ix)
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
return self.__get_canonicalized_item(ix)
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
base = (__package__ or __name__).split('.')[0] # tinygrad
x = ix.split(":")[0].lower()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \