move time_linearizer to extra.optimization.helpers [pr] (#9048)

no longer used in tinygrad
This commit is contained in:
chenyu
2025-02-12 15:49:58 -05:00
committed by GitHub
parent c15486cf39
commit f4f56d7c15
12 changed files with 41 additions and 36 deletions

View File

@@ -5,8 +5,9 @@ from tinygrad import Tensor, Device, dtypes, nn
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import Ops, sym_infer
from tinygrad.device import Compiled
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
from extra.optimization.helpers import time_linearizer
def get_sched_resnet():
mdl = ResNet50()

View File

@@ -100,3 +100,24 @@ def lin_to_feats(lin:Kernel, use_sts=True):
else:
assert len(ret) == 274, f"wrong len {len(ret)}"
return ret
from tinygrad.device import Device, Buffer
from tinygrad.engine.search import _ensure_buffer_alloc, _time_program
from tinygrad.helpers import to_function_name, CACHELEVEL, diskcache_put
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device]
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)

View File

@@ -3,10 +3,10 @@ import numpy as np
import math, random
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import actions, bufs_from_lin, time_linearizer, get_kernel_actions
from tinygrad.engine.search import actions, bufs_from_lin, get_kernel_actions
from tinygrad.nn.optim import Adam
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
if __name__ == "__main__":
net = PolicyNet()

View File

@@ -1,11 +1,11 @@
import argparse
from extra.optimization.helpers import ast_str_to_lin
from extra.optimization.helpers import ast_str_to_lin, time_linearizer
from tinygrad import dtypes
from tinygrad.helpers import BEAM, getenv
from tinygrad.device import Device, Compiled
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
if __name__ == '__main__':

View File

@@ -6,8 +6,8 @@ from copy import deepcopy
from tinygrad.helpers import getenv, colored
from tinygrad.tensor import Tensor
from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict
from tinygrad.engine.search import bufs_from_lin, time_linearizer, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats
from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer
from extra.optimization.extract_policynet import PolicyNet
from extra.optimization.pretrain_valuenet import ValueNet

View File

@@ -1,5 +1,5 @@
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import bufs_from_lin, time_linearizer, get_kernel_actions
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
from tinygrad.engine.search import bufs_from_lin, get_kernel_actions
if __name__ == "__main__":
ast_strs = load_worlds()

View File

@@ -1,7 +1,7 @@
import random
from tinygrad.helpers import getenv
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
def optimize_kernel(k):
# TODO: update this

View File

@@ -1,7 +1,7 @@
from tinygrad import Device
from tinygrad.helpers import getenv, DEBUG, BEAM
from tinygrad.engine.search import beam_search, time_linearizer, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.engine.search import beam_search, bufs_from_lin
from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer
if __name__ == "__main__":
filter_reduce = bool(getenv("FILTER_REDUCE"))

View File

@@ -1,10 +1,9 @@
import argparse
from collections import defaultdict
from extra.optimization.helpers import kern_str_to_lin
from extra.optimization.helpers import kern_str_to_lin, time_linearizer
from test.external.fuzz_linearizer import compare_linearizer
from tinygrad.helpers import colored
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import time_linearizer
# Use this with the LOGKERNS options to verify that all executed kernels are valid and evaluate to the same ground truth results

View File

@@ -4,8 +4,8 @@ from test.helpers import ast_const
from tinygrad import dtypes, Device
from tinygrad.helpers import CI
from tinygrad.codegen.kernel import Kernel
from tinygrad.engine.search import Opt, OptOps
from tinygrad.engine.search import time_linearizer, bufs_from_lin
from tinygrad.engine.search import Opt, OptOps, bufs_from_lin
from extra.optimization.helpers import time_linearizer
# stuff needed to unpack a kernel
from tinygrad.ops import UOp, Ops

View File

@@ -4,7 +4,7 @@ from test.helpers import ast_const
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.codegen.kernel import Kernel
from tinygrad.ops import UOp, Ops
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.engine.search import bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes
@@ -12,6 +12,7 @@ from tinygrad.helpers import Context, GlobalCounters
from tinygrad.engine.realize import capturing
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from extra.optimization.helpers import time_linearizer
class TestTimeLinearizer(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "WebGPU timestamps are low precision, tm is 0")

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from dataclasses import replace
from tinygrad.ops import UOp, Ops, Variable, sym_infer
from tinygrad.device import Device, Buffer, Compiler
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored
from tinygrad.helpers import IGNORE_BEAM_CACHE, TC_SEARCH_OVER_SHAPE
from tinygrad.dtype import ImageDType, PtrDType
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
@@ -197,20 +197,3 @@ def optimize_local_size(_prg:Callable, global_size:list[int], rawbufs:list[Buffe
ret = min([(try_exec(local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])
assert not math.isinf(ret[0]), "all optimize_local_size exec failed"
return ret[1]
def time_linearizer(lin:Kernel, rawbufs:list[Buffer], allow_test_size=True, max_global_size=65536, cnt=3, disable_cache=False, clear_l2=False) -> float: # noqa: E501
key = {"ast": lin.ast.key, "opts": str(lin.applied_opts), "allow_test_size": allow_test_size,
"max_global_size": max_global_size, "clear_l2": clear_l2, "device": lin.opts.device, "suffix": lin.opts.suffix}
if not disable_cache and CACHELEVEL >= 2 and (val:=diskcache_get("time_linearizer", key)) is not None: return min(val)
dev = Device[lin.opts.device]
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: dict[Variable, int] = {k:int(k.vmax+k.vmin)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
if CACHELEVEL >= 2: diskcache_put("time_linearizer", key, tms)
return min(tms)