mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-03 03:05:03 -05:00
* stable diffusion < 324ms * revert swap action * fix tests due to more sum splitting * REDUCEOP_SPLIT_THRESHOLD env var * added from unaligned np test (#2134) * align cpu buffer before copy into cl buffer (#2135) * remove shelve from handcode_resnet50_opt.py (#2139) * Add dictionary keys to reduce db size (#2131) * work * ignore beam cache * dictionary keys are generic * minor db cleanups * fix baseline and extract dataset * fix training * log likelihood * more lin to feats * sts * training policynet * net sort of works * dedup * refactor, stupid new actions * fix uops deduping * BEAM_ESTIMATE --------- Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: imaolo <56898718+imaolo@users.noreply.github.com>
33 lines
1.2 KiB
Python
33 lines
1.2 KiB
Python
from typing import List, Tuple
|
|
from tinygrad.codegen.linearizer import Linearizer
|
|
from tinygrad.features.search import get_linearizer_actions, actions
|
|
|
|
_net = None
|
|
def beam_q_estimate(beam:List[Tuple[Linearizer, float]]) -> List[Tuple[Linearizer, float]]:
|
|
global _net
|
|
if _net is None:
|
|
from tinygrad.nn.state import load_state_dict, safe_load
|
|
from extra.optimization.pretrain_valuenet import ValueNet
|
|
_net = ValueNet(1021+len(actions), 2)
|
|
load_state_dict(_net, safe_load("/tmp/qnet.safetensors"), verbose=False)
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import Context
|
|
from extra.optimization.helpers import lin_to_feats
|
|
import numpy as np
|
|
feats = []
|
|
lins = []
|
|
base_tms = []
|
|
for lin,tm in beam:
|
|
lin_feats = lin_to_feats(lin)
|
|
for a,v in get_linearizer_actions(lin, include_0=False).items():
|
|
acts = np.zeros(len(actions))
|
|
acts[a-1] = 1.0
|
|
feats.append(np.concatenate([lin_feats, acts]))
|
|
lins.append(v)
|
|
base_tms.append(tm)
|
|
with Context(BEAM=0):
|
|
with Tensor.train(False):
|
|
preds = _net(Tensor(feats)).numpy()
|
|
pred_time = np.array(base_tms) / np.exp(preds[:, 0])
|
|
return sorted(zip(lins, pred_time), key=lambda x: x[1])
|