diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index e7ba52a044..ecb5b62e51 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -1,6 +1,6 @@ from typing import cast import itertools -from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey +from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender from tinygrad.uop.spec import type_verify, program_spec, kernel_spec from tinygrad.renderer import Renderer, ProgramSpec @@ -26,6 +26,7 @@ pm_syntactic_sugar = PatternMatcher([ lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None), ]) +@Context(ALLOW_DEVICE_USAGE=0) def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp: if ren is None: ren = Renderer() diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index e0ea718385..9179cd3cb8 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos from tinygrad.device import Buffer from tinygrad.dtype import dtypes, ImageDType from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten -from tinygrad.helpers import IMAGE, ALLOW_TF32, count +from tinygrad.helpers import IMAGE, ALLOW_TF32, count, Context from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check from tinygrad.codegen.simplify import pm_flatten_range from tinygrad.renderer import Renderer @@ -342,7 +342,9 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp: elif BEAM >= 1: from tinygrad.codegen.opt.search import beam_search rawbufs = bufs_from_ast(ast, ren.device) - k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) + # beam search may open devices + with Context(ALLOW_DEVICE_USAGE=1): + k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))) elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()): from tinygrad.codegen.opt.heuristic import hand_coded_optimizations # NOTE: hand_coded_optimizations doesn't support multiblock opts yet diff --git a/tinygrad/device.py b/tinygrad/device.py index bb2707394b..40889c7bd4 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -20,10 +20,12 @@ class _Device: def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):]) # NOTE: you can't cache canonicalize in case Device.DEFAULT changes def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT) - def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix)) + def __getitem__(self, ix:str) -> Compiled: + ix = self.canonicalize(ix) + assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed" + return self.__get_canonicalized_item(ix) @functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Compiled: - assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed" base = (__package__ or __name__).split('.')[0] # tinygrad x = ix.split(":")[0].lower() ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \