mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
resolve some merge conflicts
fix more conflits Resolve merge conflicts Some more build and conflict fixes Resolve conflicts for 06-fused-attension.py resolve merge conflicts for the tutorial group gemm example Fixes for some LIT tests resolve remaining conflicts in tests Fix empty kernel set capability 0
This commit is contained in:
@@ -1316,10 +1316,7 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
|
||||
if is_hip() and (dtype_z == "bfloat16"):
|
||||
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')
|
||||
|
||||
<<<<<<< HEAD
|
||||
size = 1024
|
||||
=======
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
|
||||
if dtype_x.startswith('bfloat'):
|
||||
x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device)
|
||||
@@ -1882,13 +1879,6 @@ layouts = [
|
||||
@pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
<<<<<<< HEAD
|
||||
def test_reduce_layouts(M, N, src_layout, axis, device):
|
||||
if is_hip():
|
||||
pytest.skip("test_reduce_layouts is not supported in HIP")
|
||||
|
||||
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
|
||||
=======
|
||||
@pytest.mark.parametrize("reduce2d", [False, True])
|
||||
@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"])
|
||||
@pytest.mark.parametrize("reduce_op", ["sum", "max"])
|
||||
@@ -1907,7 +1897,6 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
"max": np.max,
|
||||
"sum": np.sum
|
||||
}[reduce_op]
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
rdims_1d = f"{N}" if axis == 0 else f"{M}"
|
||||
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
|
||||
store_range = "%7" if axis == 0 else "%1"
|
||||
@@ -1937,40 +1926,11 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
#blocked = {blocked}
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
<<<<<<< HEAD
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
=======
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
|
||||
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
|
||||
<<<<<<< HEAD
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = tt.splat %arg2 : (!tt.ptr<i32>) -> tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>
|
||||
%12 = tt.addptr %11, {store_range} : tensor<{rdims_2d}x!tt.ptr<i32>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
||||
%13 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
|
||||
%14 = {GPU_DIALECT}.convert_layout %13 : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #src>
|
||||
%15 = "tt.reduce"(%14) ({{
|
||||
^bb0(%arg3: i32, %arg4: i32):
|
||||
%17 = arith.addi %arg3, %arg4 : i32
|
||||
tt.reduce.return %17 : i32
|
||||
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
|
||||
%18 = {GPU_DIALECT}.convert_layout %15 : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%19 = tt.expand_dims %18 {{axis = {axis} : i32}} : (tensor<{rdims_1d}xi32, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}xi32, #blocked>
|
||||
tt.store %12, %19 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}xi32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
=======
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
|
||||
@@ -1986,7 +1946,6 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
tt.reduce.return %17 : {ty}
|
||||
}}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>
|
||||
""" + epilogue
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
@@ -2027,28 +1986,16 @@ def test_store_op(M, src_layout, device):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
<<<<<<< HEAD
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
=======
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
<<<<<<< HEAD
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
|
||||
=======
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<{M}x1x!tt.ptr<f32, 1>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32, 1>, #src>, tensor<{M}x1xi32, #src>
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
tt.store %8, %4 : tensor<{M}x1xf32, #src>
|
||||
tt.return
|
||||
}}
|
||||
@@ -2092,16 +2039,6 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
|
||||
#dst = {dst_layout}
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
<<<<<<< HEAD
|
||||
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
=======
|
||||
tt.func public @kernel(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
@@ -2110,7 +2047,6 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
|
||||
%4 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
%7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.return
|
||||
@@ -2174,11 +2110,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
<<<<<<< HEAD
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
=======
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
@@ -2506,18 +2438,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
red_code = ptx[start:end]
|
||||
assert len(red_code) > 0
|
||||
import os
|
||||
<<<<<<< HEAD
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
|
||||
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
|
||||
# TODO: we should eliminate these unused functions in ptx code.
|
||||
if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]):
|
||||
=======
|
||||
|
||||
# skip this check on hopper because there are some functions whose name contain "shared" in ptx.
|
||||
# TODO: we should eliminate these unused functions in ptx code.
|
||||
if not (capability[0] >= 9):
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
assert "shared" not in red_code
|
||||
assert "bar.sync" not in red_code
|
||||
# torch result
|
||||
@@ -3766,18 +3690,11 @@ intermediate_layouts = [
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("interm_layout", intermediate_layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
<<<<<<< HEAD
|
||||
def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device):
|
||||
if is_hip():
|
||||
pytest.skip("test_convert2d is not supported in HIP")
|
||||
|
||||
=======
|
||||
def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
|
||||
if is_hip():
|
||||
pytest.skip("test_convert2d is not supported in HIP")
|
||||
if (M == 1 or N == 1) and interm_layout:
|
||||
pytest.skip("Out of bound access when maxPhase > 1")
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
if str(src_layout) == str(dst_layout):
|
||||
pytest.skip()
|
||||
if 'mma' in str(src_layout) and 'mma' in str(dst_layout):
|
||||
|
||||
@@ -20,7 +20,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
pytest.skip('bfloat16 tma not support currently')
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
<<<<<<< HEAD
|
||||
if torch.version.hip is not None:
|
||||
if dtype != torch.float16:
|
||||
pytest.skip("Currently flash attention on AMD gpu is only supported in fp16.")
|
||||
@@ -31,11 +30,9 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
=======
|
||||
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
|
||||
if not interpreter and capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability >= 80")
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
@@ -68,15 +65,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
|
||||
<<<<<<< HEAD
|
||||
torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0)
|
||||
if torch.version.hip is None:
|
||||
torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
=======
|
||||
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
|
||||
if torch.version.hip is None:
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
|
||||
Reference in New Issue
Block a user