make DISABLE_COMPILER_CACHE a ContextVar [pr] (#9983)

This commit is contained in:
chenyu
2025-04-22 10:32:54 -04:00
committed by GitHub
parent bbc324f5dc
commit 9e5e371999
3 changed files with 7 additions and 8 deletions

View File

@@ -1,10 +1,8 @@
#!/usr/bin/env python
import unittest
from unittest.mock import patch
import os
from tinygrad import Tensor
from tinygrad.device import Device, Compiler
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context
class TestDevice(unittest.TestCase):
def test_canonicalize(self):
@@ -38,20 +36,20 @@ class TestCompiler(unittest.TestCase):
def test_compile_cached(self):
diskcache_put("key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
with Context(DISABLE_COMPILER_CACHE=0):
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))
def test_compile_cached_disabled(self):
diskcache_put("disabled_key", "123", None) # clear cache
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
with Context(DISABLE_COMPILER_CACHE=1):
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
self.assertIsNone(diskcache_get("disabled_key", "123"))
def test_device_compile(self):
getenv.cache_clear()
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}):
with Context(DISABLE_COMPILER_CACHE=1):
a = Tensor([0.,1.], device=Device.DEFAULT).realize()
(a + 1).realize()

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from typing import Optional, Any, Iterator, Generator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
cpu_time_execution, colored, Context, round_up
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
from tinygrad.renderer import Renderer
@@ -292,7 +292,7 @@ class CPUProgram:
class CompileError(Exception): pass
class Compiler:
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if DISABLE_COMPILER_CACHE else cachekey
def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
def compile_cached(self, src:str) -> bytes:
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:

View File

@@ -115,6 +115,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)