mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
universal disk cache (#2130)
* caching infra for tinygrad * nons tr key * fix linter * no shelve in beam search * beam search caching * check tensor cores with beam too * pretty print * LATEBEAM in stable diffusion
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
@@ -636,14 +636,15 @@ if __name__ == "__main__":
|
||||
if args.seed is not None: Tensor._seed = args.seed
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
|
||||
|
||||
14
extra/dump_cache.py
Normal file
14
extra/dump_cache.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import sys
|
||||
import sqlite3
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
for f in cur.fetchall():
|
||||
table = f[0]
|
||||
cur2 = conn.cursor()
|
||||
cur2.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
cnt = cur2.fetchone()[0]
|
||||
print(f"{table:20s} : {cnt}")
|
||||
47
test/unit/test_disk_cache.py
Normal file
47
test/unit/test_disk_cache.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import unittest
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put
|
||||
|
||||
def remote_get(q,k): q.put(diskcache_get("test", k))
|
||||
def remote_put(k,v): diskcache_put("test", k, v)
|
||||
|
||||
class DiskCache(unittest.TestCase):
|
||||
def test_putget(self):
|
||||
diskcache_put("test", "hello", "world")
|
||||
self.assertEqual(diskcache_get("test", "hello"), "world")
|
||||
diskcache_put("test", "hello", "world2")
|
||||
self.assertEqual(diskcache_get("test", "hello"), "world2")
|
||||
|
||||
def test_putcomplex(self):
|
||||
diskcache_put("test", "k", ("complex", 123, "object"))
|
||||
ret = diskcache_get("test", "k")
|
||||
self.assertEqual(ret, ("complex", 123, "object"))
|
||||
|
||||
def test_getotherprocess(self):
|
||||
from multiprocessing import Process, Queue
|
||||
diskcache_put("test", "k", "getme")
|
||||
q = Queue()
|
||||
p = Process(target=remote_get, args=(q,"k"))
|
||||
p.start()
|
||||
p.join()
|
||||
self.assertEqual(q.get(), "getme")
|
||||
|
||||
def test_putotherprocess(self):
|
||||
from multiprocessing import Process
|
||||
p = Process(target=remote_put, args=("k", "remote"))
|
||||
p.start()
|
||||
p.join()
|
||||
self.assertEqual(diskcache_get("test", "k"), "remote")
|
||||
|
||||
def test_no_table(self):
|
||||
self.assertIsNone(diskcache_get("faketable", "k"))
|
||||
|
||||
def test_ret(self):
|
||||
self.assertEqual(diskcache_put("test", "key", ("vvs",)), ("vvs",))
|
||||
|
||||
def test_non_str_key(self):
|
||||
diskcache_put("test", 4, 5)
|
||||
self.assertEqual(diskcache_get("test", 4), 5)
|
||||
self.assertEqual(diskcache_get("test", "4"), 5)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -164,8 +164,8 @@ class Kernel:
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self, pad=None) -> str:
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
|
||||
def colored_shape(self, pad=None, dense=False) -> str:
|
||||
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
|
||||
if pad: ret += ' '*(pad-ansilen(ret))
|
||||
return ret
|
||||
def printbufs(self, prefix=""):
|
||||
|
||||
@@ -160,7 +160,7 @@ class OptimizedKernel(Kernel):
|
||||
# ******************** high level optimizers ********************
|
||||
|
||||
# TODO: unify this
|
||||
def apply_tensor_cores(self, use_tensor_cores=1):
|
||||
def apply_tensor_cores(self, use_tensor_cores=1) -> bool:
|
||||
# should use HIP tensor cores?
|
||||
if use_tensor_cores != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
|
||||
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, getenv, ImageDType, flatten, DEBUG
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
@@ -20,11 +20,9 @@ actions += [
|
||||
]
|
||||
|
||||
# returns time in seconds
|
||||
import shelve
|
||||
logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
|
||||
key = str((lin.ast, lin.applied_opts))
|
||||
if should_copy and not disable_cache and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
|
||||
key = str((lin.ast, lin.applied_opts, allow_test_size, max_global_size))
|
||||
if should_copy and not disable_cache and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
if should_copy: lin = lin.copy() # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
@@ -57,8 +55,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
|
||||
#print(lin.ast)
|
||||
#print(lin.applied_opts)
|
||||
tms = [float('inf')]
|
||||
if logtm is not None: logtm[key] = tms
|
||||
return min(tms)
|
||||
return min(diskcache_put("time_linearizer", key, tms))
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
@@ -89,7 +86,12 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
pass
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin, rawbufs, amt):
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int) -> Linearizer:
|
||||
key = str((lin.ast, amt))
|
||||
if (val:=diskcache_get("beam_search", key)) is not None:
|
||||
ret = lin.copy()
|
||||
for o in val: ret.apply_opt(o)
|
||||
return ret
|
||||
best_tm = float('inf')
|
||||
beam: List[Linearizer] = [lin]
|
||||
while 1:
|
||||
@@ -99,5 +101,6 @@ def beam_search(lin, rawbufs, amt):
|
||||
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
|
||||
best_tm = opts[0][1]
|
||||
beam = [x[0] for x in opts[:amt]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
|
||||
diskcache_put("beam_search", key, beam[0].applied_opts)
|
||||
return beam[0]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile
|
||||
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
@@ -162,4 +162,33 @@ def cache_compiled(func):
|
||||
output_file.write_bytes(func(self, prg, *args, **kwargs))
|
||||
output_file.rename(cache_path)
|
||||
return cache_path.read_bytes()
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
if _db_connection is None: _db_connection = sqlite3.connect(getenv("CACHEDB", "/tmp/tinygrad_cache"))
|
||||
return _db_connection
|
||||
|
||||
def diskcache_get(table:str, key:str) -> Any:
|
||||
try:
|
||||
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE key=?", (key,))
|
||||
except sqlite3.OperationalError:
|
||||
return None # table doesn't exist
|
||||
if (val:=res.fetchone()) is not None:
|
||||
return pickle.loads(val[0])
|
||||
return None
|
||||
|
||||
_db_tables = set()
|
||||
def diskcache_put(table:str, key:str, value:Any):
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
if table not in _db_tables:
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} (key text PRIMARY KEY, val blob)")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO {table} (key, val) VALUES (?, ?)", (key, pickle.dumps(value)))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return value
|
||||
|
||||
@@ -266,17 +266,19 @@ class Compiled:
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if not getenv("NOOPT"):
|
||||
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM:
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
kb.dont_use_locals = bool(getenv("NOLOCALS"))
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
kb = beam_search(kb, rawbuffers, BEAM.value)
|
||||
baseline, beamtime = time_linearizer(k, rawbuffers, allow_test_size=False, disable_cache=True), time_linearizer(kb, rawbuffers, allow_test_size=False, disable_cache=True)
|
||||
if beamtime < baseline:
|
||||
if DEBUG >= 1: print(f"beam search {beamtime*1e6:<12.2f} beat baseline {baseline*1e6:<12.2f} by {baseline/beamtime:.2f}x")
|
||||
k = kb
|
||||
lins = [(f"beam{BEAM.value}", beam_search(kb, rawbuffers, BEAM.value)), (("tc" if used_tensor_cores else "hc"), k)]
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, rawbuffers, allow_test_size=False, disable_cache=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(25, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
return self.to_program(k)
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
|
||||
Reference in New Issue
Block a user