s/get_linearizer/get_kernel [run_process_replay] (#5467)

This commit is contained in:
chenyu
2024-07-13 20:32:22 -04:00
committed by GitHub
parent 0345577032
commit 28972418c4
11 changed files with 25 additions and 25 deletions

View File

@@ -3,7 +3,7 @@ import numpy as np
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions
from tinygrad.nn.optim import Adam
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
@@ -32,7 +32,7 @@ if __name__ == "__main__":
# mask valid actions
valid_action_mask = np.zeros((len(actions)+1), dtype=np.float32)
for x in get_linearizer_actions(lin): valid_action_mask[x] = 1
for x in get_kernel_actions(lin): valid_action_mask[x] = 1
probs *= valid_action_mask
probs /= sum(probs)

View File

@@ -1,6 +1,6 @@
from typing import List, Tuple
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import get_linearizer_actions, actions
from tinygrad.engine.search import get_kernel_actions, actions
_net = None
def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float]]:
@@ -19,7 +19,7 @@ def beam_q_estimate(beam:List[Tuple[Kernel, float]]) -> List[Tuple[Kernel, float
base_tms = []
for lin,tm in beam:
lin_feats = lin_to_feats(lin)
for a,v in get_linearizer_actions(lin, include_0=False).items():
for a,v in get_kernel_actions(lin, include_0=False).items():
acts = np.zeros(len(actions))
acts[a-1] = 1.0
feats.append(np.concatenate([lin_feats, acts]))

View File

@@ -6,7 +6,7 @@ from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet
@@ -43,7 +43,7 @@ if __name__ == "__main__":
while 1:
if VALUE:
acts,feats = [], []
for k,v in get_linearizer_actions(lin).items():
for k,v in get_kernel_actions(lin).items():
acts.append(k)
feats.append(lin_to_feats(v))
preds = net(Tensor(feats))

View File

@@ -1,5 +1,5 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_linearizer_actions
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions
if __name__ == "__main__":
ast_strs = load_worlds()
@@ -9,7 +9,7 @@ if __name__ == "__main__":
test_tm = time_linearizer(lin, rawbufs)
if test_tm < 1e-2: continue
print(f"EXAMPLE {i}")
acted_lins = get_linearizer_actions(lin)
acted_lins = get_kernel_actions(lin)
ok_avg, short_avg = 0, 0
for k,v in acted_lins.items():
tm1 = time_linearizer(v, rawbufs)