Add dictionary keys to reduce db size (#2131)

* work

* ignore beam cache

* dictionary keys are generic

* minor db cleanups

* fix baseline and extract dataset

* fix training

* log likelihood
This commit is contained in:
George Hotz
2023-10-24 10:49:22 -04:00
committed by GitHub
parent d5e2fdea22
commit cea2bc7964
6 changed files with 170 additions and 17 deletions

View File

@@ -1,5 +1,4 @@
import sys
import sqlite3
import sys, sqlite3, pickle
if __name__ == "__main__":
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
@@ -12,3 +11,10 @@ if __name__ == "__main__":
cur2.execute(f"SELECT COUNT(*) FROM {table}")
cnt = cur2.fetchone()[0]
print(f"{table:20s} : {cnt}")
cur3 = conn.cursor()
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
for f in cur3.fetchall():
v = pickle.loads(f[-1])
print(" ", len(f[0]), f[1:-1], v)
#print(f"{len(k):10d}, {sk} -> {v}")

View File

@@ -0,0 +1,121 @@
import sys, sqlite3, pickle, math
from collections import defaultdict
from tqdm import tqdm, trange
import numpy as np
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.optimizer import Opt, OptOps
# more stuff
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import actions
from extra.optimization.helpers import lin_to_feats
from extra.optimization.pretrain_valuenet import ValueNet
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
import random
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv
def dataset_from_cache(fn):
conn = sqlite3.connect(fn)
cur = conn.cursor()
cur.execute("SELECT * FROM time_linearizer")
grouped = defaultdict(dict)
for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1])
opts_to_outcome = {}
for ast,sk in grouped.items():
cnts = defaultdict(int)
for sks,tm in sk.items():
if sks[1] != 1: continue
opts = eval(sks[0])
cnts[(len(opts), sks[1])] += 1
opts_to_outcome[(ast, tuple(opts))] = tm
#print(cnts)
S,A,V = [], [], []
for ast,k in tqdm(opts_to_outcome):
if len(k) == 0: continue
old_tm = min(opts_to_outcome[(ast,k[:-1])])
new_tm = min(opts_to_outcome[(ast,k)])
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
lin = Linearizer(eval(ast))
for opt in k[:-1]: lin.apply_opt(opt)
act = k[-1]
log_ratio = math.log(old_tm/new_tm)
#print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
S.append(lin_to_feats(lin))
A.append(actions.index(act))
V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim
S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32)
X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32)
X[:, :S.shape[1]] = S
X[range(S.shape[0]), S.shape[1]+A] = 1.0
return X, V
def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor):
#print(x.shape, mu.shape, log_sigma.shape)
#return (x-mu).abs() * (-log_sigma).exp() + log_sigma
return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma
if __name__ == "__main__":
if getenv("REGEN"):
X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset")
else:
ld = safe_load("/tmp/dataset")
X,V = ld['X'].numpy(), ld['V'].numpy()
#ratio = int(0.9*X.shape[0])
ratio = -512
X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
X,V = X[:ratio], V[:ratio]
#print(X[0], V[0])
#print(X[-1], V[-1])
print(X.shape)
net = ValueNet(X.shape[1], 2)
optim = Adam(get_parameters(net))
def get_minibatch(X,Y,bs):
xs, ys = [], []
#random.seed(1337)
for _ in range(bs):
sel = random.randint(0, len(X)-1)
xs.append(X[sel])
ys.append(Y[sel])
return Tensor(xs), Tensor(ys)
Tensor.no_grad, Tensor.training = False, True
losses = []
test_losses = []
test_loss = float('inf')
for i in (t:=trange(1000)):
x,y = get_minibatch(X,V,bs=256)
out = net(x)
#loss = (out-y).square().mean()
loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean()
optim.zero_grad()
loss.backward()
optim.step()
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
losses.append(loss.numpy().item())
test_losses.append(test_loss)
if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item()
safe_save(get_state_dict(net), "/tmp/qnet.safetensors")
import matplotlib.pyplot as plt
plt.plot(losses[20:])
plt.plot(test_losses[20:])
plt.show()

View File

@@ -21,11 +21,11 @@ from extra.optimization.helpers import lin_to_feats, MAX_DIMS
# NOTE: this is not real value of the state, it's just a prediction of the runtime
INNER = 512
class ValueNet:
def __init__(self):
self.l1 = Linear(240,INNER)
def __init__(self, feats=240, out=1):
self.l1 = Linear(feats,INNER)
self.l2 = Linear(INNER,INNER)
self.l3 = Linear(INNER,INNER)
self.l4 = Linear(INNER,1)
self.l4 = Linear(INNER,out)
def __call__(self, x):
x = self.l1(x).relu()
x = self.l2(x).relu()

View File

@@ -1,4 +1,5 @@
import unittest
import pickle
from tinygrad.helpers import diskcache_get, diskcache_put
def remote_get(q,k): q.put(diskcache_get("test", k))
@@ -43,5 +44,16 @@ class DiskCache(unittest.TestCase):
self.assertEqual(diskcache_get("test", 4), 5)
self.assertEqual(diskcache_get("test", "4"), 5)
def test_dict_key(self):
fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
diskcache_put("test2", fancy_key, 5)
self.assertEqual(diskcache_get("test2", fancy_key), 5)
diskcache_put("test2", fancy_key2, 8)
self.assertEqual(diskcache_get("test2", fancy_key2), 8)
self.assertEqual(diskcache_get("test2", fancy_key), 5)
self.assertEqual(diskcache_get("test2", fancy_key3), None)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,7 +1,7 @@
from typing import Dict, List, cast, DefaultDict, Optional
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put, getenv
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
@@ -21,7 +21,7 @@ actions += [
# returns time in seconds
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
key = str((lin.ast, lin.applied_opts, allow_test_size, max_global_size))
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
if should_copy and not disable_cache and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
if should_copy: lin = lin.copy() # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
@@ -87,12 +87,12 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
return acted_lins
def beam_search(lin:Linearizer, rawbufs, amt:int) -> Linearizer:
key = str((lin.ast, amt))
if (val:=diskcache_get("beam_search", key)) is not None:
key = {"ast": str(lin.ast), "amt": amt}
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE"):
ret = lin.copy()
for o in val: ret.apply_opt(o)
return ret
best_tm = float('inf')
best_tm = time_linearizer(lin, rawbufs) # handle the case where no actions make it faster
beam: List[Linearizer] = [lin]
while 1:
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin in beam])

View File

@@ -166,15 +166,26 @@ def cache_compiled(func):
# *** universal database cache ***
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
VERSION = 2
_db_connection = None
def db_connection():
global _db_connection
if _db_connection is None: _db_connection = sqlite3.connect(getenv("CACHEDB", "/tmp/tinygrad_cache"))
if _db_connection is None:
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 3: _db_connection.set_trace_callback(print)
if diskcache_get("meta", "version") != VERSION:
print("cache is out of date, clearing it")
os.unlink(CACHEDB)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 3: _db_connection.set_trace_callback(print)
diskcache_put("meta", "version", VERSION)
return _db_connection
def diskcache_get(table:str, key:str) -> Any:
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
if isinstance(key, (str,int)): key = {"key": key}
try:
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE key=?", (key,))
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
except sqlite3.OperationalError:
return None # table doesn't exist
if (val:=res.fetchone()) is not None:
@@ -182,13 +193,16 @@ def diskcache_get(table:str, key:str) -> Any:
return None
_db_tables = set()
def diskcache_put(table:str, key:str, value:Any):
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
if isinstance(key, (str,int)): key = {"key": key}
conn = db_connection()
cur = conn.cursor()
if table not in _db_tables:
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} (key text PRIMARY KEY, val blob)")
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
_db_tables.add(table)
cur.execute(f"REPLACE INTO {table} (key, val) VALUES (?, ?)", (key, pickle.dumps(value)))
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
conn.commit()
cur.close()
return value
return val