diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index ecd5662bce..78e58a6c36 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -5,6 +5,7 @@ from dataclasses import replace from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name +from tinygrad.helpers import IGNORE_BEAM_CACHE from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.tensor import Tensor @@ -118,7 +119,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]: return acted_lins beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") -def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel: +def beam_search(lin:Kernel, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value) -> Kernel: global beam_pool key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix} if not disable_cache and CACHELEVEL >= 1 and (val:=diskcache_get("beam_search", key)) is not None: