From f90caa4b92cf33407ee60790968c8ce3503db9e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcin=20S=C5=82owik?= Date: Thu, 29 Feb 2024 22:04:21 +0100 Subject: [PATCH] Escape table name in diskcache queries. (#3543) Some devices create cache table names with non-alphanumerical characters, e.g. "compile_hip_gfx1010:xnack-_12". This commit escapes the table name in single quotes s.t. sqlite works (see https://github.com/tinygrad/tinygrad/issues/3538). --- test/unit/test_disk_cache.py | 5 +++++ tinygrad/helpers.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/test/unit/test_disk_cache.py b/test/unit/test_disk_cache.py index 5e55de3b6b..0fd4300940 100644 --- a/test/unit/test_disk_cache.py +++ b/test/unit/test_disk_cache.py @@ -76,5 +76,10 @@ class DiskCache(unittest.TestCase): self.assertEqual(diskcache_get(table, fancy_key), 5) self.assertEqual(diskcache_get(table, fancy_key3), None) + def test_table_name(self): + table = "test_gfx1010:xnack-" + diskcache_put(table, "key", "test") + self.assertEqual(diskcache_get(table, "key"), "test") + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 2d0f93fcb1..d19236de56 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -141,7 +141,7 @@ def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any: conn = db_connection() cur = conn.cursor() try: - res = cur.execute(f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values())) + res = cur.execute(f"SELECT val FROM '{table}_{VERSION}' WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values())) except sqlite3.OperationalError: return None # table doesn't exist if (val:=res.fetchone()) is not None: return pickle.loads(val[0]) @@ -156,9 +156,9 @@ def diskcache_put(table:str, key:Union[Dict, str, int], val:Any): if table not in _db_tables: TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"} ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys()) - cur.execute(f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))") + cur.execute(f"CREATE TABLE IF NOT EXISTS '{table}_{VERSION}' ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))") _db_tables.add(table) - cur.execute(f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501 + cur.execute(f"REPLACE INTO '{table}_{VERSION}' ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), )) # noqa: E501 conn.commit() cur.close() return val