From 24723327ac407edcc42117117db60ec327630fe5 Mon Sep 17 00:00:00 2001 From: b1tg <33436708+b1tg@users.noreply.github.com> Date: Thu, 1 Jan 2026 23:25:08 +0800 Subject: [PATCH] fix tc_up in search (#13438) * tensor_core is missing from Scheduler * test upcast max --------- Co-authored-by: chenyu --- extra/optimization/test_beam_search.py | 30 ++++++++++++++++++++++++-- tinygrad/codegen/opt/postrange.py | 2 ++ tinygrad/codegen/opt/search.py | 4 ++-- 3 files changed, 32 insertions(+), 4 deletions(-) diff --git a/extra/optimization/test_beam_search.py b/extra/optimization/test_beam_search.py index f493ec48eb..36aba141b6 100644 --- a/extra/optimization/test_beam_search.py +++ b/extra/optimization/test_beam_search.py @@ -1,9 +1,13 @@ import unittest import numpy as np -from tinygrad.helpers import BEAM, Timing, CI, Context -from tinygrad import Variable, Tensor +from tinygrad.helpers import BEAM, Timing, CI, prod +from tinygrad import Variable, Device, Tensor from tinygrad.nn import Conv2d +from tinygrad.uop.ops import AxisType +from tinygrad.codegen.opt import Opt, OptOps +from tinygrad.codegen.opt.postrange import Scheduler +from tinygrad.codegen.opt.search import get_kernel_actions def rand(*shape): return Tensor(np.random.rand(*shape).astype(np.float32)) @@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase): a = (a + a) * a a.realize() + @unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores") + def test_tc_up(self): + tc = Device[Device.DEFAULT].renderer.tensor_cores[0] + size = max(tc.dims[0], tc.dims[1]) * 8 + a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in) + ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast + s = Scheduler(ast, Device[Device.DEFAULT].renderer) + s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1))) + up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)]) + actions = get_kernel_actions(s, include_0=False, max_up=int(up)) + upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)] + assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none" + + def test_max_up(self): + a = Tensor.rand(16, 16) + ast = a.schedule()[-1].ast + s = Scheduler(ast, Device[Device.DEFAULT].renderer) + for max_up in (2, 4): + actions = get_kernel_actions(s, include_0=False, max_up=max_up) + for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]: + assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0 + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/opt/postrange.py b/tinygrad/codegen/opt/postrange.py index b00bd5bea3..fd86308a95 100644 --- a/tinygrad/codegen/opt/postrange.py +++ b/tinygrad/codegen/opt/postrange.py @@ -45,6 +45,7 @@ class Scheduler: ret = Scheduler(self.ast, self.ren) ret.dont_use_locals = self.dont_use_locals ret.applied_opts = self.applied_opts[:] + if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core return ret kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) @@ -307,6 +308,7 @@ class Scheduler: reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes] if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD) self.ast = self.ast.substitute({reduceop: tc_uop}) + self.tensor_core = tc return axes return None diff --git a/tinygrad/codegen/opt/search.py b/tinygrad/codegen/opt/search.py index 18d7ea49bc..13e86e8924 100644 --- a/tinygrad/codegen/opt/search.py +++ b/tinygrad/codegen/opt/search.py @@ -93,8 +93,8 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_ # *** external API *** # get dictionary of all possible actions -def get_kernel_actions(s:Scheduler, include_0=True) -> dict[int, Scheduler]: - acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024) +def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]: + acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024) kernel_actions = actions.copy() for i,a in enumerate(kernel_actions):