Files
tinygrad/extra/optimization/run_qnet.py
George Hotz e0201922e3 Q network for pruning BEAM / uops deduping / BEAM_ESTIMATE (#2142)
* stable diffusion < 324ms

* revert swap action

* fix tests due to more sum splitting

* REDUCEOP_SPLIT_THRESHOLD env var

* added from unaligned np test (#2134)

* align cpu buffer before copy into cl buffer (#2135)

* remove shelve from handcode_resnet50_opt.py (#2139)

* Add dictionary keys to reduce db size (#2131)

* work

* ignore beam cache

* dictionary keys are generic

* minor db cleanups

* fix baseline and extract dataset

* fix training

* log likelihood

* more lin to feats

* sts

* training policynet

* net sort of works

* dedup

* refactor, stupid new actions

* fix uops deduping

* BEAM_ESTIMATE

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: imaolo <56898718+imaolo@users.noreply.github.com>
2023-10-27 10:53:06 -10:00

33 lines
1.2 KiB
Python

from typing import List, Tuple
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, actions
_net = None
def beam_q_estimate(beam:List[Tuple[Linearizer, float]]) -> List[Tuple[Linearizer, float]]:
global _net
if _net is None:
from tinygrad.nn.state import load_state_dict, safe_load
from extra.optimization.pretrain_valuenet import ValueNet
_net = ValueNet(1021+len(actions), 2)
load_state_dict(_net, safe_load("/tmp/qnet.safetensors"), verbose=False)
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from extra.optimization.helpers import lin_to_feats
import numpy as np
feats = []
lins = []
base_tms = []
for lin,tm in beam:
lin_feats = lin_to_feats(lin)
for a,v in get_linearizer_actions(lin, include_0=False).items():
acts = np.zeros(len(actions))
acts[a-1] = 1.0
feats.append(np.concatenate([lin_feats, acts]))
lins.append(v)
base_tms.append(tm)
with Context(BEAM=0):
with Tensor.train(False):
preds = _net(Tensor(feats)).numpy()
pred_time = np.array(base_tms) / np.exp(preds[:, 0])
return sorted(zip(lins, pred_time), key=lambda x: x[1])