From f5256e002007fa477072c13ab13872132c83ae57 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 17 Apr 2025 08:00:56 -0400 Subject: [PATCH] Kernel.apply_opts [pr] (#9917) * Kernel.apply_opts [pr] updated all `for opt in`. also updated a few test_liinearizer tests to not implcitly depend on hand_coded_optimization * not you yet --- extra/gemm/amd_matmul.py | 2 +- extra/gemm/max_matmul.py | 3 +-- extra/gemm/tinygrad_nv_matmul.py | 2 +- extra/optimization/extract_sa_pairs.py | 2 +- extra/optimization/helpers.py | 3 +-- test/external/external_debug_metal_sd_conv.py | 2 +- .../external_test_hcq_fuzz_failures.py | 13 ++++++------- test/external/external_test_nv.py | 2 +- test/external/external_test_train_gpt2.py | 4 ++-- test/external/external_test_valid_remove.py | 4 ++-- test/external/verify_kernel.py | 3 +-- test/test_linearizer.py | 19 +++++++------------ test/test_linearizer_dumb.py | 12 ++++++------ test/test_linearizer_failures.py | 13 ++++++------- test/test_linearizer_overflows.py | 2 +- test/test_quantize_onnx.py | 4 ++-- tinygrad/codegen/kernel.py | 6 ++++-- 17 files changed, 44 insertions(+), 52 deletions(-) diff --git a/extra/gemm/amd_matmul.py b/extra/gemm/amd_matmul.py index 16a624affc..613a4cfe81 100644 --- a/extra/gemm/amd_matmul.py +++ b/extra/gemm/amd_matmul.py @@ -82,7 +82,7 @@ if __name__ == "__main__": #Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=LN), Opt(op=OptOps.LOCAL, axis=0, arg=LN)] - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program(ast_transform=ast_transform) if getenv("FAST", 1) and Device.DEFAULT == "AMD": #src = (pathlib.Path(__file__).parent / "fp32_sgemm_amd" / "src" / "kernel8_batched_gmem.s").read_text() diff --git a/extra/gemm/max_matmul.py b/extra/gemm/max_matmul.py index 8b6195d7c3..37c346bf4c 100644 --- a/extra/gemm/max_matmul.py +++ b/extra/gemm/max_matmul.py @@ -55,8 +55,7 @@ def randoms(): def ast_to_cuda_prog(compiler, ast, opts): k = Kernel(ast) k.required_optimizations() - for opt in opts: - k.apply_opt(opt) + k.apply_opts(opts) p = k.to_program() return CUDAProgram(device, p.function_name, compiler.compile(p.src)) diff --git a/extra/gemm/tinygrad_nv_matmul.py b/extra/gemm/tinygrad_nv_matmul.py index 4d54a0062d..ab8f078ade 100644 --- a/extra/gemm/tinygrad_nv_matmul.py +++ b/extra/gemm/tinygrad_nv_matmul.py @@ -28,7 +28,7 @@ if __name__ == "__main__": Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), ] - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() new_src = prg.src # can mod source here diff --git a/extra/optimization/extract_sa_pairs.py b/extra/optimization/extract_sa_pairs.py index 82f6eb002e..db4ac7e444 100644 --- a/extra/optimization/extract_sa_pairs.py +++ b/extra/optimization/extract_sa_pairs.py @@ -51,7 +51,7 @@ def dataset_from_cache(fn): lin = Kernel(eval(ast)) except Exception: continue - for opt in k[:-1]: lin.apply_opt(opt) + lin.apply_opts(k[:-1]) act = k[-1] log_ratio = math.log(old_tm/new_tm) #print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}") diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 9c68cc3b78..e8e7c41c39 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -15,8 +15,7 @@ def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str def kern_str_to_lin(kern_str:str, opts=None): (ast, applied_opts,) = eval(kern_str) k = Kernel(ast, opts=opts) - for opt in applied_opts: - k.apply_opt(opt) + k.apply_opts(applied_opts) return k # load worlds, a dataset of about 12k kernels diff --git a/test/external/external_debug_metal_sd_conv.py b/test/external/external_debug_metal_sd_conv.py index f0f3db7971..2577db097c 100644 --- a/test/external/external_debug_metal_sd_conv.py +++ b/test/external/external_debug_metal_sd_conv.py @@ -32,7 +32,7 @@ ast = UOp(Ops.SINK, dtypes.void, arg=None, src=( opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)] k = Kernel(ast) -for opt in opts: k.apply_opt(opt) +k.apply_opts(opts) bufs = bufs_from_lin(k) prg = CompiledRunner(k.to_program()) diff --git a/test/external/external_test_hcq_fuzz_failures.py b/test/external/external_test_hcq_fuzz_failures.py index 3f434e65bb..dd3c1e5a46 100644 --- a/test/external/external_test_hcq_fuzz_failures.py +++ b/test/external/external_test_hcq_fuzz_failures.py @@ -18,13 +18,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1 if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return - for opt in opts: - try: - lin.apply_opt(opt) - except KernelOptError: - # it's considered fixed if we invalidated the opts - assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" - return + try: + lin.apply_opts(opts) + except KernelOptError: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" + return (msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(lin, rtol=rtol, atol=atol) if msg in ["PASS", "KernelOptError"]: diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index 8be629055d..43db2a0a46 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -34,7 +34,7 @@ class TestNV(unittest.TestCase): opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501 with self.assertRaises(RuntimeError) as cm: lin = Kernel(ast) - for opt in opts: lin.apply_opt(opt) + lin.apply_opts(opts) rawbufs = get_fuzz_rawbufs(lin) prg = CompiledRunner(lin.to_program()) prg(rawbufs, {}, wait=True) diff --git a/test/external/external_test_train_gpt2.py b/test/external/external_test_train_gpt2.py index df7546e6b2..7845308bca 100644 --- a/test/external/external_test_train_gpt2.py +++ b/test/external/external_test_train_gpt2.py @@ -28,7 +28,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)] kernel = Kernel(ast) - for opt in opts: kernel.apply_opt(opt) + kernel.apply_opts(opts) run_linearizer(kernel) def test_2(self): @@ -48,7 +48,7 @@ class TestTrainGpt2Kernel(unittest.TestCase): opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)] kernel = Kernel(ast) - for opt in opts: kernel.apply_opt(opt) + kernel.apply_opts(opts) run_linearizer(kernel) if __name__ == "__main__": diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index d3cc22fca9..3ccfa23e64 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -54,7 +54,7 @@ class TestOpenpilotValidhack(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] kernel = Kernel(ast) - for opt in opts: kernel.apply_opt(opt) + kernel.apply_opts(opts) p = kernel.to_program() print(p.src) @@ -111,7 +111,7 @@ class TestOpenpilotValidhack(unittest.TestCase): opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)] kernel = Kernel(ast) - for opt in opts: kernel.apply_opt(opt) + kernel.apply_opts(opts) p = kernel.to_program() # ((idx1<1)?read_imagef(data1, smp, (int2)(idx0,0)):(float4)(0.0f,0.0f,0.0f,0.0f)) diff --git a/test/external/verify_kernel.py b/test/external/verify_kernel.py index 5a6fb23ba0..a56b3d7fe6 100644 --- a/test/external/verify_kernel.py +++ b/test/external/verify_kernel.py @@ -36,8 +36,7 @@ if __name__ == "__main__": with open(args.pkl, 'rb') as file: (ast, applied_opts,) = pickle.load(file) lin = Kernel(ast) - for opt in applied_opts: - lin.apply_opt(opt) + lin.apply_opts(applied_opts) test_lins = [lin] else: diff --git a/test/test_linearizer.py b/test/test_linearizer.py index a145edaba5..a74a748497 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1001,7 +1001,7 @@ class TestLinearizer(unittest.TestCase): x, y = Tensor.rand(1,128), Tensor.rand(128, 128) r = (x@y).relu() k = Kernel(r.schedule()[-1].ast) - k = hand_coded_optimizations(k) + k.apply_opts([Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]) k.linearize() stores = [u for u in k.uops if u.op is Ops.STORE] @@ -1306,7 +1306,6 @@ class TestLinearizer(unittest.TestCase): run_schedule(sched) np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.]) lin = Kernel(sched_copy[-1].ast) - lin = hand_coded_optimizations(lin) lin.linearize() assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded" @@ -1455,8 +1454,6 @@ class TestFloat4(unittest.TestCase): return (len([uop for uop in k.uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]), len([uop for uop in k.uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)])) - # TODO: express opts below as auto opts - def test_float4_basic(self): a = Tensor.empty(2, 8).realize() b = Tensor.empty(2, 8).realize() @@ -1464,7 +1461,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] k = Kernel(s.ast) - k = hand_coded_optimizations(k) + k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)]) k.linearize() assert TestFloat4.count_float4(k) == (2, 1) @@ -1511,7 +1508,6 @@ class TestFloat4(unittest.TestCase): for i in range(len(sizes)): assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i] - @unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16") def test_float4_unaligned_load(self): a = Tensor.empty(9).realize().shrink(((1, 9),)) b = Tensor.empty(9).realize().shrink(((1, 9),)) @@ -1519,7 +1515,7 @@ class TestFloat4(unittest.TestCase): s = c.schedule()[0] k = Kernel(s.ast) - k = hand_coded_optimizations(k) # implicit trigger float4 dim + k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)]) k.linearize() assert TestFloat4.count_float4(k) == (0, 1) @@ -1667,7 +1663,7 @@ class TestFloat4(unittest.TestCase): ((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]), ]: k = Kernel(ast) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) k.linearize() count = TestFloat4.count_half4(k) assert count == expected, f"{count=}, {expected=}" @@ -1697,7 +1693,7 @@ class TestFloat4(unittest.TestCase): (4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]), ]: k = Kernel(ast) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) k.linearize() count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)]) assert count == expected, f"{count=}, {expected=}" @@ -1720,7 +1716,7 @@ class TestFloat4(unittest.TestCase): (4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]), ]: k = Kernel(ast) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) k.linearize() count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)]) assert count == expected, f"{count=}, {expected=}" @@ -1838,8 +1834,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[] if apply_tc: assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered" else: - for opt in opts: - k.apply_opt(opt) + k.apply_opts(opts) if expected_color_size is not None: cs = list(zip(k.colors(), k.full_shape)) assert cs == expected_color_size, f"expected={expected_color_size} got={cs}" diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 85e88bb6d9..11216155ca 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -38,7 +38,7 @@ class TestLinearizerDumb(unittest.TestCase): opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)] k = Kernel(ast, opts=Device["METAL"].renderer) k.required_optimizations() - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) Device[Device.DEFAULT].compiler.compile_cached(prg.src) @@ -73,7 +73,7 @@ class TestLinearizerDumb(unittest.TestCase): opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX" @@ -91,7 +91,7 @@ class TestLinearizerDumb(unittest.TestCase): opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k.required_optimizations() - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) if_uops = [u for u in k.uops if u.op is Ops.IF] @@ -157,7 +157,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3] @@ -188,7 +188,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2] @@ -212,7 +212,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() print(prg.src) store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE] diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index caa7afacc2..b02b28ff25 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -17,13 +17,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, rtol=1e-2, atol=1e-2): if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return - for opt in opts: - try: - lin.apply_opt(opt) - except KernelOptError: - # it's considered fixed if we invalidated the opts - assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" - return + try: + lin.apply_opts(opts) + except KernelOptError: + # it's considered fixed if we invalidated the opts + assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}" + return compare_result = compare_linearizer(lin, rtol=rtol, atol=atol) if compare_result[0] in ["PASS", "KernelOptError"]: diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 4b6b871f22..091d0758d4 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -14,7 +14,7 @@ from tinygrad.shape.view import View def _test_overflow(ast, opts): lin = Kernel(ast) - for opt in opts: lin.apply_opt(opt) + lin.apply_opts(opts) lin.linearize() bufs = bufs_from_lin(lin) print(bufs) diff --git a/test/test_quantize_onnx.py b/test/test_quantize_onnx.py index 15ae65cee8..0c1492b562 100644 --- a/test/test_quantize_onnx.py +++ b/test/test_quantize_onnx.py @@ -40,7 +40,7 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3): si = out.schedule()[-1] k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer) #opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)] - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() if replace_src is not None: old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0] @@ -296,7 +296,7 @@ class TestDSPCache(unittest.TestCase): opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)] with Context(DEVECTORIZE=0, QUANTIZE=1): k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) - for opt in opts: k.apply_opt(opt) + k.apply_opts(opts) prg = k.to_program() #print(prg.src) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 5674826239..680d70cbed 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -318,8 +318,7 @@ class Kernel: self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt))) if (tc_opts:=self.tensor_core_opts) is not None: - if extra_opts is not None: - for opt in extra_opts: self.apply_opt(opt) + if extra_opts is not None: self.apply_opts(extra_opts) else: if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower # hand-coded TC opts @@ -430,6 +429,9 @@ class Kernel: if self.simplify_ones() and self.tensor_core_opts: self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones() + def apply_opts(self, opts:Sequence[Opt]): + for opt in opts: self.apply_opt(opt) + def required_optimizations(self) -> Kernel: if isinstance(self.membufs[0].dtype, ImageDType): unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]