mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
ALLOW_DEVICE_USAGE=0 in codegen (#14238)
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from typing import cast
|
||||
import itertools
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey
|
||||
from tinygrad.helpers import DEVECTORIZE, TRANSCENDENTAL, SPEC, DEBUG, getenv, TracingKey, Context
|
||||
from tinygrad.uop.ops import PatternMatcher, graph_rewrite, UOp, pm_lower_index_dtype, Ops, UPat, track_rewrites, KernelInfo, pyrender
|
||||
from tinygrad.uop.spec import type_verify, program_spec, kernel_spec
|
||||
from tinygrad.renderer import Renderer, ProgramSpec
|
||||
@@ -26,6 +26,7 @@ pm_syntactic_sugar = PatternMatcher([
|
||||
lambda i1,i2: i2.replace(src=i1.src+i2.src[1:]) if isinstance(i1.dtype, PtrDType) and not isinstance(i2.dtype, PtrDType) else None),
|
||||
])
|
||||
|
||||
@Context(ALLOW_DEVICE_USAGE=0)
|
||||
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
|
||||
if ren is None: ren = Renderer()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.uop.ops import axis_letters, axis_colors, axis_to_pos
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.dtype import dtypes, ImageDType
|
||||
from tinygrad.helpers import colored, BEAM, getenv, DEBUG, to_function_name, NOOPT, argsort, round_up, prod, merge_dicts, get_single_element, flatten
|
||||
from tinygrad.helpers import IMAGE, ALLOW_TF32, count
|
||||
from tinygrad.helpers import IMAGE, ALLOW_TF32, count, Context
|
||||
from tinygrad.codegen.opt import Opt, OptOps, KernelOptError, check
|
||||
from tinygrad.codegen.simplify import pm_flatten_range
|
||||
from tinygrad.renderer import Renderer
|
||||
@@ -342,7 +342,9 @@ def apply_opts(ast:UOp, ren:Renderer) -> UOp:
|
||||
elif BEAM >= 1:
|
||||
from tinygrad.codegen.opt.search import beam_search
|
||||
rawbufs = bufs_from_ast(ast, ren.device)
|
||||
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
# beam search may open devices
|
||||
with Context(ALLOW_DEVICE_USAGE=1):
|
||||
k = beam_search(k, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
elif not NOOPT and (ast.arg is None or ast.arg.applied_opts == ()):
|
||||
from tinygrad.codegen.opt.heuristic import hand_coded_optimizations
|
||||
# NOTE: hand_coded_optimizations doesn't support multiblock opts yet
|
||||
|
||||
@@ -20,10 +20,12 @@ class _Device:
|
||||
def _canonicalize(self, device:str) -> str: return re.sub(r":0$", "", (d:=device.split(":", 1)[0].upper()) + device[len(d):])
|
||||
# NOTE: you can't cache canonicalize in case Device.DEFAULT changes
|
||||
def canonicalize(self, device:str|None) -> str: return self._canonicalize(device if device is not None else Device.DEFAULT)
|
||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||
def __getitem__(self, ix:str) -> Compiled:
|
||||
ix = self.canonicalize(ix)
|
||||
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
||||
return self.__get_canonicalized_item(ix)
|
||||
@functools.cache # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||
assert ALLOW_DEVICE_USAGE or ix.split(":")[0] in ["DISK", "TINYFS", "NPY", "PYTHON"], f"usage of device {ix} disallowed"
|
||||
base = (__package__ or __name__).split('.')[0] # tinygrad
|
||||
x = ix.split(":")[0].lower()
|
||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'{base}.runtime.ops_{x}')) \
|
||||
|
||||
Reference in New Issue
Block a user