mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-26 15:28:10 -05:00
add device to beam search cache key (#2333)
This commit is contained in:
@@ -103,7 +103,7 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int, allow_test_size=True) -> Linearizer:
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size}
|
||||
key = {"ast": str(lin.ast), "amt": amt, "allow_test_size": allow_test_size, "device": Device.DEFAULT}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE") and CACHELEVEL >= 1:
|
||||
ret = lin.copy()
|
||||
for o in val[len(lin.applied_opts):]: ret.apply_opt(o)
|
||||
|
||||
@@ -173,7 +173,7 @@ _cache_dir: str = getenv("XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches"
|
||||
CACHEDB: str = getenv("CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db")))
|
||||
CACHELEVEL = getenv("CACHELEVEL", 2)
|
||||
|
||||
VERSION = 6
|
||||
VERSION = 7
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
|
||||
Reference in New Issue
Block a user