make flash attention tests run on DEV=NULL EMULATE=AMD_CDNA4 (#14742)

* make flash attention tests run on DEV=NULL EMULATE=AMD_CDNA4

* no if CI, this is just the arch
This commit is contained in:
qazal
2026-02-14 11:24:37 +08:00
committed by GitHub
parent e8bd432bf6
commit 6dc7ea58fd

View File

@@ -5,21 +5,18 @@ from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import get_runner
from tinygrad.engine.schedule import ExecItem
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import CI
import numpy as np
from extra.thunder.tiny.tk import WARP_THREADS
from extra.thunder.tiny.tk.kernel import Kernel
from extra.thunder.tiny.tk.tiles import ST_16X32, RT_16X32, RT_16X16, TileLayout
@unittest.skipIf(CI or Device.DEFAULT not in ["AMD"], "only amd")
class TestTK(unittest.TestCase):
def setUp(self):
arch = Device["AMD"].arch
arch = getattr(Device[Device.DEFAULT], "renderer").arch
if not arch.startswith("gfx9"):
self.skipTest(f"arch {arch} not supported")
@unittest.skipIf(CI, "no wmma in ci")
def test_simple_matmul(self):
N = 8192
BLOCK_SIZE = 64
@@ -73,7 +70,6 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(c.numpy(), ref.numpy())
@unittest.skipIf(CI, "no wmma in ci")
def test_simple_matmul_transposed(self):
N = 8192
BLOCK_N, BLOCK_M, BLOCK_K = 64, 64, 128