helpers.diskcache_clear (#4436)

drop all tables in diskcache. added a unit test but disabled it by default because it will drop all cache...
This commit is contained in:
chenyu
2024-05-05 14:19:01 -04:00
committed by GitHub
parent 595a6e3069
commit d0eb1540d5
2 changed files with 29 additions and 1 deletions

View File

@@ -1,6 +1,6 @@
import unittest
import pickle
from tinygrad.helpers import diskcache_get, diskcache_put, diskcache
from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear
def remote_get(table,q,k): q.put(diskcache_get(table, k))
def remote_put(table,k,v): diskcache_put(table, k, v)
@@ -81,5 +81,28 @@ class DiskCache(unittest.TestCase):
diskcache_put(table, "key", "test")
self.assertEqual(diskcache_get(table, "key"), "test")
@unittest.skip("disabled by default because this drops cache table")
def test_clear_cache(self):
# clear cache to start
diskcache_clear()
tables = [f"test_clear_cache:{i}" for i in range(3)]
for table in tables:
# check no entries
self.assertIsNone(diskcache_get(table, "k"))
for table in tables:
diskcache_put(table, "k", "test")
# check insertion
self.assertEqual(diskcache_get(table, "k"), "test")
diskcache_clear()
for table in tables:
# check no entries again
self.assertIsNone(diskcache_get(table, "k"))
# calling multiple times is fine
diskcache_clear()
diskcache_clear()
diskcache_clear()
if __name__ == "__main__":
unittest.main()

View File

@@ -155,6 +155,11 @@ def db_connection():
if DEBUG >= 7: _db_connection.set_trace_callback(print)
return _db_connection
def diskcache_clear():
cur = db_connection().cursor()
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
cur.executescript("\n".join([s[0] for s in drop_tables]))
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
if CACHELEVEL == 0: return None
if isinstance(key, (str,int)): key = {"key": key}