mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[CI] Add back pre-commit to nvidia CI job (#2159)
This commit is contained in:
@@ -2313,9 +2313,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
|
||||
elif in_dtype == 'float16' and out_dtype == tl.float32:
|
||||
if capability[0] == 7 and capability[1] == 5: # Turing
|
||||
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
|
||||
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
|
||||
else:
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
|
||||
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
|
||||
elif in_dtype == 'float16' and out_dtype == tl.float16:
|
||||
if capability[0] == 7 and capability[1] == 5: # Turing
|
||||
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx)
|
||||
@@ -2331,6 +2331,7 @@ def test_dot_mulbroadcastred(in_dtype, device):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Requires sm >= 80 to run")
|
||||
|
||||
@triton.jit
|
||||
def kernel(Z, X, Y,
|
||||
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
|
||||
@@ -2430,6 +2431,7 @@ def test_constexpr(literal, dtype_str, device):
|
||||
def test_dot_without_load(dtype_str, device):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
allow_tf32 = capability[0] > 7
|
||||
|
||||
@triton.jit
|
||||
def _kernel(out, ALLOW_TF32: tl.constexpr):
|
||||
a = GENERATE_TEST_HERE
|
||||
@@ -2900,6 +2902,8 @@ def test_call(type, num_ctas, device):
|
||||
# -------------
|
||||
|
||||
# TODO(Keren): if_exp_dynamic
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
|
||||
def test_if(if_type, device):
|
||||
|
||||
@@ -3084,7 +3088,7 @@ def test_inline_asm(num_ctas, device):
|
||||
@pytest.mark.parametrize("num_ctas", num_ctas_list)
|
||||
def test_inline_asm_packed(num_ctas, device):
|
||||
check_cuda_only(device)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
|
||||
Reference in New Issue
Block a user