universal disk cache (#2130)

* caching infra for tinygrad

* nons tr key

* fix linter

* no shelve in beam search

* beam search caching

* check tensor cores with beam too

* pretty print

* LATEBEAM in stable diffusion
This commit is contained in:
George Hotz
2023-10-22 10:56:57 -07:00
committed by GitHub
parent ace6b2a151
commit 6dc8eb5bfd
8 changed files with 125 additions and 29 deletions

View File

@@ -9,7 +9,7 @@ from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
@@ -636,14 +636,15 @@ if __name__ == "__main__":
if args.seed is not None: Tensor._seed = args.seed
latent = Tensor.randn(1,4,64,64)
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step
with Context(BEAM=getenv("LATEBEAM")):
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)

14
extra/dump_cache.py Normal file
View File

@@ -0,0 +1,14 @@
import sys
import sqlite3
if __name__ == "__main__":
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT name FROM sqlite_master WHERE type='table'")
for f in cur.fetchall():
table = f[0]
cur2 = conn.cursor()
cur2.execute(f"SELECT COUNT(*) FROM {table}")
cnt = cur2.fetchone()[0]
print(f"{table:20s} : {cnt}")

View File

@@ -0,0 +1,47 @@
import unittest
from tinygrad.helpers import diskcache_get, diskcache_put
def remote_get(q,k): q.put(diskcache_get("test", k))
def remote_put(k,v): diskcache_put("test", k, v)
class DiskCache(unittest.TestCase):
def test_putget(self):
diskcache_put("test", "hello", "world")
self.assertEqual(diskcache_get("test", "hello"), "world")
diskcache_put("test", "hello", "world2")
self.assertEqual(diskcache_get("test", "hello"), "world2")
def test_putcomplex(self):
diskcache_put("test", "k", ("complex", 123, "object"))
ret = diskcache_get("test", "k")
self.assertEqual(ret, ("complex", 123, "object"))
def test_getotherprocess(self):
from multiprocessing import Process, Queue
diskcache_put("test", "k", "getme")
q = Queue()
p = Process(target=remote_get, args=(q,"k"))
p.start()
p.join()
self.assertEqual(q.get(), "getme")
def test_putotherprocess(self):
from multiprocessing import Process
p = Process(target=remote_put, args=("k", "remote"))
p.start()
p.join()
self.assertEqual(diskcache_get("test", "k"), "remote")
def test_no_table(self):
self.assertIsNone(diskcache_get("faketable", "k"))
def test_ret(self):
self.assertEqual(diskcache_put("test", "key", ("vvs",)), ("vvs",))
def test_non_str_key(self):
diskcache_put("test", 4, 5)
self.assertEqual(diskcache_get("test", 4), 5)
self.assertEqual(diskcache_get("test", "4"), 5)
if __name__ == "__main__":
unittest.main()

View File

@@ -164,8 +164,8 @@ class Kernel:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self, pad=None) -> str:
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def colored_shape(self, pad=None, dense=False) -> str:
ret = ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) and not dense else s for s in self.full_shape], self.colors()))
if pad: ret += ' '*(pad-ansilen(ret))
return ret
def printbufs(self, prefix=""):

View File

@@ -160,7 +160,7 @@ class OptimizedKernel(Kernel):
# ******************** high level optimizers ********************
# TODO: unify this
def apply_tensor_cores(self, use_tensor_cores=1):
def apply_tensor_cores(self, use_tensor_cores=1) -> bool:
# should use HIP tensor cores?
if use_tensor_cores != 0 and self.opts.device == "HIP" and self.reduceop and self.reduceop.op == ReduceOps.SUM and \
isinstance(self.reduceop.src[0], LazyOp) and self.reduceop.src[0].op == UnaryOps.CAST and \

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, cast, DefaultDict, Optional
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, getenv, ImageDType, flatten, DEBUG
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
@@ -20,11 +20,9 @@ actions += [
]
# returns time in seconds
import shelve
logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
key = str((lin.ast, lin.applied_opts))
if should_copy and not disable_cache and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects
key = str((lin.ast, lin.applied_opts, allow_test_size, max_global_size))
if should_copy and not disable_cache and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
if should_copy: lin = lin.copy() # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
@@ -57,8 +55,7 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru
#print(lin.ast)
#print(lin.applied_opts)
tms = [float('inf')]
if logtm is not None: logtm[key] = tms
return min(tms)
return min(diskcache_put("time_linearizer", key, tms))
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
@@ -89,7 +86,12 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
pass
return acted_lins
def beam_search(lin, rawbufs, amt):
def beam_search(lin:Linearizer, rawbufs, amt:int) -> Linearizer:
key = str((lin.ast, amt))
if (val:=diskcache_get("beam_search", key)) is not None:
ret = lin.copy()
for o in val: ret.apply_opt(o)
return ret
best_tm = float('inf')
beam: List[Linearizer] = [lin]
while 1:
@@ -99,5 +101,6 @@ def beam_search(lin, rawbufs, amt):
if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster
best_tm = opts[0][1]
beam = [x[0] for x in opts[:amt]]
if DEBUG >= 1: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
if DEBUG >= 2: print(f"{opts[0][1]*1e6:12.2f} us from {len(opts):3d} actions", beam[0].colored_shape())
diskcache_put("beam_search", key, beam[0].applied_opts)
return beam[0]

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile
import os, functools, platform, time, re, contextlib, operator, pathlib, hashlib, tempfile, pickle, sqlite3
import numpy as np
from typing import Dict, Tuple, Union, List, NamedTuple, Final, Iterator, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
@@ -162,4 +162,33 @@ def cache_compiled(func):
output_file.write_bytes(func(self, prg, *args, **kwargs))
output_file.rename(cache_path)
return cache_path.read_bytes()
return wrapper
return wrapper
# *** universal database cache ***
_db_connection = None
def db_connection():
global _db_connection
if _db_connection is None: _db_connection = sqlite3.connect(getenv("CACHEDB", "/tmp/tinygrad_cache"))
return _db_connection
def diskcache_get(table:str, key:str) -> Any:
try:
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE key=?", (key,))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None:
return pickle.loads(val[0])
return None
_db_tables = set()
def diskcache_put(table:str, key:str, value:Any):
conn = db_connection()
cur = conn.cursor()
if table not in _db_tables:
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} (key text PRIMARY KEY, val blob)")
_db_tables.add(table)
cur.execute(f"REPLACE INTO {table} (key, val) VALUES (?, ?)", (key, pickle.dumps(value)))
conn.commit()
cur.close()
return value

View File

@@ -266,17 +266,19 @@ class Compiled:
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not getenv("NOOPT"):
if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations()
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM:
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
kb.dont_use_locals = bool(getenv("NOLOCALS"))
from tinygrad.features.search import beam_search, time_linearizer
kb = beam_search(kb, rawbuffers, BEAM.value)
baseline, beamtime = time_linearizer(k, rawbuffers, allow_test_size=False, disable_cache=True), time_linearizer(kb, rawbuffers, allow_test_size=False, disable_cache=True)
if beamtime < baseline:
if DEBUG >= 1: print(f"beam search {beamtime*1e6:<12.2f} beat baseline {baseline*1e6:<12.2f} by {baseline/beamtime:.2f}x")
k = kb
lins = [(f"beam{BEAM.value}", beam_search(kb, rawbuffers, BEAM.value)), (("tc" if used_tensor_cores else "hc"), k)]
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, rawbuffers, allow_test_size=False, disable_cache=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(25, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
return self.to_program(k)
if getenv("ENABLE_METHOD_CACHE", 1):