mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Add dictionary keys to reduce db size (#2131)
* work * ignore beam cache * dictionary keys are generic * minor db cleanups * fix baseline and extract dataset * fix training * log likelihood
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
import sys
|
||||
import sqlite3
|
||||
import sys, sqlite3, pickle
|
||||
|
||||
if __name__ == "__main__":
|
||||
fn = sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache"
|
||||
@@ -12,3 +11,10 @@ if __name__ == "__main__":
|
||||
cur2.execute(f"SELECT COUNT(*) FROM {table}")
|
||||
cnt = cur2.fetchone()[0]
|
||||
print(f"{table:20s} : {cnt}")
|
||||
|
||||
cur3 = conn.cursor()
|
||||
cur3.execute(f"SELECT * FROM {table} LIMIT 10")
|
||||
for f in cur3.fetchall():
|
||||
v = pickle.loads(f[-1])
|
||||
print(" ", len(f[0]), f[1:-1], v)
|
||||
#print(f"{len(k):10d}, {sk} -> {v}")
|
||||
|
||||
121
extra/optimization/extract_sa_pairs.py
Normal file
121
extra/optimization/extract_sa_pairs.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import sys, sqlite3, pickle, math
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm, trange
|
||||
import numpy as np
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
|
||||
# more stuff
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import actions
|
||||
from extra.optimization.helpers import lin_to_feats
|
||||
from extra.optimization.pretrain_valuenet import ValueNet
|
||||
from tinygrad.nn.optim import Adam
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
import random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
def dataset_from_cache(fn):
|
||||
conn = sqlite3.connect(fn)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT * FROM time_linearizer")
|
||||
grouped = defaultdict(dict)
|
||||
for f in tqdm(cur.fetchall()): grouped[f[0]][f[1:-1]] = pickle.loads(f[-1])
|
||||
|
||||
opts_to_outcome = {}
|
||||
|
||||
for ast,sk in grouped.items():
|
||||
cnts = defaultdict(int)
|
||||
for sks,tm in sk.items():
|
||||
if sks[1] != 1: continue
|
||||
opts = eval(sks[0])
|
||||
cnts[(len(opts), sks[1])] += 1
|
||||
opts_to_outcome[(ast, tuple(opts))] = tm
|
||||
#print(cnts)
|
||||
|
||||
S,A,V = [], [], []
|
||||
for ast,k in tqdm(opts_to_outcome):
|
||||
if len(k) == 0: continue
|
||||
old_tm = min(opts_to_outcome[(ast,k[:-1])])
|
||||
new_tm = min(opts_to_outcome[(ast,k)])
|
||||
if math.isinf(old_tm) or math.isinf(new_tm) or old_tm < 1e-9 or new_tm < 1e-9: continue
|
||||
lin = Linearizer(eval(ast))
|
||||
for opt in k[:-1]: lin.apply_opt(opt)
|
||||
act = k[-1]
|
||||
log_ratio = math.log(old_tm/new_tm)
|
||||
#print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")
|
||||
S.append(lin_to_feats(lin))
|
||||
A.append(actions.index(act))
|
||||
V.append([log_ratio]) # NOTE: i have written the bug many times with this having the wrong dim
|
||||
|
||||
S, A, V = np.array(S), np.array(A), np.array(V, dtype=np.float32)
|
||||
X = np.zeros((S.shape[0], S.shape[1]+len(actions)), dtype=np.float32)
|
||||
X[:, :S.shape[1]] = S
|
||||
X[range(S.shape[0]), S.shape[1]+A] = 1.0
|
||||
return X, V
|
||||
|
||||
def log_likelihood(x:Tensor, mu:Tensor, log_sigma:Tensor):
|
||||
#print(x.shape, mu.shape, log_sigma.shape)
|
||||
#return (x-mu).abs() * (-log_sigma).exp() + log_sigma
|
||||
return (x-mu).square() * (-2*log_sigma).exp() / 2 + log_sigma
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("REGEN"):
|
||||
X,V = dataset_from_cache(sys.argv[1] if len(sys.argv) > 1 else "/tmp/tinygrad_cache")
|
||||
safe_save({"X": Tensor(X), "V": Tensor(V)}, "/tmp/dataset")
|
||||
else:
|
||||
ld = safe_load("/tmp/dataset")
|
||||
X,V = ld['X'].numpy(), ld['V'].numpy()
|
||||
|
||||
#ratio = int(0.9*X.shape[0])
|
||||
ratio = -512
|
||||
X_test, V_test = Tensor(X[ratio:]), Tensor(V[ratio:])
|
||||
X,V = X[:ratio], V[:ratio]
|
||||
|
||||
#print(X[0], V[0])
|
||||
#print(X[-1], V[-1])
|
||||
print(X.shape)
|
||||
|
||||
net = ValueNet(X.shape[1], 2)
|
||||
optim = Adam(get_parameters(net))
|
||||
|
||||
def get_minibatch(X,Y,bs):
|
||||
xs, ys = [], []
|
||||
#random.seed(1337)
|
||||
for _ in range(bs):
|
||||
sel = random.randint(0, len(X)-1)
|
||||
xs.append(X[sel])
|
||||
ys.append(Y[sel])
|
||||
return Tensor(xs), Tensor(ys)
|
||||
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
losses = []
|
||||
test_losses = []
|
||||
test_loss = float('inf')
|
||||
for i in (t:=trange(1000)):
|
||||
x,y = get_minibatch(X,V,bs=256)
|
||||
out = net(x)
|
||||
#loss = (out-y).square().mean()
|
||||
loss = log_likelihood(y, out[:, 0:1], out[:, 1:2]).mean()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
|
||||
losses.append(loss.numpy().item())
|
||||
test_losses.append(test_loss)
|
||||
if i % 10: test_loss = (net(X_test)[:, 0:1]-V_test).square().mean().numpy().item()
|
||||
|
||||
safe_save(get_state_dict(net), "/tmp/qnet.safetensors")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(losses[20:])
|
||||
plt.plot(test_losses[20:])
|
||||
plt.show()
|
||||
@@ -21,11 +21,11 @@ from extra.optimization.helpers import lin_to_feats, MAX_DIMS
|
||||
# NOTE: this is not real value of the state, it's just a prediction of the runtime
|
||||
INNER = 512
|
||||
class ValueNet:
|
||||
def __init__(self):
|
||||
self.l1 = Linear(240,INNER)
|
||||
def __init__(self, feats=240, out=1):
|
||||
self.l1 = Linear(feats,INNER)
|
||||
self.l2 = Linear(INNER,INNER)
|
||||
self.l3 = Linear(INNER,INNER)
|
||||
self.l4 = Linear(INNER,1)
|
||||
self.l4 = Linear(INNER,out)
|
||||
def __call__(self, x):
|
||||
x = self.l1(x).relu()
|
||||
x = self.l2(x).relu()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import unittest
|
||||
import pickle
|
||||
from tinygrad.helpers import diskcache_get, diskcache_put
|
||||
|
||||
def remote_get(q,k): q.put(diskcache_get("test", k))
|
||||
@@ -43,5 +44,16 @@ class DiskCache(unittest.TestCase):
|
||||
self.assertEqual(diskcache_get("test", 4), 5)
|
||||
self.assertEqual(diskcache_get("test", "4"), 5)
|
||||
|
||||
def test_dict_key(self):
|
||||
fancy_key = {"hello": "world", "goodbye": 7, "good": True, "pkl": pickle.dumps("cat")}
|
||||
fancy_key2 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("cat")}
|
||||
fancy_key3 = {"hello": "world", "goodbye": 8, "good": True, "pkl": pickle.dumps("dog")}
|
||||
diskcache_put("test2", fancy_key, 5)
|
||||
self.assertEqual(diskcache_get("test2", fancy_key), 5)
|
||||
diskcache_put("test2", fancy_key2, 8)
|
||||
self.assertEqual(diskcache_get("test2", fancy_key2), 8)
|
||||
self.assertEqual(diskcache_get("test2", fancy_key), 5)
|
||||
self.assertEqual(diskcache_get("test2", fancy_key3), None)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Dict, List, cast, DefaultDict, Optional
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put
|
||||
from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, diskcache_get, diskcache_put, getenv
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
@@ -21,7 +21,7 @@ actions += [
|
||||
|
||||
# returns time in seconds
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True, disable_cache=False) -> float:
|
||||
key = str((lin.ast, lin.applied_opts, allow_test_size, max_global_size))
|
||||
key = {"ast": str(lin.ast), "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, "max_global_size": max_global_size}
|
||||
if should_copy and not disable_cache and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
if should_copy: lin = lin.copy() # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
@@ -87,12 +87,12 @@ def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Lineariz
|
||||
return acted_lins
|
||||
|
||||
def beam_search(lin:Linearizer, rawbufs, amt:int) -> Linearizer:
|
||||
key = str((lin.ast, amt))
|
||||
if (val:=diskcache_get("beam_search", key)) is not None:
|
||||
key = {"ast": str(lin.ast), "amt": amt}
|
||||
if (val:=diskcache_get("beam_search", key)) is not None and not getenv("IGNORE_BEAM_CACHE"):
|
||||
ret = lin.copy()
|
||||
for o in val: ret.apply_opt(o)
|
||||
return ret
|
||||
best_tm = float('inf')
|
||||
best_tm = time_linearizer(lin, rawbufs) # handle the case where no actions make it faster
|
||||
beam: List[Linearizer] = [lin]
|
||||
while 1:
|
||||
acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin in beam])
|
||||
|
||||
@@ -166,15 +166,26 @@ def cache_compiled(func):
|
||||
|
||||
# *** universal database cache ***
|
||||
|
||||
CACHEDB = getenv("CACHEDB", "/tmp/tinygrad_cache")
|
||||
VERSION = 2
|
||||
_db_connection = None
|
||||
def db_connection():
|
||||
global _db_connection
|
||||
if _db_connection is None: _db_connection = sqlite3.connect(getenv("CACHEDB", "/tmp/tinygrad_cache"))
|
||||
if _db_connection is None:
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 3: _db_connection.set_trace_callback(print)
|
||||
if diskcache_get("meta", "version") != VERSION:
|
||||
print("cache is out of date, clearing it")
|
||||
os.unlink(CACHEDB)
|
||||
_db_connection = sqlite3.connect(CACHEDB)
|
||||
if DEBUG >= 3: _db_connection.set_trace_callback(print)
|
||||
diskcache_put("meta", "version", VERSION)
|
||||
return _db_connection
|
||||
|
||||
def diskcache_get(table:str, key:str) -> Any:
|
||||
def diskcache_get(table:str, key:Union[Dict, str, int]) -> Any:
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
try:
|
||||
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE key=?", (key,))
|
||||
res = db_connection().cursor().execute(f"SELECT val FROM {table} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}", tuple(key.values()))
|
||||
except sqlite3.OperationalError:
|
||||
return None # table doesn't exist
|
||||
if (val:=res.fetchone()) is not None:
|
||||
@@ -182,13 +193,16 @@ def diskcache_get(table:str, key:str) -> Any:
|
||||
return None
|
||||
|
||||
_db_tables = set()
|
||||
def diskcache_put(table:str, key:str, value:Any):
|
||||
def diskcache_put(table:str, key:Union[Dict, str, int], val:Any):
|
||||
if isinstance(key, (str,int)): key = {"key": key}
|
||||
conn = db_connection()
|
||||
cur = conn.cursor()
|
||||
if table not in _db_tables:
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} (key text PRIMARY KEY, val blob)")
|
||||
TYPES = {str: "text", bool: "integer", int: "integer", float: "numeric", bytes: "blob"}
|
||||
ltypes = ', '.join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
|
||||
cur.execute(f"CREATE TABLE IF NOT EXISTS {table} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))")
|
||||
_db_tables.add(table)
|
||||
cur.execute(f"REPLACE INTO {table} (key, val) VALUES (?, ?)", (key, pickle.dumps(value)))
|
||||
cur.execute(f"REPLACE INTO {table} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)", tuple(key.values()) + (pickle.dumps(val), ))
|
||||
conn.commit()
|
||||
cur.close()
|
||||
return value
|
||||
return val
|
||||
|
||||
Reference in New Issue
Block a user