From 8c6299bcedaf680cb84171da7fdce1fb0556a3be Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 10 Apr 2025 23:40:16 -0400 Subject: [PATCH] move hand_coded_optimizations to heuristic.py [pr] (#9844) * move hand_coded_optimizations to heuristic.py [pr] also folded all long lines * make a copy and rename self -> k * fix test --- examples/handcode_opt.py | 3 +- extra/gemm/tvm_gemm.py | 1 - extra/optimization/get_action_space.py | 3 +- extra/optimization/test_net.py | 3 +- extra/replay_pkl.py | 3 +- test/external/external_benchmark_hcopt.py | 3 +- test/external/external_benchmark_schedule.py | 3 +- test/external/speed_beam_v_hcopt.py | 5 +- test/external/speed_compare_cuda_nv.py | 5 +- test/external/speed_compare_cuda_ptx.py | 5 +- test/test_linearizer.py | 19 +-- test/test_winograd.py | 3 +- tinygrad/codegen/heuristic.py | 132 +++++++++++++++++++ tinygrad/codegen/kernel.py | 126 +----------------- tinygrad/engine/realize.py | 3 +- 15 files changed, 168 insertions(+), 149 deletions(-) create mode 100644 tinygrad/codegen/heuristic.py diff --git a/examples/handcode_opt.py b/examples/handcode_opt.py index 48c54b5670..6e3daf6c66 100644 --- a/examples/handcode_opt.py +++ b/examples/handcode_opt.py @@ -3,6 +3,7 @@ from extra.mcts_search import mcts_search from examples.mlperf.helpers import get_mlperf_bert_model from tinygrad import Tensor, Device, dtypes, nn from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.ops import Ops, sym_infer from tinygrad.device import Compiled from tinygrad.engine.search import beam_search, bufs_from_lin @@ -83,7 +84,7 @@ if __name__ == "__main__": # always try hand coded opt lin = Kernel(si.ast, opts=device.renderer) - lin.hand_coded_optimizations() + lin = hand_coded_optimizations(lin) lins.append((lin, "HC")) # maybe try tensor cores diff --git a/extra/gemm/tvm_gemm.py b/extra/gemm/tvm_gemm.py index 03fe0f7894..fa6f661557 100644 --- a/extra/gemm/tvm_gemm.py +++ b/extra/gemm/tvm_gemm.py @@ -40,7 +40,6 @@ sched = C.schedule() from tinygrad.codegen.kernel import Kernel from tinygrad.device import CompilerOptions lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False)) -#lin.hand_coded_optimizations() lin.linearize() from tinygrad.runtime.ops_cpu import renderer src = renderer("mmult", lin.uops) diff --git a/extra/optimization/get_action_space.py b/extra/optimization/get_action_space.py index 22d177ef5b..b5b8435042 100644 --- a/extra/optimization/get_action_space.py +++ b/extra/optimization/get_action_space.py @@ -2,6 +2,7 @@ import random from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.engine.search import actions from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.helpers import tqdm tactions = set() @@ -23,7 +24,7 @@ if __name__ == "__main__": for ast_str in tqdm(ast_strs): lin = ast_str_to_lin(ast_str) #if not lin.apply_tensor_cores(): - lin.hand_coded_optimizations() + lin = hand_coded_optimizations(lin) test_rebuild(lin) # confirm linearize can be called twice uops1 = lin.linearize().uops diff --git a/extra/optimization/test_net.py b/extra/optimization/test_net.py index 0c5b53f99d..4258d94ca9 100644 --- a/extra/optimization/test_net.py +++ b/extra/optimization/test_net.py @@ -7,6 +7,7 @@ from tinygrad.helpers import getenv, colored from tinygrad.tensor import Tensor from tinygrad.nn.state import get_parameters, get_state_dict, safe_save, safe_load, load_state_dict from tinygrad.engine.search import bufs_from_lin, actions, get_kernel_actions +from tinygrad.codegen.heuristic import hand_coded_optimizations from extra.optimization.helpers import load_worlds, ast_str_to_lin, lin_to_feats, time_linearizer from extra.optimization.extract_policynet import PolicyNet from extra.optimization.pretrain_valuenet import ValueNet @@ -34,7 +35,7 @@ if __name__ == "__main__": rawbufs = bufs_from_lin(lin) linhc = deepcopy(lin) - linhc.hand_coded_optimizations() + linhc = hand_coded_optimizations(linhc) tmhc = time_linearizer(linhc, rawbufs) print(f"{tmhc*1e6:10.2f} HC ", linhc.colored_shape()) diff --git a/extra/replay_pkl.py b/extra/replay_pkl.py index 1965c5bd73..91718654e5 100644 --- a/extra/replay_pkl.py +++ b/extra/replay_pkl.py @@ -7,6 +7,7 @@ from tinygrad.engine.jit import TinyJit from tinygrad.engine.realize import CompiledRunner, ExecItem, ScheduleItem, lower_schedule_item from tinygrad.renderer import ProgramSpec from tinygrad.codegen.kernel import Kernel, Opt, OptOps +from tinygrad.codegen.heuristic import hand_coded_optimizations import numpy as np def move_jit_captured_to_dev(captured, device="DSP"): @@ -56,7 +57,7 @@ if __name__ == "__main__": ei.bufs[0].copyin(memoryview(bytearray(b'\x00'*ei.bufs[0].nbytes))) GlobalCounters.kernel_count -= 1 - if not getenv("NOOPT"): k.hand_coded_optimizations() + if not getenv("NOOPT"): k = hand_coded_optimizations(k) p2 = k.to_program() new_ei = replace(ei, prg=CompiledRunner(p2)) new_ei.run() diff --git a/test/external/external_benchmark_hcopt.py b/test/external/external_benchmark_hcopt.py index eb42c7af39..80800d79de 100644 --- a/test/external/external_benchmark_hcopt.py +++ b/test/external/external_benchmark_hcopt.py @@ -1,11 +1,12 @@ import random from tinygrad.helpers import getenv from tinygrad.engine.search import beam_search, bufs_from_lin +from tinygrad.codegen.heuristic import hand_coded_optimizations from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer def optimize_kernel(k): # TODO: update this - return k.hand_coded_optimizations() + return hand_coded_optimizations(k) if __name__ == '__main__': hcopt_wins = beam_wins = tie = 0 diff --git a/test/external/external_benchmark_schedule.py b/test/external/external_benchmark_schedule.py index 5b50669439..8398c48ba6 100644 --- a/test/external/external_benchmark_schedule.py +++ b/test/external/external_benchmark_schedule.py @@ -4,6 +4,7 @@ from tinygrad import Tensor, Device, nn from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen from tinygrad.ops import Ops from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.devectorizer import full_graph_rewrite @@ -37,7 +38,7 @@ if __name__ == "__main__": if BEAM: with Context(DEBUG=max(2, DEBUG.value)): k = beam_search(k, bufs_from_lin(k), BEAM.value) elif NOOPT: pass - else: k.hand_coded_optimizations() + else: k = hand_coded_optimizations(k) kernels.append(k) with Timing("***** model lower in "): uops = [rewrite_shapetracker_with_index(k.get_optimized_ast(), k.opts) for k in kernels] diff --git a/test/external/speed_beam_v_hcopt.py b/test/external/speed_beam_v_hcopt.py index c6f0b430fd..f2cadb530c 100644 --- a/test/external/speed_beam_v_hcopt.py +++ b/test/external/speed_beam_v_hcopt.py @@ -1,6 +1,7 @@ from tinygrad import Device from tinygrad.helpers import getenv, DEBUG, BEAM from tinygrad.engine.search import beam_search, bufs_from_lin +from tinygrad.codegen.heuristic import hand_coded_optimizations from extra.optimization.helpers import load_worlds, ast_str_to_lin, time_linearizer if __name__ == "__main__": @@ -20,14 +21,14 @@ if __name__ == "__main__": k = new_lin() # k.required_optimizations() - if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() + if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k = hand_coded_optimizations(k) assert BEAM > 0 lins = [(("tc" if used_tensor_cores else "hc"), k)] if used_tensor_cores: lins.append(("hc", new_lin())) - lins[-1][1].hand_coded_optimizations() + lins[-1][1] = hand_coded_optimizations(lins[-1][1]) kb = new_lin() # kb.required_optimizations() test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization diff --git a/test/external/speed_compare_cuda_nv.py b/test/external/speed_compare_cuda_nv.py index 85773b93dc..3d15d433a1 100644 --- a/test/external/speed_compare_cuda_nv.py +++ b/test/external/speed_compare_cuda_nv.py @@ -2,6 +2,7 @@ from tinygrad import Device, dtypes from tinygrad.helpers import getenv, colorize_float from extra.optimization.helpers import load_worlds, ast_str_to_lin from test.external.fuzz_linearizer import get_fuzz_rawbufs +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.engine.search import bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.tensor import _to_np_dtype @@ -21,7 +22,7 @@ if __name__ == "__main__": for num,ast in enumerate(ast_strs): # cuda compile culin = ast_str_to_lin(ast, opts=cudev.renderer) - culin.hand_coded_optimizations() + culin = hand_coded_optimizations(culin) has_bf16 = any(b.dtype == dtypes.bfloat16 for b in culin.membufs) cuda_prg = CompiledRunner(culin.to_program()) @@ -31,7 +32,7 @@ if __name__ == "__main__": rdr = nvdev.renderer rdr.device = "NV" nvlin = ast_str_to_lin(ast, opts=rdr) - nvlin.hand_coded_optimizations() + nvlin = hand_coded_optimizations(nvlin) nv_prg = CompiledRunner(nvlin.to_program()) nvbufs = bufs_from_lin(nvlin) test_nvbufs = get_fuzz_rawbufs(nvlin) if not has_bf16 else nvbufs diff --git a/test/external/speed_compare_cuda_ptx.py b/test/external/speed_compare_cuda_ptx.py index a4a3b8b095..2671f1f53d 100644 --- a/test/external/speed_compare_cuda_ptx.py +++ b/test/external/speed_compare_cuda_ptx.py @@ -1,6 +1,7 @@ import itertools from tinygrad import Device from tinygrad.engine.realize import CompiledRunner +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.helpers import getenv, colorize_float from extra.optimization.helpers import load_worlds, ast_str_to_lin from tinygrad.engine.search import bufs_from_lin @@ -23,7 +24,7 @@ if __name__ == "__main__": # cuda compile dev.compiler = CUDACompiler(dev.arch) lin = ast_str_to_lin(ast, opts=dev.renderer) - lin.hand_coded_optimizations() + lin = hand_coded_optimizations(lin) cuda_prg = CompiledRunner(lin.to_program()) bufs = bufs_from_lin(lin) @@ -31,7 +32,7 @@ if __name__ == "__main__": # ptx compile dev.compiler = PTXCompiler(dev.arch) lin = ast_str_to_lin(ast, opts=ptx) - lin.hand_coded_optimizations() + lin = hand_coded_optimizations(lin) lin.linearize() ptx_prg = CompiledRunner(lin.to_program()) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c07f08047b..a145edaba5 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -12,6 +12,7 @@ from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX from tinygrad.dtype import DType, dtypes @@ -1000,7 +1001,7 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = Kernel(r.schedule()[-1].ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) k.linearize() stores = [u for u in k.uops if u.op is Ops.STORE] @@ -1305,7 +1306,7 @@ class TestLinearizer(unittest.TestCase): run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) lin = Kernel(sched_copy[-1].ast) - lin.hand_coded_optimizations() + lin = hand_coded_optimizations(lin) lin.linearize() assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded" @@ -1463,7 +1464,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] k = Kernel(s.ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) k.linearize() assert TestFloat4.count_float4(k) == (2, 1) @@ -1518,7 +1519,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] k = Kernel(s.ast) - k.hand_coded_optimizations() # implicit trigger float4 dim + k = hand_coded_optimizations(k) # implicit trigger float4 dim k.linearize() assert TestFloat4.count_float4(k) == (0, 1) @@ -1731,7 +1732,7 @@ class TestHandCodedOpts(unittest.TestCase): s = layer_2.schedule()[-1] k = Kernel(s.ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) assert len(k.bufs) == 6 # make sure all ops are done in one kernel # masked upcast should upcast masked axis of size 7 # masked upcast should not upcast large (20) last axis @@ -1744,7 +1745,7 @@ class TestHandCodedOpts(unittest.TestCase): s = monster.schedule()[-1] k = Kernel(s.ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) assert len(k.bufs) == 37 # make sure all ops are done in one kernel # should upcast the two Tensor.stacks assert k.upcasted >= 2 and k.full_shape[k.shape_len-k.upcasted:k.shape_len].count(6) == 2 @@ -1760,7 +1761,7 @@ class TestHandCodedOpts(unittest.TestCase): # collect upcasts of tile transform kernels for i, si in enumerate(wino_schedule): k = Kernel(si.ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) if k.reduceop is not None: continue # not a tile transform kernel (there is a gemm reduce kernel) if len(k.bufs) < 22: continue # not a tile transform kernel (there's a permute kernel at the end) upcasts.append(tuple(k.full_shape[k.shape_len - k.upcasted:k.shape_len])) @@ -1772,7 +1773,7 @@ class TestHandCodedOpts(unittest.TestCase): backward_schedule = Tensor.schedule(x.grad, w.grad) for si in backward_schedule: k = Kernel(si.ast) - k.hand_coded_optimizations() + k = hand_coded_optimizations(k) k.linearize() if len(k.bufs) < 20: continue # not a tile transform kernel # heuristic number to make sure that at least some upcasts but not too many upcasts are being done @@ -1859,8 +1860,8 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] # Check correctness of handcoded optimiztions. k = Kernel(realized_ast) + k = hand_coded_optimizations(k) lins.append(k) - k.hand_coded_optimizations() prg = get_prg(k) reset_bufs(outbufs) prg.exec(real_bufs) diff --git a/test/test_winograd.py b/test/test_winograd.py index 36d984678f..d0871f0ed2 100644 --- a/test/test_winograd.py +++ b/test/test_winograd.py @@ -3,6 +3,7 @@ from tinygrad import Tensor, GlobalCounters, dtypes from tinygrad.ops import Ops from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.heuristic import hand_coded_optimizations class TestWinograd(unittest.TestCase): def setUp(self): @@ -26,7 +27,7 @@ class TestWinograd(unittest.TestCase): ops = s.ast.toposort with Timing(f"linearize {i} with {len(ops):4d} ops: "): l = Kernel(s.ast) - l.hand_coded_optimizations() + l = hand_coded_optimizations(l) l.linearize() assert len(l.sts) <= 256 # just the current value to prevent regression if DEBUG >= 2: print(f"{len(l.sts):4d} shapetrackers with max {max(len(x.views) for x in l.sts)} views") diff --git a/tinygrad/codegen/heuristic.py b/tinygrad/codegen/heuristic.py new file mode 100644 index 0000000000..9068fc68e2 --- /dev/null +++ b/tinygrad/codegen/heuristic.py @@ -0,0 +1,132 @@ +import itertools +from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError +from tinygrad.helpers import getenv, DEBUG, all_int, prod +from tinygrad.dtype import ImageDType +from tinygrad.ops import Ops, resolve + +def hand_coded_optimizations(k:Kernel) -> Kernel: + # make a copy so it does not mutate the input + k = k.copy().required_optimizations() + + # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat + MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) + if k.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ + k.reduceop is not None and k.reduceop.arg[0] is Ops.ADD and len(k.full_shape) >= 2 and k.opts.has_shared and \ + (mulop:=k.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: + st0, st1 = k.sts[k.bufs.index(mulop.src[0])], k.sts[k.bufs.index(mulop.src[1])] + strides0, strides1 = st0.real_strides(), st1.real_strides() + def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) + if strides0[k.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)): + for global_idx in range(k.global_dims): + if k.full_shape[k.first_reduce]%MV_THREADS_PER_ROW == 0 and k.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: + if DEBUG >= 3: + print(f"MATVEC: {k.full_shape=} {k.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") + if MV_THREADS_PER_ROW > 1: k.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) + if MV_BLOCKSIZE > 1: k.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) + if MV_ROWS_PER_THREAD > 1: k.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) + return k + + if k.opts.has_local and k.opts.has_shared and all_int(k.sts[0].shape[:k.first_reduce]): + # are we grouping? (requires local shape support) + if not [x for x in k.sts[0].unit_stride_axes() if x >= k.first_upcast and k.sts[0].shape[x]%4 == 0] and \ + k.first_reduce <= 2 and k.first_reduce < k.shape_len and prod(k.sts[0].shape[:k.first_reduce]) <= 2048: + # TODO: use 1024 if it's allowed in a smarter way + for sz in ([256, 16] if prod(k.sts[0].shape[:k.first_reduce]) <= 32 else [16]): + if all(st.shape[k.first_reduce] % sz == 0 or st.shape[k.first_reduce] == 1 for st in k.sts): + try: # may fail due to excessive smem usage + k.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) + break + except KernelOptError: pass + + # upcast float4 images + for buf_index,buf in enumerate(k.bufs): + unit_stride_axes_mul_4 = [i for i in k.sts[buf_index].unit_stride_axes(ignore_valid=True) if k.sts[buf_index].shape[i]%4 == 0] + if buf.src[0].dtype.__class__ is ImageDType: + #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {k.bufs[buf_index]}" + if len(unit_stride_axes_mul_4) and all(x < k.first_upcast for x in unit_stride_axes_mul_4): + if unit_stride_axes_mul_4[0] < k.first_reduce: + k.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) + else: + k.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-k.first_reduce, 4)) + + # no more opt if we are grouping + if k.group_for_reduces: return k + + # **** below this line need to be optional and benchmarked **** + + # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) + # to trigger the above bug, remove prod(k.full_shape[k.first_upcast:]) from the below + # expression and run test/test_ops.py with IMAGE=2 + # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) + # this can be made much smarter + to_upcast: list[int] = [] + # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) + for axis in range(k.first_reduce): + # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent + # for now skip upcasting here if there is a symbolic axis + if isinstance(k.full_shape[axis], int) and k.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in k.sts) and \ + prod(k.full_shape[k.first_upcast:]) * prod(k.full_shape[j] for j in to_upcast) * k.full_shape[axis] <= 7 * 7: + if DEBUG >= 4: print(f"upcasting masked axis : {axis}") + to_upcast.append(axis) + for axis in to_upcast[::-1]: k.apply_opt(Opt(OptOps.UPCAST, axis, 0)) + + # potentially do more upcasts of non reduce axes based on a heuristic + is_dsp = k.opts is not None and k.opts.device == "DSP" + upcasted_axis: set[int] = set() + while resolve(prod(k.sts[0].shape[:k.first_reduce]) >= 1024): + xb_choices = [] + # consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP) + for axis, upcast_amount in itertools.product(range(k.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): + # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already + if axis not in upcasted_axis and isinstance(k.full_shape[axis], int) and k.full_shape[axis]%upcast_amount == 0 and \ + any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in k.upcasted_axis(buf_index)) for buf_index, st in enumerate(k.sts)): + xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in k.sts), + sum(st.views[-1].strides[axis] for st in k.sts), axis, upcast_amount)) + if xb_choices: + xb_choices = sorted(xb_choices) + if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") + k.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) + upcasted_axis.add(xb_choices[0][2]) + else: break + + # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. + if k.first_reduce < k.first_upcast and (prod(k.full_shape[k.first_upcast:]) <= 4 or \ + not any(r for _,_,r in k.upcasted_axis(k.full_buf_index))) and (k.upcasted == 0 or prod(k.full_shape[-k.upcasted:]) < 64): + if isinstance(s:=k.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis + k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0)) + # if it's small, upcast a second reduce dimension too + if k.first_reduce < k.first_upcast and s <= 3 and isinstance(s2:=k.full_unupcasted_shape[-1], int) and s2 <= 3: + k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, 0)) + else: + for splits in [4]: + if k.full_unupcasted_shape[-1]%splits == 0: + k.apply_opt(Opt(OptOps.UNROLL, len(k.full_unupcasted_shape)-1-k.first_reduce, splits)) + break + + # if nothing at all is upcasted and it's easy to, do an upcast + for splits in [4]: + if k.upcasted == 0 and k.full_unupcasted_shape and k.full_unupcasted_shape[-1] % splits == 0: + k.apply_opt(Opt(OptOps.UPCAST, len(k.full_unupcasted_shape)-1, splits)) + + # **** local groups **** + + if k.opts.has_local: + if getenv("NOLOCALS") and k.local_dims == 0 and not k.group_for_reduces: + k.apply_opt(Opt(OptOps.NOLOCALS)) + else: + # prioritize making expand axes local + local_axis_ranking = [(any(k.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(k.sts))), axis) \ + for axis in range(len(k.full_shape[:k.first_reduce]))] + to_local: list[tuple[int, int]] = [] + for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): + local_size = prod(sz for _, sz in to_local) + local_sz: int|None = next((x for x in ([32] * (axis == 0) + [16,8,4,3,2]) if k.full_shape[axis] % x == 0 and local_size * x <= 128), None) + if local_sz is not None: to_local.append((axis, local_sz)) + deleted_shape = 0 + for axis, local_sz in sorted(to_local[:3]): + axis = axis - deleted_shape + will_delete_shape = local_sz == k.full_shape[axis] + k.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) + if will_delete_shape: deleted_shape += 1 + + return k \ No newline at end of file diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index d75ef21bc5..5674826239 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -4,8 +4,7 @@ from dataclasses import dataclass from collections import defaultdict from typing import Optional, cast, Final, Callable, Sequence -from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops -from tinygrad.ops import PatternMatcher +from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, can_pad, resolve, Variable, sint, graph_rewrite, track_rewrites, print_uops, PatternMatcher from tinygrad.spec import type_verify, shape_spec from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, ProgramSpec, Opt, OptOps @@ -438,129 +437,6 @@ class Kernel: if all(x < self.first_upcast for x in unit_stride_axes_mul_4): self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) return self - def hand_coded_optimizations(self) -> Kernel: - self.required_optimizations() - - # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat - MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) - if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: - st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])] - strides0, strides1 = st0.real_strides(), st1.real_strides() - def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) - if strides0[self.first_reduce] == 1 and not (has_expanded_axis(st0.shape, strides0) and has_expanded_axis(st1.shape, strides1)): - for global_idx in range(self.global_dims): - if self.full_shape[self.first_reduce]%MV_THREADS_PER_ROW == 0 and self.full_shape[global_idx]%(MV_BLOCKSIZE*MV_ROWS_PER_THREAD) == 0: - if DEBUG >= 3: - print(f"MATVEC: {self.full_shape=} {self.first_reduce=} {strides0=} {MV_BLOCKSIZE=} {MV_THREADS_PER_ROW=} {MV_ROWS_PER_THREAD=}") - if MV_THREADS_PER_ROW > 1: self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW)) - if MV_BLOCKSIZE > 1: self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE)) - if MV_ROWS_PER_THREAD > 1: self.apply_opt(Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)) - return self - - if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]): - # are we grouping? (requires local shape support) - if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \ - self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: - # TODO: use 1024 if it's allowed in a smarter way - for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]): - if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts): - try: # may fail due to excessive smem usage - self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz)) - break - except KernelOptError: pass - - # upcast float4 images - for buf_index,buf in enumerate(self.bufs): - unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True) if self.sts[buf_index].shape[i]%4 == 0] - if buf.src[0].dtype.__class__ is ImageDType: - #assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}" - if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4): - if unit_stride_axes_mul_4[0] < self.first_reduce: - self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4)) - else: - self.apply_opt(Opt(OptOps.UNROLL, unit_stride_axes_mul_4[0]-self.first_reduce, 4)) - - # no more opt if we are grouping - if self.group_for_reduces: return self - - # **** below this line need to be optional and benchmarked **** - - # TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx) - # to trigger the above bug, remove prod(self.full_shape[self.first_upcast:]) from the below - # expression and run test/test_ops.py with IMAGE=2 - # if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack) - # this can be made much smarter - to_upcast: list[int] = [] - # upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first) - for axis in range(self.first_reduce): - # we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent - # for now skip upcasting here if there is a symbolic axis - if isinstance(self.full_shape[axis], int) and self.full_shape[axis] <= 7 and any(st.axis_is_masked(axis) for st in self.sts) and \ - prod(self.full_shape[self.first_upcast:]) * prod(self.full_shape[j] for j in to_upcast) * self.full_shape[axis] <= 7 * 7: - if DEBUG >= 4: print(f"upcasting masked axis : {axis}") - to_upcast.append(axis) - for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0)) - - # potentially do more upcasts of non reduce axes based on a heuristic - is_dsp = self.opts is not None and self.opts.device == "DSP" - upcasted_axis: set[int] = set() - while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024): - xb_choices = [] - # consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP) - for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]): - # if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already - if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501 - xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501 - if xb_choices: - xb_choices = sorted(xb_choices) - if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}") - self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3])) - upcasted_axis.add(xb_choices[0][2]) - else: break - - # if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. - if self.first_reduce < self.first_upcast and (prod(self.full_shape[self.first_upcast:]) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))) and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted:]) < 64): # noqa: E501 - if isinstance(s:=self.full_unupcasted_shape[-1], int) and s <= 32: # NOTE: cannot loop unroll symbolic axis - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - # if it's small, upcast a second reduce dimension too - if self.first_reduce < self.first_upcast and s <= 3 and isinstance(s2:=self.full_unupcasted_shape[-1], int) and s2 <= 3: - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, 0)) - else: - for splits in [4]: - if self.full_unupcasted_shape[-1]%splits == 0: - self.apply_opt(Opt(OptOps.UNROLL, len(self.full_unupcasted_shape)-1-self.first_reduce, splits)) - break - - # if nothing at all is upcasted and it's easy to, do an upcast - # TODO: this is breaking the tests - for splits in [4]: - if self.upcasted == 0 and self.full_unupcasted_shape and self.full_unupcasted_shape[-1] % splits == 0: - self.apply_opt(Opt(OptOps.UPCAST, len(self.full_unupcasted_shape)-1, splits)) - - # **** local groups **** - - if self.opts.has_local: - if getenv("NOLOCALS") and self.local_dims == 0 and not self.group_for_reduces: - self.apply_opt(Opt(OptOps.NOLOCALS)) - else: - # prioritize making expand axes local - local_axis_ranking = [(any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))), axis) for axis in range(len(self.full_shape[:self.first_reduce]))] # noqa: E501 - to_local: list[tuple[int, int]] = [] - for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])): - local_size = prod(sz for _, sz in to_local) - local_sz: Optional[int] = next((x for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2]) if self.full_shape[axis] % x == 0 and local_size * x <= 128), None) # noqa: E501 - if local_sz is not None: to_local.append((axis, local_sz)) - deleted_shape = 0 - for axis, local_sz in sorted(to_local[:3]): - axis = axis - deleted_shape - will_delete_shape = local_sz == self.full_shape[axis] - self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz)) - if will_delete_shape: deleted_shape += 1 - - return self - # **** kernel outputs **** kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index a621b6cdb8..dd5c109383 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -7,6 +7,7 @@ from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates from tinygrad.codegen.kernel import Kernel +from tinygrad.codegen.heuristic import hand_coded_optimizations from tinygrad.engine.schedule import ScheduleItem # **************** Program Creation **************** @@ -15,7 +16,7 @@ logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS def get_kernel(renderer:Renderer, ast:UOp) -> Kernel: k = Kernel(ast, opts=renderer).required_optimizations() if not NOOPT: - if not k.apply_tensor_cores(getenv("TC", 1)): k.hand_coded_optimizations() + if not k.apply_tensor_cores(getenv("TC", 1)): k = hand_coded_optimizations(k) if BEAM >= 1: from tinygrad.engine.search import beam_search, bufs_from_lin kb = Kernel(ast, opts=renderer).required_optimizations()