diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 070e03f08e..184faf6e66 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -170,9 +170,6 @@ jobs: run: | PYTHONPATH="." python test/external/dist/test_world.py PYTHONPATH="." python test/external/dist/test_collectives.py - - if: ${{ matrix.task == 'realworld' }} - name: Test KOPT - run: PYTHONPATH="." KOPT=1 BUDGET=20 GPU=1 DEBUG=1 python -m pytest -rA -n=auto test/models/test_real_world.py - if: ${{ matrix.task == 'realworld' }} name: Run GPT2 run: | diff --git a/docs/env_vars.md b/docs/env_vars.md index c6d0afbfe5..ab8352cf87 100644 --- a/docs/env_vars.md +++ b/docs/env_vars.md @@ -43,7 +43,6 @@ LLVM | [1] | enable LLVM backend LLVMOPT | [1] | enable slightly more expensive LLVM optimizations LAZY | [1] | enable lazy operations (this is the default) OPT | [1-3] | optimization level -KOPT | [1-2] | kernel optimization, 1 turns it on, 2 caches the found optimizations BUDGET | [#] | kernel optimization search budget GRAPH | [1] | create a graph of all operations (requires graphviz) GRAPHPATH | [/path/to] | where to put the generated graph @@ -177,7 +176,6 @@ TORCHCUDA | [1] | enable the torch cuda backend Variable | Possible Value(s) | Description ---|---|--- -KOPT | [1] | enable kernel optimization KCACHE | [1] | enable kernel cache ### test/external/external_test_opt.py diff --git a/examples/handcode_resnet50_opt.py b/examples/handcode_resnet50_opt.py index bc76f1a4a6..717ebe48d6 100644 --- a/examples/handcode_resnet50_opt.py +++ b/examples/handcode_resnet50_opt.py @@ -3,8 +3,8 @@ from models.resnet import ResNet50 from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps, Device, Compiled from tinygrad.codegen.linearizer import Linearizer -from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions -from tinygrad.helpers import ansilen, DEBUG, getenv, flatten +from tinygrad.features.search import time_linearizer, beam_search +from tinygrad.helpers import ansilen, DEBUG, getenv from tinygrad.graph import print_tree from tinygrad.lazy import vars_from_ast from tinygrad.shape.symbolic import sym_infer @@ -62,17 +62,7 @@ if __name__ == "__main__": for ao in global_db[str(lin.ast)]: lin.apply_opt(ao) else: - best_tm = float('inf') - beam = [lin] - while 1: - acted_lins = flatten([get_linearizer_actions(lin).items() for lin in beam]) - timed_lins = [(v,time_linearizer(v, rawbufs)) for k,v in acted_lins if k != 0] - opts = sorted(timed_lins, key=lambda x: x[1]) - if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster - best_tm = opts[0][1] - beam = [x[0] for x in opts[:getenv("BEAM")]] - if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape()) - lin = beam[0] + lin = beam_search(lin, rawbufs, getenv("BEAM")) global_db[str(lin.ast)] = lin.applied_opts lins.append(lin) diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py index 91faf2100e..987e0206b0 100644 --- a/extra/optimization/get_action_space.py +++ b/extra/optimization/get_action_space.py @@ -1,6 +1,6 @@ from tqdm import tqdm from extra.optimization.helpers import load_worlds, ast_str_to_lin -from tinygrad.codegen.search import actions +from tinygrad.features.search import actions from tinygrad.codegen.linearizer import Linearizer tactions = set() diff --git a/extra/optimization/pretrain_policynet.py b/extra/optimization/pretrain_policynet.py index ffd9f2394d..5eca5dd0f9 100644 --- a/extra/optimization/pretrain_policynet.py +++ b/extra/optimization/pretrain_policynet.py @@ -4,7 +4,7 @@ from tinygrad.nn import Linear from tinygrad.tensor import Tensor from tinygrad.nn.optim import Adam from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.codegen.search import actions +from tinygrad.features.search import actions from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, assert_same_lin INNER = 32 diff --git a/extra/optimization/rl.py b/extra/optimization/rl.py index bf70bee04e..9b2dd2d8d5 100644 --- a/extra/optimization/rl.py +++ b/extra/optimization/rl.py @@ -3,7 +3,7 @@ import numpy as np import math, random from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.codegen.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions +from tinygrad.features.search import actions, bufs_from_lin, time_linearizer, get_linearizer_actions from tinygrad.nn.optim import Adam from extra.optimization.pretrain_policynet import PolicyNet from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats diff --git a/extra/optimization/test_net.py b/extra/optimization/test_net.py index 2c7da7feb3..9dd40af613 100644 --- a/extra/optimization/test_net.py +++ b/extra/optimization/test_net.py @@ -6,7 +6,7 @@ from copy import deepcopy from tinygrad.helpers import getenv, colored from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict -from tinygrad.codegen.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions +from tinygrad.features.search import bufs_from_lin, time_linearizer, actions, get_linearizer_actions from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats from extra.optimization.pretrain_policynet import PolicyNet from extra.optimization.pretrain_valuenet import ValueNet diff --git a/extra/optimization/test_time_linearizer.py b/extra/optimization/test_time_linearizer.py index ab3d592034..47af6c52b7 100644 --- a/extra/optimization/test_time_linearizer.py +++ b/extra/optimization/test_time_linearizer.py @@ -1,5 +1,5 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin -from tinygrad.codegen.search import bufs_from_lin, time_linearizer, get_linearizer_actions +from tinygrad.features.search import bufs_from_lin, time_linearizer, get_linearizer_actions if __name__ == "__main__": ast_strs = load_worlds() diff --git a/test/external/external_test_gpu_ast.py b/test/external/external_test_gpu_ast.py index ec6b7a15df..0c98552fdf 100644 --- a/test/external/external_test_gpu_ast.py +++ b/test/external/external_test_gpu_ast.py @@ -13,9 +13,6 @@ OSX = platform.system() == "Darwin" def compile_and_test_ast(ast, local_size=None): k = CLCodegen(ast) - if getenv("KOPT", 0): - from extra.kernel_search import apply_optimization - apply_optimization(k, ast, 10, getenv("KCACHE", 0)) prg = k.codegen().build(CLProgram) if local_size is not None: prg.local_size = local_size for i in range(5): prg(prg.lower(k.bufs)) diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 28cac58fb6..b114ecdcd3 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -6,34 +6,12 @@ from tinygrad.nn.state import get_parameters from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE from tinygrad.ops import Device, GlobalCounters, LazyOp, LoadOps from tinygrad.helpers import CI, dtypes, getenv, prod -from tinygrad.features.kopt import kernel_optimize_opts from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS from examples.hlb_cifar10 import SpeedyResNet from examples.llama import Transformer as LLaMaTransformer, MODEL_PARAMS as LLAMA_MODEL_PARAMS from examples.stable_diffusion import UNetModel -def kopt_search_hook(k, create_k, to_prg, baseline, bufs, var_vals): - import nevergrad as ng - wanna_output = bufs[0].toCPU().copy() - def check_opt(x): - try: - k = create_k() - for o in x: k.apply_opt(o) - prg = to_prg(k) - first_tm = prg.exec(bufs, var_vals, force_wait=True, optimizing=True) - np.testing.assert_allclose(wanna_output, bufs[0].toCPU(), atol=1e-4, rtol=1e-4) - return first_tm - except Exception: - return 10000_000 # 10000 seconds is infinity - opts = kernel_optimize_opts(k) - if not opts: return "BASELINE" - search_space = prod([len(x.choices) for x in opts]) - budget = getenv("BUDGET", 20) # THIS IS TEST BUDGET - optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, budget)) - recommendation = optimizer.minimize(check_opt) - return recommendation.value if recommendation.loss < baseline else "BASELINE" - def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jitted=False): tms = [] for _ in range(4): @@ -70,15 +48,9 @@ class TestRealWorld(unittest.TestCase): def setUp(self): self.old_type = Tensor.default_type np.random.seed(2002) - # TODO: abstract better to remove this junk - if getenv("KOPT"): - self.oldfunc = getattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search") - setattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search", kopt_search_hook) def tearDown(self): Tensor.default_type = self.old_type - if getenv("KOPT"): - setattr(__import__("tinygrad.features.kopt", fromlist=["kernel_optimize_search"]), "kernel_optimize_search", self.oldfunc) @unittest.skipUnless(not CI, "too big for CI") def test_stable_diffusion(self): @@ -111,7 +83,6 @@ class TestRealWorld(unittest.TestCase): def test(t): return model(t, 0).realize() helper_test("test_gpt2", lambda: (Tensor([[1,]]),), test, 0.21 if CI else 0.9, 129 if CI else 369, all_jitted=True) - @unittest.skipIf(getenv("KOPT"), "cifar hangs with KOPT") @unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["LLVM"], "needs JIT, too long on CI LLVM") def test_train_cifar(self): # TODO: with default device diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0035dbf697..16072c8430 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -1,4 +1,6 @@ +from __future__ import annotations from typing import NamedTuple, Optional, List, Tuple, cast, Dict +from copy import deepcopy import itertools from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, ReduceOps, MemBuffer, BufferOps, Device, Compiled from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, all_int, ansilen @@ -84,6 +86,9 @@ class Kernel: self.global_size: Optional[List[int]] = None self.local_size: Optional[List[int]] = None + def copy(self): + return deepcopy(self) + @property def membufs(self) -> List[MemBuffer]: return [x for x in self.bufs if isinstance(x, MemBuffer)] diff --git a/tinygrad/features/kopt.py b/tinygrad/features/kopt.py deleted file mode 100644 index b63bacc08a..0000000000 --- a/tinygrad/features/kopt.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Callable -import time -from tinygrad.codegen.linearizer import Linearizer -from tinygrad.codegen.optimizer import Opt, OptOps -from tinygrad.helpers import DEBUG, prod, getenv -from tinygrad.lazy import vars_from_ast - -def get_divisors(n, min_div = 1, max_div = 512): - if min_div > 1: yield 1 - for d in range(min_div, min(max_div, n//2) + 1): - if n % d == 0: yield d - -def kernel_optimize_opts(k:Linearizer): - import nevergrad as ng - opts = [] - for i in range(k.first_reduce): - # TODO: the upcast always happen first, you might want to reverse this? - # TODO: the order of the locals might improve things too - opts.append(ng.p.TransitionChoice([Opt(OptOps.UPCAST,i,s) for s in get_divisors(k.full_shape[i], max_div=8)])) - opts.append(ng.p.TransitionChoice([Opt(OptOps.LOCAL,i,s) for s in get_divisors(k.full_shape[i], min_div=4)])) - for i in range(k.shape_len-k.first_reduce): - opts.append(ng.p.TransitionChoice([Opt(OptOps.UNROLL,i,s) for s in get_divisors(k.full_shape[k.first_reduce+i], max_div=8)])) - return opts - -def kernel_optimize_search(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, baseline, bufs, var_vals): - import nevergrad as ng - def opt(x): - try: - k = create_k() - for o in x: k.apply_opt(o) - prg = to_prg(k) - first_tm = prg.exec(bufs, var_vals, force_wait=True, optimizing=True) - if baseline*5 < first_tm*1000: return first_tm*1000 # very slow - tm = min([first_tm]+[prg.exec(bufs, var_vals, force_wait=True, optimizing=True) for _ in range(2)])*1000 - return tm - except Exception: - if DEBUG >= 3: - import traceback - traceback.print_exc() - return 10000_000 # 10000 seconds is infinity - opts = kernel_optimize_opts(k) - if not opts: return "BASELINE" - search_space = prod([len(x.choices) for x in opts]) - st = time.perf_counter() - budget = getenv("BUDGET", 200) - optimizer = ng.optimizers.NGOpt(parametrization=ng.p.Tuple(*opts), budget=min(search_space, budget)) - recommendation = optimizer.minimize(opt) - et = time.perf_counter() - st - if DEBUG >= 1: print(f"optimizer({et:6.2f} s to search) space {search_space:8d} with tm {recommendation.loss:5.2f} ms vs baseline {baseline:5.2f} ms, a {baseline/recommendation.loss:5.2f}x gain : {k.colored_shape()}") - return recommendation.value if recommendation.loss < baseline else "BASELINE" - -# optimization -global_db = None -def kernel_optimize(k:Linearizer, create_k:Callable[[], Linearizer], to_prg, bufs, key): - global global_db - - skey = str(key) - - if getenv("KOPT") == 2 and global_db is None: - import shelve - global_db = shelve.open("/tmp/kopt_cache") - - if global_db is not None and skey in global_db: - choice = global_db[skey] - elif k.has_variable_shape(): - # don't optimize variable shapes - choice = "BASELINE" - else: - var_vals = {k:k.min for k in vars_from_ast(k.ast)} - # get baseline - def get_baseline(): - k = create_k() - k.hand_coded_optimizations() - prg = to_prg(k) - return min([prg.exec(bufs, var_vals, force_wait=True, optimizing=True) for _ in range(5)])*1000 - choice = kernel_optimize_search(k, create_k, to_prg, get_baseline(), bufs, var_vals) - if global_db is not None: - global_db[skey] = choice - global_db.sync() - - if choice == "BASELINE": - k.hand_coded_optimizations() - else: - for o in choice: k.apply_opt(o) \ No newline at end of file diff --git a/tinygrad/codegen/search.py b/tinygrad/features/search.py similarity index 63% rename from tinygrad/codegen/search.py rename to tinygrad/features/search.py index dcbaa23c8f..09b1220456 100644 --- a/tinygrad/codegen/search.py +++ b/tinygrad/features/search.py @@ -1,8 +1,7 @@ from typing import Dict, List, cast, DefaultDict, Optional -from copy import deepcopy from tinygrad.lazy import vars_from_ast from tinygrad.ops import Device, Compiled, MemBuffer -from tinygrad.helpers import prod, getenv, ImageDType, flatten +from tinygrad.helpers import prod, getenv, ImageDType, flatten, DEBUG from tinygrad.codegen.linearizer import Linearizer from tinygrad.runtime.lib import RawBuffer from collections import defaultdict @@ -18,16 +17,18 @@ actions += [ Opt(op=OptOps.GROUPTOP, axis=1, amt=16), Opt(op=OptOps.GROUPTOP, axis=1, amt=256), Opt(op=OptOps.GROUPTOP, axis=2, amt=16), Opt(op=OptOps.GROUPTOP, axis=2, amt=256) ] -device:Compiled = cast(Compiled, Device[Device.DEFAULT]) # returns time in seconds -logtm = open(getenv("LOGTM", ""),"a") if getenv("LOGTM", "") else None +import shelve +logtm = shelve.open(getenv("LOGTM", "")) if getenv("LOGTM", "") else None def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=True, max_global_size=65536, cnt=3, should_copy=True) -> float: - if should_copy: lin = deepcopy(lin) # TODO: remove the need for this + key = str((lin.ast, lin.applied_opts)) + if should_copy and logtm is not None and key in logtm: return min(logtm[key]) # pylint: disable=E1135 # NOTE: we check should_copy since this may have side effects + if should_copy: lin = lin.copy() # TODO: remove the need for this var_vals = {k:k.min for k in vars_from_ast(lin.ast)} try: lin.linearize() - prg = device.to_program(lin) + prg = cast(Compiled, Device[Device.DEFAULT]).to_program(lin) real_global_size = prg.global_size[:] if allow_test_size: test_global_size = prg.global_size[:] @@ -41,14 +42,16 @@ def time_linearizer(lin:Linearizer, rawbufs:List[RawBuffer], allow_test_size=Tru #print(real_global_size, test_global_size, factor) else: factor = 1 - tms = [prg(rawbufs, var_vals, force_wait=True)*factor for _ in range(cnt)] + # TODO: this is super broken for var_vals + global_size, local_size = prg.launch_dims(var_vals) + tms = [prg.clprg(global_size, local_size, *rawbufs, *var_vals.values(), wait=True)*factor for _ in range(cnt)] prg.global_size = real_global_size except Exception: #print("FAILED") #print(lin.ast) #print(lin.applied_opts) tms = [float('inf')] - if logtm: logtm.write(str((lin.ast, lin.applied_opts, tms))+"\n") + if logtm is not None: logtm[key] = tms return min(tms) # get (scrap) buffers for timing the linearizer @@ -57,17 +60,17 @@ def bufs_from_lin(lin:Linearizer) -> List[RawBuffer]: for x in lin.membufs: bufsts[x.idx].append(x) rawbufs:List[Optional[RawBuffer]] = [None]*len(bufsts) for k,lx in bufsts.items(): - rawbufs[k] = device.buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype) + rawbufs[k] = cast(Compiled, Device[Device.DEFAULT]).buffer(prod(lx[0].dtype.shape) if isinstance(lx[0].dtype, ImageDType) else max(y.st.size() for y in lx), lx[0].dtype) assert all(r is not None for r in rawbufs) return cast(List[RawBuffer], rawbufs) # get dictionary of all possible actions -def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]: - acted_lins = {0:deepcopy(lin)} +def get_linearizer_actions(lin:Linearizer, include_0=True) -> Dict[int, Linearizer]: + acted_lins = {0:lin.copy()} if include_0 else {} for i,a in enumerate(actions): if a.axis >= lin.shape_len: continue if lin.full_shape[a.axis] == a.amt and Opt(a.op, a.axis, 0) in actions: continue - lin2 = deepcopy(lin) + lin2 = lin.copy() try: lin2.apply_opt(a) up, lcl = 1, 1 @@ -79,3 +82,16 @@ def get_linearizer_actions(lin:Linearizer) -> Dict[int, Linearizer]: except Exception: pass return acted_lins + +def beam_search(lin, rawbufs, amt): + best_tm = float('inf') + beam: List[Linearizer] = [lin] + while 1: + acted_lins = flatten([get_linearizer_actions(lin, include_0=False).values() for lin in beam]) + timed_lins = [(v,time_linearizer(v, rawbufs)) for v in acted_lins] + opts = sorted(timed_lins, key=lambda x: x[1]) + if len(opts) == 0 or best_tm <= opts[0][1]: break # we didn't get faster + best_tm = opts[0][1] + beam = [x[0] for x in opts[:amt]] + if DEBUG >= 1: print(f"{opts[0][1]*1e3:10.2f} ms from {len(opts):3d} actions", beam[0].colored_shape()) + return beam[0] diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 16059cade4..f7cd930bbb 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -268,8 +268,9 @@ class Compiled: from tinygrad.codegen.linearizer import Linearizer k = Linearizer(ast, self.linearizer_opts) assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}" - from tinygrad.features.kopt import kernel_optimize - if getenv("KOPT"): kernel_optimize(k, lambda: Linearizer(ast, self.linearizer_opts), self.to_program, rawbuffers, ast) + if getenv("BEAM"): + from tinygrad.features.search import beam_search + k = beam_search(k, rawbuffers, getenv("BEAM")) elif not getenv("NOOPT"): if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations() return self.to_program(k)