mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move time_linearizer to extra.optimization.helpers [pr] (#9048)
no longer used in tinygrad
This commit is contained in:
@@ -100,3 +100,24 @@ def lin_to_feats(lin:Kernel, use_sts=True):
|
||||
else:
|
||||
assert len(ret) == 274, f"wrong len {len(ret)}"
|
||||
return ret
|
||||
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.engine.search import _ensure_buffer_alloc, _time_program
|
||||
from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_put
|
||||
|
||||
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
|
||||
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
|
||||
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
|
||||
|
||||
dev = Device[lin.opts.device]
|
||||
assert dev.compiler is not None
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
|
||||
p = lin.to_program()
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
|
||||
return min(tms)
|
||||
|
||||
@@ -3,10 +3,10 @@ import numpy as np
|
||||
import math, random
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions
|
||||
from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions
|
||||
from tinygrad.nn.optim import Adam
|
||||
from extra.optimization.extract_policynet import PolicyNet
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = PolicyNet()
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import argparse
|
||||
from extra.optimization.helpers import ast_str_to_lin
|
||||
from extra.optimization.helpers import ast_str_to_lin, time_linearizer
|
||||
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import BEAM, getenv
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.engine.search import beam_search, bufs_from_lin
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -6,8 +6,8 @@ from copy import deepcopy
|
||||
from tinygrad.helpers import getenv, colored
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
|
||||
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
|
||||
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
|
||||
from extra.optimization.extract_policynet import PolicyNet
|
||||
from extra.optimization.pretrain_valuenet import ValueNet
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
||||
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
|
||||
from tinygrad.engine.search import bufs_from_lin, get_kernel_actions
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
|
||||
Reference in New Issue
Block a user