add device to beam search cache key (#2333)

This commit is contained in:
chenyu
2023-11-16 18:35:08 -05:00
committed by GitHub
parent 75676ab8e1
commit 3b9dd3330c
2 changed files with 2 additions and 2 deletions

View File

@@ -103,7 +103,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
ret = lin.copy()
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)

View File

@@ -173,7 +173,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
CACHELEVEL = getenv("CACHELEVEL", 2)
VERSION = 6
VERSION = 7
_db_connection = None
def db_connection():
global _db_connection