[CI] Add back pre-commit to nvidia CI job (#2159)

This commit is contained in:
Zahi Moudallal
2023-08-22 18:11:03 -07:00
committed by GitHub
parent 6a65c894fe
commit 5282ed890d
5 changed files with 17 additions and 12 deletions

View File

@@ -2313,9 +2313,9 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.tf32.tf32', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float32:
if capability[0] == 7 and capability[1] == 5: # Turing
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f32.f16.f16', ptx)
else:
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f32.f16.f16', ptx)
elif in_dtype == 'float16' and out_dtype == tl.float16:
if capability[0] == 7 and capability[1] == 5: # Turing
assert re.search(r'mma.sync.aligned.m\d+n\d+k8(?:.row.col)?.f16.f16.f16', ptx)
@@ -2331,6 +2331,7 @@ def test_dot_mulbroadcastred(in_dtype, device):
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Requires sm >= 80 to run")
@triton.jit
def kernel(Z, X, Y,
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr,
@@ -2430,6 +2431,7 @@ def test_constexpr(literal, dtype_str, device):
def test_dot_without_load(dtype_str, device):
capability = torch.cuda.get_device_capability()
allow_tf32 = capability[0] > 7
@triton.jit
def _kernel(out, ALLOW_TF32: tl.constexpr):
a = GENERATE_TEST_HERE
@@ -2900,6 +2902,8 @@ def test_call(type, num_ctas, device):
# -------------
# TODO(Keren): if_exp_dynamic
@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_and_static"])
def test_if(if_type, device):
@@ -3084,7 +3088,7 @@ def test_inline_asm(num_ctas, device):
@pytest.mark.parametrize("num_ctas", num_ctas_list)
def test_inline_asm_packed(num_ctas, device):
check_cuda_only(device)
@triton.jit
def kernel(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))