fix tc_up in search (#13438)

* tensor_core is missing from Scheduler

* test upcast max

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
b1tg
2026-01-01 23:25:08 +08:00
committed by GitHub
parent 9726500de8
commit 24723327ac
3 changed files with 32 additions and 4 deletions

View File

@@ -1,9 +1,13 @@
import unittest
import numpy as np
from tinygrad.helpers import BEAM, Timing, CI, Context
from tinygrad import Variable, Tensor
from tinygrad.helpers import BEAM, Timing, CI, prod
from tinygrad import Variable, Device, Tensor
from tinygrad.nn import Conv2d
from tinygrad.uop.ops import AxisType
from tinygrad.codegen.opt import Opt, OptOps
from tinygrad.codegen.opt.postrange import Scheduler
from tinygrad.codegen.opt.search import get_kernel_actions
def rand(*shape):
return Tensor(np.random.rand(*shape).astype(np.float32))
@@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase):
a = (a + a) * a
a.realize()
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
def test_tc_up(self):
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
size = max(tc.dims[0], tc.dims[1]) * 8
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
actions = get_kernel_actions(s, include_0=False, max_up=int(up))
upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]
assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none"
def test_max_up(self):
a = Tensor.rand(16, 16)
ast = a.schedule()[-1].ast
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
for max_up in (2, 4):
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]:
assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0
if __name__ == '__main__':
unittest.main()

View File

@@ -45,6 +45,7 @@ class Scheduler:
ret = Scheduler(self.ast, self.ren)
ret.dont_use_locals = self.dont_use_locals
ret.applied_opts = self.applied_opts[:]
if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core
return ret
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
@@ -307,6 +308,7 @@ class Scheduler:
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
self.ast = self.ast.substitute({reduceop: tc_uop})
self.tensor_core = tc
return axes
return None

View File

@@ -93,8 +93,8 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
# *** external API ***
# get dictionary of all possible actions
def get_kernel_actions(s:Scheduler, include_0=True) -> dict[int, Scheduler]:
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]:
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024)
kernel_actions = actions.copy()
for i,a in enumerate(kernel_actions):