mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
helpers.diskcache_clear (#4436)
drop all tables in diskcache. added a unit test but disabled it by default because it will drop all cache...
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
import unittest
|
import unittest
|
||||||
import pickle
|
import pickle
|
||||||
from tinygrad.helpers import diskcache_get, diskcache_put, diskcache
|
from tinygrad.helpers import diskcache_get, diskcache_put, diskcache, diskcache_clear
|
||||||
|
|
||||||
def remote_get(table,q,k): q.put(diskcache_get(table, k))
|
def remote_get(table,q,k): q.put(diskcache_get(table, k))
|
||||||
def remote_put(table,k,v): diskcache_put(table, k, v)
|
def remote_put(table,k,v): diskcache_put(table, k, v)
|
||||||
@@ -81,5 +81,28 @@ class DiskCache(unittest.TestCase):
|
|||||||
diskcache_put(table, "key", "test")
|
diskcache_put(table, "key", "test")
|
||||||
self.assertEqual(diskcache_get(table, "key"), "test")
|
self.assertEqual(diskcache_get(table, "key"), "test")
|
||||||
|
|
||||||
|
@unittest.skip("disabled by default because this drops cache table")
|
||||||
|
def test_clear_cache(self):
|
||||||
|
# clear cache to start
|
||||||
|
diskcache_clear()
|
||||||
|
tables = [f"test_clear_cache:{i}" for i in range(3)]
|
||||||
|
for table in tables:
|
||||||
|
# check no entries
|
||||||
|
self.assertIsNone(diskcache_get(table, "k"))
|
||||||
|
for table in tables:
|
||||||
|
diskcache_put(table, "k", "test")
|
||||||
|
# check insertion
|
||||||
|
self.assertEqual(diskcache_get(table, "k"), "test")
|
||||||
|
|
||||||
|
diskcache_clear()
|
||||||
|
for table in tables:
|
||||||
|
# check no entries again
|
||||||
|
self.assertIsNone(diskcache_get(table, "k"))
|
||||||
|
|
||||||
|
# calling multiple times is fine
|
||||||
|
diskcache_clear()
|
||||||
|
diskcache_clear()
|
||||||
|
diskcache_clear()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -155,6 +155,11 @@ def db_connection():
|
|||||||
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
if DEBUG >= 7: _db_connection.set_trace_callback(print)
|
||||||
return _db_connection
|
return _db_connection
|
||||||
|
|
||||||
|
def diskcache_clear():
|
||||||
|
cur = db_connection().cursor()
|
||||||
|
drop_tables = cur.execute("SELECT 'DROP TABLE IF EXISTS ' || quote(name) || ';' FROM sqlite_master WHERE type = 'table';").fetchall()
|
||||||
|
cur.executescript("\n".join([s[0] for s in drop_tables]))
|
||||||
|
|
||||||
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
||||||
if CACHELEVEL == 0: return None
|
if CACHELEVEL == 0: return None
|
||||||
if isinstance(key, (str,int)): key = {"key": key}
|
if isinstance(key, (str,int)): key = {"key": key}
|
||||||
|
|||||||
Reference in New Issue
Block a user