mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
train value net, improve API, add BCE (#2047)
* api cleanups, BCE losses * valuenet * fixup examples * learning okay * add valuenet runner * net improvements * net improvements * 40% win rate
This commit is contained in:
@@ -6,6 +6,8 @@ from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv
|
||||
from tinygrad.graph import print_tree
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
import shelve
|
||||
global_db = shelve.open("/tmp/greedy_cache")
|
||||
@@ -59,10 +61,9 @@ if __name__ == "__main__":
|
||||
else:
|
||||
while 1:
|
||||
acted_lins = get_linearizer_actions(lin)
|
||||
tm, gflops = time_linearizer(lin, rawbufs)
|
||||
timed_lins = {k:time_linearizer(v, rawbufs)[0] for k,v in acted_lins.items()}
|
||||
timed_lins = {k:time_linearizer(v, rawbufs) for k,v in acted_lins.items()}
|
||||
opts = sorted(timed_lins.items(), key=lambda x: x[1])
|
||||
if len(opts) == 0 or opts[0][1] >= tm: break # we are done
|
||||
if opts[0][0] == 0: break # we are done
|
||||
lin = acted_lins[opts[0][0]]
|
||||
if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", lin.colored_shape())
|
||||
global_db[str(lin.ast)] = lin.applied_opts
|
||||
@@ -71,7 +72,8 @@ if __name__ == "__main__":
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for lin in lins:
|
||||
tm, gflops = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False)
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, should_copy=False)
|
||||
gflops = sym_infer(lin.info.flops, {k:k.min for k in vars_from_ast(lin.ast)})*1e-9/tm
|
||||
choices.append((tm, gflops, lin))
|
||||
|
||||
# print all kernels
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import numpy as np
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from tqdm import tqdm
|
||||
from tinygrad.helpers import dedup, ansilen
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
from tinygrad.codegen.search import get_linearizer_actions, time_linearizer, bufs_from_lin, actions
|
||||
|
||||
#from extra.optimization.pretrain import PolicyNet
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load worlds
|
||||
ast_strs = load_worlds()
|
||||
for ep_num,ast_str in enumerate(ast_strs):
|
||||
print("\nEPISODE", ep_num)
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
|
||||
linhc = deepcopy(lin)
|
||||
linhc.hand_coded_optimizations()
|
||||
if not all(x in actions for x in linhc.applied_opts):
|
||||
print("skipping", linhc.colored_shape())
|
||||
continue
|
||||
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
tm1, gf1 = time_linearizer(linhc, rawbufs)
|
||||
print(f"{tm1:10.2f}", linhc.colored_shape(), f"with {len(linhc.applied_opts)} actions from {len(actions)} action space")
|
||||
|
||||
while 1:
|
||||
tm, gflops = time_linearizer(lin, rawbufs)
|
||||
print(f"{tm:10.2f}", lin.colored_shape())
|
||||
acted_lins = get_linearizer_actions(lin)
|
||||
if len(acted_lins) == 0: break
|
||||
|
||||
best_tm, best_lin = tm, lin
|
||||
for l in list(acted_lins.values()):
|
||||
tm, gflops = time_linearizer(l, rawbufs)
|
||||
if tm < best_tm: best_tm, best_lin = tm, l
|
||||
if lin == best_lin: break
|
||||
lin = best_lin
|
||||
@@ -35,6 +35,8 @@ from tinygrad.shape.symbolic import Node
|
||||
|
||||
MAX_DIMS = 16
|
||||
def lin_to_feats(lin):
|
||||
assert lin.shape_len < MAX_DIMS, "too many dims"
|
||||
|
||||
all_colors = ["blue", "cyan", "white", "green", "red", "magenta", "yellow"]
|
||||
lc = [all_colors.index(x) for x in lin.colors()]
|
||||
#my_sts = dedup([(x.shape == lin.full_shape, x.real_strides()) for x in lin.sts[1:]])
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_lo
|
||||
from tinygrad.codegen.search import actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin
|
||||
|
||||
INNER = 256
|
||||
INNER = 32
|
||||
class PolicyNet:
|
||||
def __init__(self):
|
||||
self.l1 = Linear(240,INNER)
|
||||
88
extra/optimization/pretrain_valuenet.py
Normal file
88
extra/optimization/pretrain_valuenet.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tqdm import tqdm, trange
|
||||
import math
|
||||
import random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Linear
|
||||
from tinygrad.nn.optim import Adam
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.optimizer import Opt, OptOps
|
||||
|
||||
from extra.optimization.helpers import lin_to_feats, MAX_DIMS
|
||||
|
||||
# NOTE: this is not real value of the state, it's just a prediction of the runtime
|
||||
INNER = 512
|
||||
class ValueNet:
|
||||
def __init__(self):
|
||||
self.l1 = Linear(240,INNER)
|
||||
self.l2 = Linear(INNER,INNER)
|
||||
self.l3 = Linear(INNER,INNER)
|
||||
self.l4 = Linear(INNER,1)
|
||||
def __call__(self, x):
|
||||
x = self.l1(x).relu()
|
||||
x = self.l2(x).relu()
|
||||
x = self.l3(x).relu()
|
||||
return self.l4(x)
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = ValueNet()
|
||||
optim = Adam(get_parameters(net))
|
||||
|
||||
TEST_SIZE = 256
|
||||
|
||||
dset = open("/tmp/logtm").read().strip().split("\n")
|
||||
random.seed(1337)
|
||||
random.shuffle(dset)
|
||||
|
||||
X,Y = [], []
|
||||
for i,x in enumerate(tqdm(dset)):
|
||||
ast, opts, tms = eval(x)
|
||||
lin = Linearizer(ast)
|
||||
for o in opts: lin.apply_opt(o)
|
||||
if lin.shape_len >= MAX_DIMS: continue
|
||||
if min(tms) == float('inf'): continue
|
||||
X.append(lin_to_feats(lin))
|
||||
Y.append([math.log(min(tms))])
|
||||
print(f"got {len(X)} samples")
|
||||
|
||||
X_test,Y_test = Tensor(X[-TEST_SIZE:]), Tensor(Y[-TEST_SIZE:])
|
||||
X,Y = X[:-TEST_SIZE], Y[:-TEST_SIZE]
|
||||
|
||||
def get_minibatch(X,Y,bs):
|
||||
xs, ys = [], []
|
||||
for _ in range(bs):
|
||||
sel = random.randint(0, len(X)-1)
|
||||
xs.append(X[sel])
|
||||
ys.append(Y[sel])
|
||||
return Tensor(xs), Tensor(ys)
|
||||
|
||||
Tensor.no_grad, Tensor.training = False, True
|
||||
losses = []
|
||||
test_losses = []
|
||||
test_loss = float('inf')
|
||||
for i in (t:=trange(2000)):
|
||||
x,y = get_minibatch(X,Y,bs=256)
|
||||
out = net(x)
|
||||
loss = (out-y).square().mean()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
t.set_description(f"loss {loss.numpy():7.2f}, test loss {test_loss:7.2f}")
|
||||
losses.append(loss.numpy().item())
|
||||
test_losses.append(test_loss)
|
||||
if i % 10: test_loss = (net(X_test)-Y_test).square().mean().numpy().item()
|
||||
|
||||
safe_save(get_state_dict(net), "/tmp/valuenet.safetensors")
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
plt.plot(losses[200:])
|
||||
plt.plot(test_losses[200:])
|
||||
plt.show()
|
||||
@@ -3,9 +3,9 @@ import numpy as np
|
||||
import math, random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer
|
||||
from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.optimization.pretrain import PolicyNet
|
||||
from extra.optimization.pretrain_policynet import PolicyNet
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
||||
Tensor.no_grad, Tensor.training = True, False
|
||||
lin = ast_str_to_lin(random.choice(ast_strs))
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)[0]
|
||||
tm = last_tm = base_tm = time_linearizer(lin, rawbufs)
|
||||
|
||||
# take actions
|
||||
feats, acts, rews = [], [], []
|
||||
@@ -29,6 +29,13 @@ if __name__ == "__main__":
|
||||
feat = lin_to_feats(lin)
|
||||
feats.append(feat)
|
||||
probs = net(Tensor([feat])).exp()[0].numpy()
|
||||
|
||||
# mask valid actions
|
||||
valid_action_mask = np.zeros((len(actions)+1), dtype=np.float32)
|
||||
for x in get_linearizer_actions(lin): valid_action_mask[x] = 1
|
||||
probs *= valid_action_mask
|
||||
probs /= sum(probs)
|
||||
|
||||
act = np.random.choice(len(probs), p=probs)
|
||||
acts.append(act)
|
||||
if act == 0:
|
||||
@@ -36,7 +43,7 @@ if __name__ == "__main__":
|
||||
break
|
||||
try:
|
||||
lin.apply_opt(actions[act-1])
|
||||
tm = time_linearizer(lin, rawbufs)[0]
|
||||
tm = time_linearizer(lin, rawbufs)
|
||||
if math.isinf(tm): raise Exception("failed")
|
||||
rews.append(((last_tm-tm)/base_tm))
|
||||
last_tm = tm
|
||||
@@ -50,7 +57,9 @@ if __name__ == "__main__":
|
||||
print(f"***** EPISODE {len(rews)} steps, {sum(rews):5.2f} reward, {base_tm*1e6:12.2f} -> {tm*1e6:12.2f} : {lin.colored_shape()}")
|
||||
all_feats += feats
|
||||
all_acts += acts
|
||||
all_rews += np.cumsum(rews).tolist()
|
||||
# rewards to go
|
||||
for i in range(len(rews)-2, -1, -1): rews[i] += rews[i+1]
|
||||
all_rews += rews
|
||||
|
||||
BS = 32
|
||||
if len(all_feats) >= BS:
|
||||
@@ -65,6 +74,3 @@ if __name__ == "__main__":
|
||||
all_feats = all_feats[BS:]
|
||||
all_acts = all_acts[BS:]
|
||||
all_rews = all_rews[BS:]
|
||||
|
||||
#print(rews)
|
||||
|
||||
|
||||
66
extra/optimization/test_net.py
Normal file
66
extra/optimization/test_net.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import numpy as np
|
||||
import math
|
||||
import random
|
||||
np.set_printoptions(suppress=True)
|
||||
from copy import deepcopy
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
from extra.optimization.pretrain_policynet import PolicyNet
|
||||
from extra.optimization.pretrain_valuenet import ValueNet
|
||||
|
||||
VALUE = getenv("VALUE")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if VALUE:
|
||||
net = ValueNet()
|
||||
load_state_dict(net, safe_load("/tmp/valuenet.safetensors"))
|
||||
else:
|
||||
net = PolicyNet()
|
||||
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
|
||||
ast_strs = load_worlds()
|
||||
|
||||
# real randomness
|
||||
random.seed()
|
||||
random.shuffle(ast_strs)
|
||||
|
||||
wins = 0
|
||||
for ep_num,ast_str in enumerate(ast_strs):
|
||||
print("\nEPISODE", ep_num, f"win {wins*100/max(1,ep_num):.2f}%")
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
|
||||
linhc = deepcopy(lin)
|
||||
linhc.hand_coded_optimizations()
|
||||
tmhc = time_linearizer(linhc, rawbufs)
|
||||
print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape())
|
||||
|
||||
pred_time = float('nan')
|
||||
tm = float('inf')
|
||||
while 1:
|
||||
if VALUE:
|
||||
acts,feats = [], []
|
||||
for k,v in get_linearizer_actions(lin).items():
|
||||
acts.append(k)
|
||||
feats.append(lin_to_feats(v))
|
||||
preds = net(Tensor(feats))
|
||||
pred_time = math.exp(preds.numpy().min())
|
||||
act = acts[preds.numpy().argmin()]
|
||||
else:
|
||||
probs = net(Tensor([lin_to_feats(lin)]))
|
||||
dist = probs.exp().numpy()
|
||||
act = dist.argmax()
|
||||
if act == 0: break
|
||||
try:
|
||||
lin.apply_opt(actions[act-1])
|
||||
except Exception:
|
||||
print("FAILED")
|
||||
break
|
||||
tm = time_linearizer(lin, rawbufs)
|
||||
print(f"{tm*1e6:10.2f} {pred_time*1e6:10.2f}", lin.colored_shape())
|
||||
|
||||
print(f"{colored('BEAT', 'green') if tm < tmhc else colored('lost', 'red')} hand coded {tmhc/tm:5.2f}x")
|
||||
wins += int(tm < tmhc)
|
||||
@@ -1,37 +0,0 @@
|
||||
import numpy as np
|
||||
np.set_printoptions(suppress=True)
|
||||
from copy import deepcopy
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
from extra.optimization.pretrain import PolicyNet
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = PolicyNet()
|
||||
load_state_dict(net, safe_load("/tmp/policynet.safetensors"))
|
||||
|
||||
ast_strs = load_worlds()
|
||||
|
||||
for ep_num,ast_str in enumerate(ast_strs):
|
||||
print("\nEPISODE", ep_num)
|
||||
lin = ast_str_to_lin(ast_str)
|
||||
rawbufs = bufs_from_lin(lin)
|
||||
|
||||
linhc = deepcopy(lin)
|
||||
linhc.hand_coded_optimizations()
|
||||
tm, gflops = time_linearizer(linhc, rawbufs)
|
||||
print(f"{tm:10.2f}", linhc.colored_shape())
|
||||
|
||||
while 1:
|
||||
probs = net(Tensor([lin_to_feats(lin)]))
|
||||
dist = probs.exp().numpy()
|
||||
act = dist.argmax()
|
||||
if act == 0: break
|
||||
try:
|
||||
lin.apply_opt(actions[act-1])
|
||||
except Exception:
|
||||
print("FAILED")
|
||||
break
|
||||
tm, gflops = time_linearizer(lin, rawbufs)
|
||||
print(f"{tm:10.2f}", lin.colored_shape())
|
||||
@@ -1215,6 +1215,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64), (32,8,16,16)], lambda x,y,z,m: torch.nn.functional.scaled_dot_product_attention(x,y,z,attn_mask=m), lambda x,y,z,m: Tensor.scaled_dot_product_attention(x,y,z,attn_mask=m))
|
||||
helper_test_op([(32,8,16,64), (32,8,16,64), (32,8,16,64)], lambda x,y,z: torch.nn.functional.scaled_dot_product_attention(x,y,z,is_causal=True), lambda x,y,z: Tensor.scaled_dot_product_attention(x,y,z,is_causal=True))
|
||||
|
||||
def test_binary_crossentropy(self):
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy_with_logits(x,torch.clip(y,0,1)), lambda x,y: x.sigmoid().binary_crossentropy(y.clip(0,1)))
|
||||
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from typing import Tuple, Dict, List, cast, DefaultDict, Optional
|
||||
from typing import Dict, List, cast, DefaultDict, Optional
|
||||
from copy import deepcopy
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
from tinygrad.ops import Device, Compiled, MemBuffer
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
from tinygrad.helpers import prod, getenv
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from collections import defaultdict
|
||||
@@ -29,39 +28,35 @@ actions = [
|
||||
Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256)]
|
||||
device:Compiled = cast(Compiled, Device[Device.DEFAULT])
|
||||
|
||||
# returns time(s) and GFLOPS
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> Tuple[float, float]:
|
||||
# returns time in seconds
|
||||
logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None
|
||||
def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, cnt=3, should_copy=True) -> float:
|
||||
if should_copy: lin = deepcopy(lin) # TODO: remove the need for this
|
||||
var_vals = {k:k.min for k in vars_from_ast(lin.ast)}
|
||||
try:
|
||||
lin.linearize()
|
||||
prg = device.to_program(lin)
|
||||
real_global_size = prg.global_size[:]
|
||||
prg.global_size = [1,1,1]
|
||||
tm = prg(rawbufs, var_vals, force_wait=True)
|
||||
if allow_test_size:
|
||||
test_global_size = prg.global_size[:]
|
||||
while prod(test_global_size) > 16384:
|
||||
for j in range(2,-1,-1):
|
||||
if test_global_size[j] > 1:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
factor = prod(prg.global_size) / prod(test_global_size)
|
||||
prg.global_size = test_global_size
|
||||
else:
|
||||
factor = 1
|
||||
tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)]
|
||||
prg.global_size = real_global_size
|
||||
except Exception:
|
||||
print("FAILED")
|
||||
print(lin.ast)
|
||||
print(lin.applied_opts)
|
||||
return float('inf'), 0
|
||||
|
||||
if allow_test_size:
|
||||
test_global_size = real_global_size[:]
|
||||
while prod(test_global_size) > 16384:
|
||||
for j in range(2,-1,-1):
|
||||
if test_global_size[j] > 1:
|
||||
test_global_size[j] //= 2
|
||||
break
|
||||
factor = prod(real_global_size) / prod(test_global_size)
|
||||
prg.global_size = test_global_size
|
||||
else:
|
||||
prg.global_size = real_global_size
|
||||
factor = 1
|
||||
|
||||
tm = min([prg(rawbufs, var_vals, force_wait=True) for _ in range(cnt)])
|
||||
tm *= factor
|
||||
gflops = sym_infer(lin.info.flops, var_vals)*1e-9/tm
|
||||
return tm, gflops
|
||||
tms = [float('inf')]
|
||||
if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n")
|
||||
return min(tms)
|
||||
|
||||
# get (scrap) buffers for timing the linearizer
|
||||
def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
@@ -75,7 +70,7 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]:
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
|
||||
acted_lins = {}
|
||||
acted_lins = {0:deepcopy(lin)}
|
||||
for i,a in enumerate(actions):
|
||||
lin2 = deepcopy(lin)
|
||||
try:
|
||||
@@ -85,7 +80,7 @@ def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]:
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
if c in {"cyan", "green", "white"}: lcl *= s
|
||||
if up > 256 or lcl > 256: continue
|
||||
acted_lins[i] = lin2
|
||||
acted_lins[i+1] = lin2
|
||||
except Exception:
|
||||
pass
|
||||
return acted_lins
|
||||
|
||||
@@ -734,6 +734,12 @@ class Tensor:
|
||||
if attn_mask is not None and attn_mask.dtype == dtypes.bool: attn_mask = (attn_mask == 0).where(-float("inf"), attn_mask)
|
||||
return (self @ key.transpose(-2,-1) / math.sqrt(self.shape[-1]) + attn_mask).softmax(-1).dropout(dropout_p) @ value
|
||||
|
||||
def binary_crossentropy(self, y:Tensor) -> Tensor:
|
||||
return (-y*self.log() - (1-y)*(1-self).log()).mean()
|
||||
|
||||
def binary_crossentropy_logits(self, y:Tensor) -> Tensor:
|
||||
return (self.maximum(0) - y * self + (1 + self.abs().__neg__().exp()).log()).mean()
|
||||
|
||||
def sparse_categorical_crossentropy(self, Y, ignore_index=-1) -> Tensor:
|
||||
loss_mask = Y != ignore_index
|
||||
y_counter = Tensor.arange(self.shape[-1], dtype=dtypes.int32, requires_grad=False, device=self.device).unsqueeze(0).expand(Y.numel(), self.shape[-1])
|
||||
|
||||
Reference in New Issue
Block a user