From 9e5e371999a1efc0c4826479935c5cd1ba06d8b8 Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 22 Apr 2025 10:32:54 -0400 Subject: [PATCH] make DISABLE_COMPILER_CACHE a ContextVar [pr] (#9983) --- test/unit/test_device.py | 10 ++++------ tinygrad/device.py | 4 ++-- tinygrad/helpers.py | 1 + 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/test/unit/test_device.py b/test/unit/test_device.py index b4dd37bcc2..1f20cebd49 100644 --- a/test/unit/test_device.py +++ b/test/unit/test_device.py @@ -1,10 +1,8 @@ #!/usr/bin/env python import unittest -from unittest.mock import patch -import os from tinygrad import Tensor from tinygrad.device import Device, Compiler -from tinygrad.helpers import diskcache_get, diskcache_put, getenv +from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context class TestDevice(unittest.TestCase): def test_canonicalize(self): @@ -38,20 +36,20 @@ class TestCompiler(unittest.TestCase): def test_compile_cached(self): diskcache_put("key", "123", None) # clear cache getenv.cache_clear() - with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True): + with Context(DISABLE_COMPILER_CACHE=0): self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123")) self.assertEqual(diskcache_get("key", "123"), str.encode("123")) def test_compile_cached_disabled(self): diskcache_put("disabled_key", "123", None) # clear cache getenv.cache_clear() - with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True): + with Context(DISABLE_COMPILER_CACHE=1): self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123")) self.assertIsNone(diskcache_get("disabled_key", "123")) def test_device_compile(self): getenv.cache_clear() - with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}): + with Context(DISABLE_COMPILER_CACHE=1): a = Tensor([0.,1.], device=Device.DEFAULT).realize() (a + 1).realize() diff --git a/tinygrad/device.py b/tinygrad/device.py index 4da9b84dd6..3764ec487a 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -4,7 +4,7 @@ from collections import defaultdict from typing import Optional, Any, Iterator, Generator import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \ - cpu_time_execution, colored, Context, round_up + cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype from tinygrad.renderer import Renderer @@ -292,7 +292,7 @@ class CPUProgram: class CompileError(Exception): pass class Compiler: - def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey + def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if DISABLE_COMPILER_CACHE else cachekey def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default def compile_cached(self, src:str) -> bytes: if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None: diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 87732fa2ea..0fcf90d1b0 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -115,6 +115,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1) PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) +DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)