mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
make flash attention tests run on DEV=NULL EMULATE=AMD_CDNA4 (#14742)
* make flash attention tests run on DEV=NULL EMULATE=AMD_CDNA4 * no if CI, this is just the arch
This commit is contained in:
@@ -5,21 +5,18 @@ from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.realize import get_runner
|
||||
from tinygrad.engine.schedule import ExecItem
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.kernel import Kernel
|
||||
from extra.thunder.tiny.tk.tiles import ST_16X32, RT_16X32, RT_16X16, TileLayout
|
||||
|
||||
@unittest.skipIf(CI or Device.DEFAULT not in ["AMD"], "only amd")
|
||||
class TestTK(unittest.TestCase):
|
||||
def setUp(self):
|
||||
arch = Device["AMD"].arch
|
||||
arch = getattr(Device[Device.DEFAULT], "renderer").arch
|
||||
if not arch.startswith("gfx9"):
|
||||
self.skipTest(f"arch {arch} not supported")
|
||||
|
||||
@unittest.skipIf(CI, "no wmma in ci")
|
||||
def test_simple_matmul(self):
|
||||
N = 8192
|
||||
BLOCK_SIZE = 64
|
||||
@@ -73,7 +70,6 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(c.numpy(), ref.numpy())
|
||||
|
||||
@unittest.skipIf(CI, "no wmma in ci")
|
||||
def test_simple_matmul_transposed(self):
|
||||
N = 8192
|
||||
BLOCK_N, BLOCK_M, BLOCK_K = 64, 64, 128
|
||||
|
||||
Reference in New Issue
Block a user