diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index b0f0ca6238..48c54b5670 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -5,8 +5,9 @@ from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel from tinygrad.ops import Ops, sym_infer from tinygrad.device import Compiled -from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin +from tinygrad.engine.search import beam_search, bufs_from_lin from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA +from extra.optimization.helpers import time_linearizer def get_sched_resnet(): mdl = ResNet50() diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 85a070f42b..dc6defb2a6 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -100,3 +100,24 @@ def lin_to_feats(lin:Kernel, use_sts=True): else: assert len(ret) == 274, f"wrong len {len(ret)}" return ret + +from tinygrad.device import Device, Buffer +from tinygrad.engine.search import _ensure_buffer_alloc, _time_program +from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_put + +def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 + key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, + "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} + if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) + + dev = Device[lin.opts.device] + assert dev.compiler is not None + + rawbufs = _ensure_buffer_alloc(rawbufs) + var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} + p = lin.to_program() + tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, + max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) + + if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) + return min(tms) diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index df0b791a71..232002c217 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -3,10 +3,10 @@ import numpy as np import math, random from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions +from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions from tinygrad.nn.optim import Adam from extra.optimization.extract_policynet import PolicyNet -from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats +from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer if __name__ == "__main__": net = PolicyNet() diff --git a/extra/optimization/search.py b/extra/optimization/search.py index 9b8aec9bf5..469487da58 100644 --- a/extra/optimization/search.py +++ b/extra/optimization/search.py @@ -1,11 +1,11 @@ import argparse -from extra.optimization.helpers import ast_str_to_lin +from extra.optimization.helpers import ast_str_to_lin, time_linearizer from tinygrad import dtypes from tinygrad.helpers import BEAM, getenv from tinygrad.device import Device, Compiled from tinygrad.codegen.kernel import Kernel -from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin +from tinygrad.engine.search import beam_search, bufs_from_lin if __name__ == '__main__': diff --git a/extra/optimization/test_net.py b/extra/optimization/test_net.py index 5984444076..0c5b53f99d 100644 --- a/extra/optimization/test_net.py +++ b/extra/optimization/test_net.py @@ -6,8 +6,8 @@ from copy import deepcopy from tinygrad.helpers import getenv, colored from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions -from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats +from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions +from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer from extra.optimization.extract_policynet import PolicyNet from extra.optimization.pretrain_valuenet import ValueNet diff --git a/extra/optimization/test_time_linearizer.py b/extra/optimization/test_time_linearizer.py index effabf1e1e..4bfb1f0f03 100644 --- a/extra/optimization/test_time_linearizer.py +++ b/extra/optimization/test_time_linearizer.py @@ -1,5 +1,5 @@ -from extra.optimization.helpers import load_worlds, ast_str_to_lin -from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions +from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer +from tinygrad.engine.search import bufs_from_lin, get_kernel_actions if __name__ == "__main__": ast_strs = load_worlds() diff --git a/test/external/external_benchmark_hcopt.py b/test/external/external_benchmark_hcopt.py index 252b0531c6..eb42c7af39 100644 --- a/test/external/external_benchmark_hcopt.py +++ b/test/external/external_benchmark_hcopt.py @@ -1,7 +1,7 @@ import random from tinygrad.helpers import getenv -from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin -from extra.optimization.helpers import load_worlds, ast_str_to_lin +from tinygrad.engine.search import beam_search, bufs_from_lin +from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer def optimize_kernel(k): # TODO: update this diff --git a/test/external/speed_beam_v_hcopt.py b/test/external/speed_beam_v_hcopt.py index 447a891ea7..c6f0b430fd 100644 --- a/test/external/speed_beam_v_hcopt.py +++ b/test/external/speed_beam_v_hcopt.py @@ -1,7 +1,7 @@ from tinygrad import Device from tinygrad.helpers import getenv, DEBUG, BEAM -from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin -from extra.optimization.helpers import load_worlds, ast_str_to_lin +from tinygrad.engine.search import beam_search, bufs_from_lin +from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer if __name__ == "__main__": filter_reduce = bool(getenv("FILTER_REDUCE")) diff --git a/test/external/verify_kernel.py b/test/external/verify_kernel.py index e8c99e7415..5a6fb23ba0 100644 --- a/test/external/verify_kernel.py +++ b/test/external/verify_kernel.py @@ -1,10 +1,9 @@ import argparse from collections import defaultdict -from extra.optimization.helpers import kern_str_to_lin +from extra.optimization.helpers import kern_str_to_lin, time_linearizer from test.external.fuzz_linearizer import compare_linearizer from tinygrad.helpers import colored from tinygrad.codegen.kernel import Kernel -from tinygrad.engine.search import time_linearizer # Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index f5fb749956..4b6b871f22 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -4,8 +4,8 @@ from test.helpers import ast_const from tinygrad import dtypes, Device from tinygrad.helpers import CI from tinygrad.codegen.kernel import Kernel -from tinygrad.engine.search import Opt, OptOps -from tinygrad.engine.search import time_linearizer, bufs_from_lin +from tinygrad.engine.search import Opt, OptOps, bufs_from_lin +from extra.optimization.helpers import time_linearizer # stuff needed to unpack a kernel from tinygrad.ops import UOp, Ops diff --git a/test/test_search.py b/test/test_search.py index d9b7241cd2..da2d55cfad 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -4,7 +4,7 @@ from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel from tinygrad.ops import UOp, Ops -from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search +from tinygrad.engine.search import bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer from tinygrad.tensor import Tensor from tinygrad.dtype import dtypes @@ -12,6 +12,7 @@ from tinygrad.helpers import Context, GlobalCounters from tinygrad.engine.realize import capturing from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View +from extra.optimization.helpers import time_linearizer class TestTimeLinearizer(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0") diff --git a/tinygrad/engine/search.py b/tinygrad/engine/search.py index 30e0efe8c6..e1a2d9569d 100644 --- a/tinygrad/engine/search.py +++ b/tinygrad/engine/search.py @@ -4,7 +4,7 @@ from collections import defaultdict from dataclasses import replace from tinygrad.ops import UOp, Ops, Variable, sym_infer from tinygrad.device import Device, Buffer, Compiler -from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name +from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE from tinygrad.dtype import ImageDType, PtrDType from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError @@ -197,20 +197,3 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))]) assert not math.isinf(ret[0]), "all optimize_local_size exec failed" return ret[1] - -def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501 - key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size, - "max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix} - if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val) - - dev = Device[lin.opts.device] - assert dev.compiler is not None - - rawbufs = _ensure_buffer_alloc(rawbufs) - var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()} - p = lin.to_program() - tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs, - max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name)) - - if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms) - return min(tms)