mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
fix tc_up in search (#13438)
* tensor_core is missing from Scheduler * test upcast max --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,9 +1,13 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import BEAM, Timing, CI, Context
|
||||
from tinygrad import Variable, Tensor
|
||||
from tinygrad.helpers import BEAM, Timing, CI, prod
|
||||
from tinygrad import Variable, Device, Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.uop.ops import AxisType
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
from tinygrad.codegen.opt.postrange import Scheduler
|
||||
from tinygrad.codegen.opt.search import get_kernel_actions
|
||||
|
||||
def rand(*shape):
|
||||
return Tensor(np.random.rand(*shape).astype(np.float32))
|
||||
@@ -75,5 +79,27 @@ class TestBeamSearch(unittest.TestCase):
|
||||
a = (a + a) * a
|
||||
a.realize()
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.tensor_cores, "test requires tensor cores")
|
||||
def test_tc_up(self):
|
||||
tc = Device[Device.DEFAULT].renderer.tensor_cores[0]
|
||||
size = max(tc.dims[0], tc.dims[1]) * 8
|
||||
a, b = Tensor.rand(size, size, dtype=tc.dtype_in), Tensor.rand(size, size, dtype=tc.dtype_in)
|
||||
ast = a.matmul(b, dtype=tc.dtype_out).schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
s.apply_opt(Opt(OptOps.TC, 0, (-1, 0, 1)))
|
||||
up = prod([x for x, t in zip(s.full_shape, s.axis_types) if t in (AxisType.UPCAST, AxisType.UNROLL)])
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=int(up))
|
||||
upcasted = [s for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]
|
||||
assert len(upcasted) > 0, f"expected upcast/unroll actions after TC with max_up={up}, but got none"
|
||||
|
||||
def test_max_up(self):
|
||||
a = Tensor.rand(16, 16)
|
||||
ast = a.schedule()[-1].ast
|
||||
s = Scheduler(ast, Device[Device.DEFAULT].renderer)
|
||||
for max_up in (2, 4):
|
||||
actions = get_kernel_actions(s, include_0=False, max_up=max_up)
|
||||
for up_opts in [s.applied_opts for s in actions.values() if any(opt.op in (OptOps.UPCAST, OptOps.UNROLL) for opt in s.applied_opts)]:
|
||||
assert len([opt for opt in up_opts if opt.arg > max_up]) == 0 and len([op for op in up_opts if op.arg <= max_up]) > 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user