From d0eb1540d5c8217fa5597fa94efb8dd58cecb979 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 5 May 2024 14:19:01 -0400 Subject: [PATCH] helpers.diskcache_clear (#4436) drop all tables in diskcache. added a unit test but disabled it by default because it will drop all cache... --- test/unit/test_disk_cache.py | 25 ++++++++++++++++++++++++- tinygrad/helpers.py | 5 +++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/unit/test_disk_cache.py b/test/unit/test_disk_cache.py index 0fd4300940..6ddf1abe71 100644 --- a/test/unit/test_disk_cache.py +++ b/test/unit/test_disk_cache.py @@ -1,6 +1,6 @@ import unittest import pickle -from tinygrad.helpers import diskcache_get, diskcache_put, diskcache +from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear def remote_get(table,q,k): q.put(diskcache_get(table, k)) def remote_put(table,k,v): diskcache_put(table, k, v) @@ -81,5 +81,28 @@ class DiskCache(unittest.TestCase): diskcache_put(table, "key", "test") self.assertEqual(diskcache_get(table, "key"), "test") + @unittest.skip("disabled by default because this drops cache table") + def test_clear_cache(self): + # clear cache to start + diskcache_clear() + tables = [f"test_clear_cache:{i}" for i in range(3)] + for table in tables: + # check no entries + self.assertIsNone(diskcache_get(table, "k")) + for table in tables: + diskcache_put(table, "k", "test") + # check insertion + self.assertEqual(diskcache_get(table, "k"), "test") + + diskcache_clear() + for table in tables: + # check no entries again + self.assertIsNone(diskcache_get(table, "k")) + + # calling multiple times is fine + diskcache_clear() + diskcache_clear() + diskcache_clear() + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 5bb91b7220..bb7c1c1721 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -155,6 +155,11 @@ def db_connection(): if DEBUG >= 7: _db_connection.set_trace_callback(print) return _db_connection +def diskcache_clear(): + cur = db_connection().cursor() + drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall() + cur.executescript("\n".join([s[0] for s in drop_tables])) + def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any: if CACHELEVEL == 0: return None if isinstance(key, (str,int)): key = {"key": key}