mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
make DISABLE_COMPILER_CACHE a ContextVar [pr] (#9983)
This commit is contained in:
@@ -1,10 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.device import Device, Compiler
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put, getenv
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put, getenv, Context
|
||||
|
||||
class TestDevice(unittest.TestCase):
|
||||
def test_canonicalize(self):
|
||||
@@ -38,20 +36,20 @@ class TestCompiler(unittest.TestCase):
|
||||
def test_compile_cached(self):
|
||||
diskcache_put("key", "123", None) # clear cache
|
||||
getenv.cache_clear()
|
||||
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "0"}, clear=True):
|
||||
with Context(DISABLE_COMPILER_CACHE=0):
|
||||
self.assertEqual(MockCompiler("key").compile_cached("123"), str.encode("123"))
|
||||
self.assertEqual(diskcache_get("key", "123"), str.encode("123"))
|
||||
|
||||
def test_compile_cached_disabled(self):
|
||||
diskcache_put("disabled_key", "123", None) # clear cache
|
||||
getenv.cache_clear()
|
||||
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}, clear=True):
|
||||
with Context(DISABLE_COMPILER_CACHE=1):
|
||||
self.assertEqual(MockCompiler("disabled_key").compile_cached("123"), str.encode("123"))
|
||||
self.assertIsNone(diskcache_get("disabled_key", "123"))
|
||||
|
||||
def test_device_compile(self):
|
||||
getenv.cache_clear()
|
||||
with patch.dict(os.environ, {"DISABLE_COMPILER_CACHE": "1"}):
|
||||
with Context(DISABLE_COMPILER_CACHE=1):
|
||||
a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
(a + 1).realize()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from typing import Optional, Any, Iterator, Generator
|
||||
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, ctypes.util, platform, contextlib, sys, re, atexit, pickle, decimal, time
|
||||
from tinygrad.helpers import CI, OSX, LRU, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILE, temp, mv_address, \
|
||||
cpu_time_execution, colored, Context, round_up
|
||||
cpu_time_execution, colored, Context, round_up, DISABLE_COMPILER_CACHE
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType, dtypes, _to_np_dtype
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -292,7 +292,7 @@ class CPUProgram:
|
||||
class CompileError(Exception): pass
|
||||
|
||||
class Compiler:
|
||||
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if getenv("DISABLE_COMPILER_CACHE") else cachekey
|
||||
def __init__(self, cachekey:Optional[str]=None): self.cachekey = None if DISABLE_COMPILER_CACHE else cachekey
|
||||
def compile(self, src:str) -> bytes: return src.encode() # NOTE: empty compiler is the default
|
||||
def compile_cached(self, src:str) -> bytes:
|
||||
if self.cachekey is None or (lib := diskcache_get(self.cachekey, src)) is None:
|
||||
|
||||
@@ -115,6 +115,7 @@ FUSE_ARANGE, FUSE_CONV_BW = ContextVar("FUSE_ARANGE", 0), ContextVar("FUSE_CONV_
|
||||
SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), ContextVar("NO_MEMORY_PLANNER", 0), ContextVar("RING", 1)
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
DISABLE_COMPILER_CACHE = ContextVar("DISABLE_COMPILER_CACHE", 0)
|
||||
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
||||
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user