mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fix tc_up in search (#13438)
* tensor_core is missing from Scheduler * test upcast max --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import BEAM, Timing, CI, Context
|
||||
from tinygrad import Variable, Tensor
|
||||
from tinygrad.helpers import BEAM, Timing, CI, prod
|
||||
from tinygrad import Variable, Device, Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.uop.ops import AxisType
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt.search import get_kernel_actions
|
||||
|
||||
def rand(*shape):
|
||||
return Tensor(np.random.rand(*shape).astype(np.float32))
|
||||
@@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase):
|
||||
a = (a + a) * a
|
||||
a.realize()
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tc_up(self):
|
||||
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
|
||||
size = max(tc.dims[0], tc.dims[1]) * 8
|
||||
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
|
||||
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
|
||||
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=int(up))
|
||||
upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]
|
||||
assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none"
|
||||
|
||||
def test_max_up(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
ast = a.schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
for max_up in (2, 4):
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
|
||||
for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]:
|
||||
assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -45,6 +45,7 @@ class Scheduler:
|
||||
ret = Scheduler(self.ast, self.ren)
|
||||
ret.dont_use_locals = self.dont_use_locals
|
||||
ret.applied_opts = self.applied_opts[:]
|
||||
if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core
|
||||
return ret
|
||||
|
||||
kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
|
||||
@@ -307,6 +308,7 @@ class Scheduler:
|
||||
reduce_ranges = [x for x in UOp.sink(*reduceop.src[1:]).toposort() if x.op is Ops.RANGE and x.arg[0] not in tc_reduce_axes]
|
||||
if len(reduce_ranges): tc_uop = UOp(Ops.REDUCE, tc_uop.dtype, (tc_uop,)+tuple(reduce_ranges), Ops.ADD)
|
||||
self.ast = self.ast.substitute({reduceop: tc_uop})
|
||||
self.tensor_core = tc
|
||||
return axes
|
||||
return None
|
||||
|
||||
|
||||
@@ -93,8 +93,8 @@ def _ensure_buffer_alloc(bufs:list[Buffer]) -> list[Buffer]: return [buf.ensure_
|
||||
# *** external API ***
|
||||
|
||||
# get dictionary of all possible actions
|
||||
def get_kernel_actions(s:Scheduler, include_0=True) -> dict[int, Scheduler]:
|
||||
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256), getenv("BEAM_LOCAL_MAX", 1024)
|
||||
def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]:
|
||||
acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024)
|
||||
kernel_actions = actions.copy()
|
||||
|
||||
for i,a in enumerate(kernel_actions):
|
||||
|
||||
Reference in New Issue
Block a user