train value net, improve API, add BCE (#2047)

* api cleanups, BCE losses

* valuenet

* fixup examples

* learning okay

* add valuenet runner

* net improvements

* net improvements

* 40% win rate
This commit is contained in:
George Hotz
2023-10-12 07:56:38 -07:00
committed by GitHub
parent 0ba629c7b9
commit c5edb3c374
11 changed files with 212 additions and 121 deletions

View File

@@ -6,6 +6,8 @@ from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.graph import print_tree
from tinygrad.lazy import vars_from_ast
from tinygrad.shape.symbolic import sym_infer
import shelve
global_db = shelve.open("/tmp/greedy_cache")
@@ -59,10 +61,9 @@ if __name__ == "__main__":
else:
while 1:
acted_lins = get_linearizer_actions(lin)
tm, gflops = time_linearizer(lin, rawbufs)
timed_lins = {k:time_linearizer(v, rawbufs)[0] for k,v in acted_lins.items()}
timed_lins = {k:time_linearizer(v, rawbufs) for k,v in acted_lins.items()}
opts = sorted(timed_lins.items(), key=lambda x: x[1])
if len(opts) == 0 or opts[0][1] >= tm: break # we are done
if opts[0][0] == 0: break # we are done
lin = acted_lins[opts[0][0]]
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", lin.colored_shape())
global_db[str(lin.ast)] = lin.applied_opts
@@ -71,7 +72,8 @@ if __name__ == "__main__":
# benchmark the programs
choices = []
for lin in lins:
tm, gflops = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False)
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False)
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
choices.append((tm, gflops, lin))
# print all kernels

View File

@@ -1,43 +0,0 @@
import numpy as np
import random
from copy import deepcopy
from tqdm import tqdm
from tinygrad.helpers import dedup, ansilen
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.tensor import Tensor
from tinygrad.codegen.search import get_linearizer_actions, time_linearizer, bufs_from_lin, actions
#from extra.optimization.pretrain import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
if __name__ == "__main__":
# load worlds
ast_strs = load_worlds()
for ep_num,ast_str in enumerate(ast_strs):
print("\nEPISODE", ep_num)
lin = ast_str_to_lin(ast_str)
linhc = deepcopy(lin)
linhc.hand_coded_optimizations()
if not all(x in actions for x in linhc.applied_opts):
print("skipping", linhc.colored_shape())
continue
rawbufs = bufs_from_lin(lin)
tm1, gf1 = time_linearizer(linhc, rawbufs)
print(f"{tm1:10.2f}", linhc.colored_shape(), f"with {len(linhc.applied_opts)} actions from {len(actions)} action space")
while 1:
tm, gflops = time_linearizer(lin, rawbufs)
print(f"{tm:10.2f}", lin.colored_shape())
acted_lins = get_linearizer_actions(lin)
if len(acted_lins) == 0: break
best_tm, best_lin = tm, lin
for l in list(acted_lins.values()):
tm, gflops = time_linearizer(l, rawbufs)
if tm < best_tm: best_tm, best_lin = tm, l
if lin == best_lin: break
lin = best_lin

View File

@@ -35,6 +35,8 @@ from tinygrad.shape.symbolic import Node
MAX_DIMS = 16
def lin_to_feats(lin):
assert lin.shape_len < MAX_DIMS, "too many dims"
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
lc = [all_colors.index(x) for x in lin.colors()]
#my_sts = dedup([(x.shape == lin.full_shape, x.real_strides()) for x in lin.sts[1:]])

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_lo
from tinygrad.codegen.search import actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
INNER = 256
INNER = 32
class PolicyNet:
def __init__(self):
self.l1 = Linear(240,INNER)

View File

@@ -0,0 +1,88 @@
from tinygrad.codegen.linearizer import Linearizer
from tqdm import tqdm, trange
import math
import random
from tinygrad.tensor import Tensor
from tinygrad.nn import Linear
from tinygrad.nn.optim import Adam
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
# stuff needed to unpack a kernel
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
from tinygrad.helpers import dtypes
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import Variable
inf, nan = float('inf'), float('nan')
from tinygrad.codegen.optimizer import Opt, OptOps
from extra.optimization.helpers import lin_to_feats, MAX_DIMS
# NOTE: this is not real value of the state, it's just a prediction of the runtime
INNER = 512
class ValueNet:
def __init__(self):
self.l1 = Linear(240,INNER)
self.l2 = Linear(INNER,INNER)
self.l3 = Linear(INNER,INNER)
self.l4 = Linear(INNER,1)
def __call__(self, x):
x = self.l1(x).relu()
x = self.l2(x).relu()
x = self.l3(x).relu()
return self.l4(x)
if __name__ == "__main__":
net = ValueNet()
optim = Adam(get_parameters(net))
TEST_SIZE = 256
dset = open("/tmp/logtm").read().strip().split("\n")
random.seed(1337)
random.shuffle(dset)
X,Y = [], []
for i,x in enumerate(tqdm(dset)):
ast, opts, tms = eval(x)
lin = Linearizer(ast)
for o in opts: lin.apply_opt(o)
if lin.shape_len >= MAX_DIMS: continue
if min(tms) == float('inf'): continue
X.append(lin_to_feats(lin))
Y.append([math.log(min(tms))])
print(f"got {len(X)} samples")
X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:])
X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE]
def get_minibatch(X,Y,bs):
xs, ys = [], []
for _ in range(bs):
sel = random.randint(0, len(X)-1)
xs.append(X[sel])
ys.append(Y[sel])
return Tensor(xs), Tensor(ys)
Tensor.no_grad, Tensor.training = False, True
losses = []
test_losses = []
test_loss = float('inf')
for i in (t:=trange(2000)):
x,y = get_minibatch(X,Y,bs=256)
out = net(x)
loss = (out-y).square().mean()
optim.zero_grad()
loss.backward()
optim.step()
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
losses.append(loss.numpy().item())
test_losses.append(test_loss)
if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item()
safe_save(get_state_dict(net), "/tmp/valuenet.safetensors")
import matplotlib.pyplot as plt
plt.plot(losses[200:])
plt.plot(test_losses[200:])
plt.show()

View File

@@ -3,9 +3,9 @@ import numpy as np
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions
from tinygrad.nn.optim import Adam
from extra.optimization.pretrain import PolicyNet
from extra.optimization.pretrain_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
if __name__ == "__main__":
@@ -21,7 +21,7 @@ if __name__ == "__main__":
Tensor.no_grad, Tensor.training = True, False
lin = ast_str_to_lin(random.choice(ast_strs))
rawbufs = bufs_from_lin(lin)
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)[0]
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)
# take actions
feats, acts, rews = [], [], []
@@ -29,6 +29,13 @@ if __name__ == "__main__":
feat = lin_to_feats(lin)
feats.append(feat)
probs = net(Tensor([feat])).exp()[0].numpy()
# mask valid actions
valid_action_mask = np.zeros((len(actions)+1), dtype=np.float32)
for x in get_linearizer_actions(lin): valid_action_mask[x] = 1
probs *= valid_action_mask
probs /= sum(probs)
act = np.random.choice(len(probs), p=probs)
acts.append(act)
if act == 0:
@@ -36,7 +43,7 @@ if __name__ == "__main__":
break
try:
lin.apply_opt(actions[act-1])
tm = time_linearizer(lin, rawbufs)[0]
tm = time_linearizer(lin, rawbufs)
if math.isinf(tm): raise Exception("failed")
rews.append(((last_tm-tm)/base_tm))
last_tm = tm
@@ -50,7 +57,9 @@ if __name__ == "__main__":
print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}")
all_feats += feats
all_acts += acts
all_rews += np.cumsum(rews).tolist()
# rewards to go
for i in range(len(rews)-2, -1, -1): rews[i] += rews[i+1]
all_rews += rews
BS = 32
if len(all_feats) >= BS:
@@ -65,6 +74,3 @@ if __name__ == "__main__":
all_feats = all_feats[BS:]
all_acts = all_acts[BS:]
all_rews = all_rews[BS:]
#print(rews)

View File

@@ -0,0 +1,66 @@
import numpy as np
import math
import random
np.set_printoptions(suppress=True)
from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.pretrain_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet
VALUE = getenv("VALUE")
if __name__ == "__main__":
if VALUE:
net = ValueNet()
load_state_dict(net, safe_load("/tmp/valuenet.safetensors"))
else:
net = PolicyNet()
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
ast_strs = load_worlds()
# real randomness
random.seed()
random.shuffle(ast_strs)
wins = 0
for ep_num,ast_str in enumerate(ast_strs):
print("\nEPISODE", ep_num, f"win {wins*100/max(1,ep_num):.2f}%")
lin = ast_str_to_lin(ast_str)
rawbufs = bufs_from_lin(lin)
linhc = deepcopy(lin)
linhc.hand_coded_optimizations()
tmhc = time_linearizer(linhc, rawbufs)
print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape())
pred_time = float('nan')
tm = float('inf')
while 1:
if VALUE:
acts,feats = [], []
for k,v in get_linearizer_actions(lin).items():
acts.append(k)
feats.append(lin_to_feats(v))
preds = net(Tensor(feats))
pred_time = math.exp(preds.numpy().min())
act = acts[preds.numpy().argmin()]
else:
probs = net(Tensor([lin_to_feats(lin)]))
dist = probs.exp().numpy()
act = dist.argmax()
if act == 0: break
try:
lin.apply_opt(actions[act-1])
except Exception:
print("FAILED")
break
tm = time_linearizer(lin, rawbufs)
print(f"{tm*1e6:10.2f} {pred_time*1e6:10.2f}", lin.colored_shape())
print(f"{colored('BEAT', 'green') if tm < tmhc else colored('lost', 'red')} hand coded {tmhc/tm:5.2f}x")
wins += int(tm < tmhc)

View File

@@ -1,37 +0,0 @@
import numpy as np
np.set_printoptions(suppress=True)
from copy import deepcopy
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.pretrain import PolicyNet
if __name__ == "__main__":
net = PolicyNet()
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
ast_strs = load_worlds()
for ep_num,ast_str in enumerate(ast_strs):
print("\nEPISODE", ep_num)
lin = ast_str_to_lin(ast_str)
rawbufs = bufs_from_lin(lin)
linhc = deepcopy(lin)
linhc.hand_coded_optimizations()
tm, gflops = time_linearizer(linhc, rawbufs)
print(f"{tm:10.2f}", linhc.colored_shape())
while 1:
probs = net(Tensor([lin_to_feats(lin)]))
dist = probs.exp().numpy()
act = dist.argmax()
if act == 0: break
try:
lin.apply_opt(actions[act-1])
except Exception:
print("FAILED")
break
tm, gflops = time_linearizer(lin, rawbufs)
print(f"{tm:10.2f}", lin.colored_shape())

View File

@@ -1215,6 +1215,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
def test_binary_crossentropy(self):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)

View File

@@ -1,9 +1,8 @@
from typing import Tuple, Dict, List, cast, DefaultDict, Optional
from typing import Dict, List, cast, DefaultDict, Optional
from copy import deepcopy
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import Device, Compiled, MemBuffer
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import sym_infer
from tinygrad.helpers import prod, getenv
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.runtime.lib import RawBuffer
from collections import defaultdict
@@ -29,39 +28,35 @@ actions = [
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)]
device:Compiled = cast(Compiled, Device[Device.DEFAULT])
# returns time(s) and GFLOPS
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> Tuple[float, float]:
# returns time in seconds
logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> float:
if should_copy: lin = deepcopy(lin) # TODO: remove the need for this
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
try:
lin.linearize()
prg = device.to_program(lin)
real_global_size = prg.global_size[:]
prg.global_size = [1,1,1]
tm = prg(rawbufs, var_vals, force_wait=True)
if allow_test_size:
test_global_size = prg.global_size[:]
while prod(test_global_size) > 16384:
for j in range(2,-1,-1):
if test_global_size[j] > 1:
test_global_size[j] //= 2
break
factor = prod(prg.global_size) / prod(test_global_size)
prg.global_size = test_global_size
else:
factor = 1
tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)]
prg.global_size = real_global_size
except Exception:
print("FAILED")
print(lin.ast)
print(lin.applied_opts)
return float('inf'), 0
if allow_test_size:
test_global_size = real_global_size[:]
while prod(test_global_size) > 16384:
for j in range(2,-1,-1):
if test_global_size[j] > 1:
test_global_size[j] //= 2
break
factor = prod(real_global_size) / prod(test_global_size)
prg.global_size = test_global_size
else:
prg.global_size = real_global_size
factor = 1
tm = min([prg(rawbufs, var_vals, force_wait=True) for _ in range(cnt)])
tm *= factor
gflops = sym_infer(lin.info.flops, var_vals)*1e-9/tm
return tm, gflops
tms = [float('inf')]
if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n")
return min(tms)
# get (scrap) buffers for timing the linearizer
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
@@ -75,7 +70,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
# get dictionary of all possible actions
def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
acted_lins = {}
acted_lins = {0:deepcopy(lin)}
for i,a in enumerate(actions):
lin2 = deepcopy(lin)
try:
@@ -85,7 +80,7 @@ def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
if c in {"magenta", "yellow"}: up *= s
if c in {"cyan", "green", "white"}: lcl *= s
if up > 256 or lcl > 256: continue
acted_lins[i] = lin2
acted_lins[i+1] = lin2
except Exception:
pass
return acted_lins

View File

@@ -734,6 +734,12 @@ class Tensor:
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
def binary_crossentropy(self, y:Tensor) -> Tensor:
return (-y*self.log() - (1-y)*(1-self).log()).mean()
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
loss_mask = Y != ignore_index
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])