Kernel.apply_opts [pr] (#9917)

* Kernel.apply_opts [pr]

updated all `for opt in`. also updated a few test_liinearizer tests to not implcitly depend on hand_coded_optimization

* not you yet
This commit is contained in:
chenyu
2025-04-17 08:00:56 -04:00
committed by GitHub
parent e2ed673c94
commit f5256e0020
17 changed files with 44 additions and 52 deletions

View File

@@ -82,7 +82,7 @@ if __name__ == "__main__":
#Opt(op=OptOps.UPCAST, axis=1, arg=4),
Opt(op=OptOps.LOCAL, axis=1, arg=LN),
Opt(op=OptOps.LOCAL, axis=0, arg=LN)]
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program(ast_transform=ast_transform)
if getenv("FAST", 1) and Device.DEFAULT == "AMD":
#src = (pathlib.Path(__file__).parent / "fp32_sgemm_amd" / "src" / "kernel8_batched_gmem.s").read_text()

View File

@@ -55,8 +55,7 @@ def randoms():
def ast_to_cuda_prog(compiler, ast, opts):
k = Kernel(ast)
k.required_optimizations()
for opt in opts:
k.apply_opt(opt)
k.apply_opts(opts)
p = k.to_program()
return CUDAProgram(device, p.function_name, compiler.compile(p.src))

View File

@@ -28,7 +28,7 @@ if __name__ == "__main__":
Opt(op=OptOps.LOCAL, axis=1, amt=2),
Opt(op=OptOps.LOCAL, axis=0, amt=2),
]
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
new_src = prg.src
# can mod source here

View File

@@ -51,7 +51,7 @@ def dataset_from_cache(fn):
lin = Kernel(eval(ast))
except Exception:
continue
for opt in k[:-1]: lin.apply_opt(opt)
lin.apply_opts(k[:-1])
act = k[-1]
log_ratio = math.log(old_tm/new_tm)
#print(f"ratio: {old_tm/new_tm:6.2f}x (log {log_ratio:5.2f}) from {str(act):50s} on {lin.colored_shape()}")

View File

@@ -15,8 +15,7 @@ def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
k = Kernel(ast, opts=opts)
for opt in applied_opts:
k.apply_opt(opt)
k.apply_opts(applied_opts)
return k
# load worlds, a dataset of about 12k kernels

View File

@@ -32,7 +32,7 @@ ast = UOp(Ops.SINK, dtypes.void, arg=None, src=(
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.UNROLL, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0), Opt(op=OptOps.LOCAL, axis=1, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=8), Opt(op=OptOps.LOCAL, axis=2, arg=2)]
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
bufs = bufs_from_lin(k)
prg = CompiledRunner(k.to_program())

View File

@@ -18,13 +18,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, validate_device, rtol=1
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
for opt in opts:
try:
lin.apply_opt(opt)
except KernelOptError:
# it's considered fixed if we invalidated the opts
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
return
try:
lin.apply_opts(opts)
except KernelOptError:
# it's considered fixed if we invalidated the opts
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
return
(msg, rawbufs, var_vals, ground_truth, state1) = compare_linearizer(lin, rtol=rtol, atol=atol)
if msg in ["PASS", "KernelOptError"]:

View File

@@ -34,7 +34,7 @@ class TestNV(unittest.TestCase):
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=1, arg=32), Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=2), Opt(op=OptOps.LOCAL, axis=0, arg=2)] # noqa: E501
with self.assertRaises(RuntimeError) as cm:
lin = Kernel(ast)
for opt in opts: lin.apply_opt(opt)
lin.apply_opts(opts)
rawbufs = get_fuzz_rawbufs(lin)
prg = CompiledRunner(lin.to_program())
prg(rawbufs, {}, wait=True)

View File

@@ -28,7 +28,7 @@ class TestTrainGpt2Kernel(unittest.TestCase):
opts = [Opt(op=OptOps.LOCAL, axis=0, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=3), Opt(op=OptOps.LOCAL, axis=0, arg=2)]
kernel = Kernel(ast)
for opt in opts: kernel.apply_opt(opt)
kernel.apply_opts(opts)
run_linearizer(kernel)
def test_2(self):
@@ -48,7 +48,7 @@ class TestTrainGpt2Kernel(unittest.TestCase):
opts = [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.LOCAL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.LOCAL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=3, arg=4)]
kernel = Kernel(ast)
for opt in opts: kernel.apply_opt(opt)
kernel.apply_opts(opts)
run_linearizer(kernel)
if __name__ == "__main__":

View File

@@ -54,7 +54,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=4), Opt(op=OptOps.UNROLL, axis=1, arg=4), Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
kernel = Kernel(ast)
for opt in opts: kernel.apply_opt(opt)
kernel.apply_opts(opts)
p = kernel.to_program()
print(p.src)
@@ -111,7 +111,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
opts = [Opt(op=OptOps.UPCAST, axis=1, arg=4), Opt(op=OptOps.NOLOCALS, axis=None, arg=None)]
kernel = Kernel(ast)
for opt in opts: kernel.apply_opt(opt)
kernel.apply_opts(opts)
p = kernel.to_program()
# ((idx1<1)?read_imagef(data1, smp, (int2)(idx0,0)):(float4)(0.0f,0.0f,0.0f,0.0f))

View File

@@ -36,8 +36,7 @@ if __name__ == "__main__":
with open(args.pkl, 'rb') as file:
(ast, applied_opts,) = pickle.load(file)
lin = Kernel(ast)
for opt in applied_opts:
lin.apply_opt(opt)
lin.apply_opts(applied_opts)
test_lins = [lin]
else:

View File

@@ -1001,7 +1001,7 @@ class TestLinearizer(unittest.TestCase):
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Kernel(r.schedule()[-1].ast)
k = hand_coded_optimizations(k)
k.apply_opts([Opt(op=OptOps.GROUP, axis=0, arg=8), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)])
k.linearize()
stores = [u for u in k.uops if u.op is Ops.STORE]
@@ -1306,7 +1306,6 @@ class TestLinearizer(unittest.TestCase):
run_schedule(sched)
np.testing.assert_equal(a.flatten().numpy(), [1.,1.,1.,1.,2.,2.,2.,2.,1.,1.,1.,1.,1.,1.,1.,1.])
lin = Kernel(sched_copy[-1].ast)
lin = hand_coded_optimizations(lin)
lin.linearize()
assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded"
@@ -1455,8 +1454,6 @@ class TestFloat4(unittest.TestCase):
return (len([uop for uop in k.uops if uop.op is Ops.LOAD and uop.dtype == dtypes.half.vec(4)]),
len([uop for uop in k.uops if uop.op is Ops.STORE and uop.src[-1].dtype == dtypes.half.vec(4)]))
# TODO: express opts below as auto opts
def test_float4_basic(self):
a = Tensor.empty(2, 8).realize()
b = Tensor.empty(2, 8).realize()
@@ -1464,7 +1461,7 @@ class TestFloat4(unittest.TestCase):
s = c.schedule()[0]
k = Kernel(s.ast)
k = hand_coded_optimizations(k)
k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)])
k.linearize()
assert TestFloat4.count_float4(k) == (2, 1)
@@ -1511,7 +1508,6 @@ class TestFloat4(unittest.TestCase):
for i in range(len(sizes)):
assert TestFloat4.count_float4(kernel_for_shape(sizes[i], shifts[i]), excepted_upcast_size[i]) == expected_output[i]
@unittest.skipIf(Device.DEFAULT in {"CPU", "LLVM"} and AMX, "CPU with AMX upcasts float up to size 16")
def test_float4_unaligned_load(self):
a = Tensor.empty(9).realize().shrink(((1, 9),))
b = Tensor.empty(9).realize().shrink(((1, 9),))
@@ -1519,7 +1515,7 @@ class TestFloat4(unittest.TestCase):
s = c.schedule()[0]
k = Kernel(s.ast)
k = hand_coded_optimizations(k) # implicit trigger float4 dim
k.apply_opts([Opt(op=OptOps.UPCAST, axis=0, arg=4)])
k.linearize()
assert TestFloat4.count_float4(k) == (0, 1)
@@ -1667,7 +1663,7 @@ class TestFloat4(unittest.TestCase):
((2, 0), [Opt(op=OptOps.UNROLL, axis=0, arg=4)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
k.linearize()
count = TestFloat4.count_half4(k)
assert count == expected, f"{count=}, {expected=}"
@@ -1697,7 +1693,7 @@ class TestFloat4(unittest.TestCase):
(4, [Opt(op=OptOps.UPCAST, axis=2, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=4)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
k.linearize()
count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(4)])
assert count == expected, f"{count=}, {expected=}"
@@ -1720,7 +1716,7 @@ class TestFloat4(unittest.TestCase):
(4, [Opt(op=OptOps.LOCAL, axis=1, arg=16), Opt(op=OptOps.UPCAST, axis=1, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=2)]),
]:
k = Kernel(ast)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
k.linearize()
count = len([uop for uop in k.uops if uop.op is Ops.DEFINE_ACC and uop.dtype == dtypes.float.vec(2)])
assert count == expected, f"{count=}, {expected=}"
@@ -1838,8 +1834,7 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:list[Buffer], opts=[]
if apply_tc:
assert k.apply_tensor_cores(1, extra_opts=opts), "no tensor core triggered"
else:
for opt in opts:
k.apply_opt(opt)
k.apply_opts(opts)
if expected_color_size is not None:
cs = list(zip(k.colors(), k.full_shape))
assert cs == expected_color_size, f"expected={expected_color_size} got={cs}"

View File

@@ -38,7 +38,7 @@ class TestLinearizerDumb(unittest.TestCase):
opts = [Opt(op=OptOps.TC, axis=2, arg=(-1, 2)), Opt(op=OptOps.UPCAST, axis=2, arg=0), Opt(op=OptOps.UNROLL, axis=1, arg=0)]
k = Kernel(ast, opts=Device["METAL"].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
Device[Device.DEFAULT].compiler.compile_cached(prg.src)
@@ -73,7 +73,7 @@ class TestLinearizerDumb(unittest.TestCase):
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=4), Opt(op=OptOps.LOCAL, axis=0, arg=8)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX"
@@ -91,7 +91,7 @@ class TestLinearizerDumb(unittest.TestCase):
opts = [Opt(op=OptOps.GROUP, axis=0, arg=0), Opt(op=OptOps.PADTO, axis=0, arg=32), Opt(op=OptOps.LOCAL, axis=0, arg=4), Opt(op=OptOps.UPCAST, axis=0, arg=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
if_uops = [u for u in k.uops if u.op is Ops.IF]
@@ -157,7 +157,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0), Opt(op=OptOps.LOCAL, axis=0, arg=3)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 3]
@@ -188,7 +188,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
load_idxs = [x.src[1] for x in k.uops if x.op is Ops.LOAD and x.src[0].arg == 2]
@@ -212,7 +212,7 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, arg=0), Opt(op=OptOps.UPCAST, axis=2, arg=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
print(prg.src)
store_idxs = [x.src[1] for x in k.uops if x.op is Ops.STORE]

View File

@@ -17,13 +17,12 @@ def helper_test_lin(lin: Kernel, opts, failed_platforms, rtol=1e-2, atol=1e-2):
if any(b.dtype.base == dtypes.half for b in lin.membufs) and not is_dtype_supported(dtypes.half): return
if any(b.dtype.base == dtypes.bfloat16 for b in lin.membufs) and not is_dtype_supported(dtypes.bfloat16): return
for opt in opts:
try:
lin.apply_opt(opt)
except KernelOptError:
# it's considered fixed if we invalidated the opts
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
return
try:
lin.apply_opts(opts)
except KernelOptError:
# it's considered fixed if we invalidated the opts
assert Device.DEFAULT not in failed_platforms, f"unexpected success on {Device.DEFAULT}"
return
compare_result = compare_linearizer(lin, rtol=rtol, atol=atol)
if compare_result[0] in ["PASS", "KernelOptError"]:

View File

@@ -14,7 +14,7 @@ from tinygrad.shape.view import View
def _test_overflow(ast, opts):
lin = Kernel(ast)
for opt in opts: lin.apply_opt(opt)
lin.apply_opts(opts)
lin.linearize()
bufs = bufs_from_lin(lin)
print(bufs)

View File

@@ -40,7 +40,7 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
si = out.schedule()[-1]
k = Kernel(si.ast, opts=Device[Device.DEFAULT].renderer)
#opts = [Opt(op=OptOps.UPCAST, axis=0, arg=128)] #, Opt(op=OptOps.UNROLL, axis=0, arg=4)]
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
if replace_src is not None:
old_name = prg.src.split("__attribute__((noinline)) void ")[1].split("(")[0]
@@ -296,7 +296,7 @@ class TestDSPCache(unittest.TestCase):
opts = [Opt(op=OptOps.UNROLL, axis=0, arg=8), Opt(op=OptOps.UPCAST, axis=1, arg=32), Opt(op=OptOps.UPCAST, axis=0, arg=4)]
with Context(DEVECTORIZE=0, QUANTIZE=1):
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
for opt in opts: k.apply_opt(opt)
k.apply_opts(opts)
prg = k.to_program()
#print(prg.src)

View File

@@ -318,8 +318,7 @@ class Kernel:
self.apply_opt(Opt(OptOps.TC, axis, (tc_select, tc_opt)))
if (tc_opts:=self.tensor_core_opts) is not None:
if extra_opts is not None:
for opt in extra_opts: self.apply_opt(opt)
if extra_opts is not None: self.apply_opts(extra_opts)
else:
if AMX: return True # skip hand-coded TC opts if AMX, upcasting will make kernel slower
# hand-coded TC opts
@@ -430,6 +429,9 @@ class Kernel:
if self.simplify_ones() and self.tensor_core_opts:
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
def apply_opts(self, opts:Sequence[Opt]):
for opt in opts: self.apply_opt(opt)
def required_optimizations(self) -> Kernel:
if isinstance(self.membufs[0].dtype, ImageDType):
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]