diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index 2fae9efdbe..70b26c7a2a 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -5,21 +5,18 @@ from tinygrad.uop.ops import UOp, Ops from tinygrad.engine.realize import get_runner from tinygrad.engine.schedule import ExecItem from tinygrad.engine.jit import TinyJit -from tinygrad.helpers import CI import numpy as np from extra.thunder.tiny.tk import WARP_THREADS from extra.thunder.tiny.tk.kernel import Kernel from extra.thunder.tiny.tk.tiles import ST_16X32, RT_16X32, RT_16X16, TileLayout -@unittest.skipIf(CI or Device.DEFAULT not in ["AMD"], "only amd") class TestTK(unittest.TestCase): def setUp(self): - arch = Device["AMD"].arch + arch = getattr(Device[Device.DEFAULT], "renderer").arch if not arch.startswith("gfx9"): self.skipTest(f"arch {arch} not supported") - @unittest.skipIf(CI, "no wmma in ci") def test_simple_matmul(self): N = 8192 BLOCK_SIZE = 64 @@ -73,7 +70,6 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(c.numpy(), ref.numpy()) - @unittest.skipIf(CI, "no wmma in ci") def test_simple_matmul_transposed(self): N = 8192 BLOCK_N, BLOCK_M, BLOCK_K = 64, 64, 128