move time_linearizer to extra.optimization.helpers [pr] (#9048)

no longer used in tinygrad
This commit is contained in:
chenyu
2025-02-12 15:49:58 -05:00
committed by GitHub
parent c15486cf39
commit f4f56d7c15
12 changed files with 41 additions and 36 deletions

View File

@@ -100,3 +100,24 @@ def lin_to_feats(lin:Kernel, use_sts=True):
else:
assert len(ret) == 274, f"wrong len {len(ret)}"
return ret
from tinygrad.device import Device, Buffer
from tinygrad.engine.search import _ensure_buffer_alloc, _time_program
from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_put
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device]
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)

View File

@@ -3,10 +3,10 @@ import numpy as np
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions
from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions
from tinygrad.nn.optim import Adam
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
if __name__ == "__main__":
net = PolicyNet()

View File

@@ -1,11 +1,11 @@
import argparse
from extra.optimization.helpers import ast_str_to_lin
from extra.optimization.helpers import ast_str_to_lin, time_linearizer
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
if __name__ == '__main__':

View File

@@ -6,8 +6,8 @@ from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet

View File

@@ -1,5 +1,5 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
from tinygrad.engine.search import bufs_from_lin, get_kernel_actions
if __name__ == "__main__":
ast_strs = load_worlds()