fix onnx attention permute (#14025)

* fix onnx attention permute

* skip test_attention_4d_fp16_cpu too
This commit is contained in:
chenyu
2026-01-05 08:58:50 -05:00
committed by GitHub
parent 5cff5698f7
commit 9497ec00f2
2 changed files with 11 additions and 8 deletions

View File

@@ -171,9 +171,11 @@ backend_test.exclude('test_tensorscatter_*')
backend_test.exclude('test_l1normalization_*')
backend_test.exclude('test_l2normalization_*')
backend_test.exclude('test_lpnormalization_*')
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
backend_test.exclude('test_attention_3d_*')
backend_test.exclude('test_attention_4d_*')
backend_test.exclude('test_attention_4d_diff_heads_mask4d_padded_kv_cpu') # needs nonpad_kv_seqlen handling
backend_test.exclude('test_attention_4d_fp16_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_fp16_expanded_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_cpu') # fp16 numerical issues
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_expanded_cpu') # fp16 numerical issues
# rest of the failing tests

View File

@@ -1048,14 +1048,15 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
return output, present
def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
softcap:float=0.0, softmax_precision:int|None=None):
nonpad_kv_seqlen:Tensor|None=None, is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None,
qk_matmul_output_mode:int=0, scale:float|None=None, softcap:float=0.0, softmax_precision:int|None=None):
if nonpad_kv_seqlen is not None: raise NotImplementedError("nonpad_kv_seqlen is not supported")
input_shape_len = Q.ndim
if input_shape_len == 3:
assert q_num_heads is not None and kv_num_heads is not None
Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
Q = Q.reshape(Q.shape[0], Q.shape[1], q_num_heads, -1).permute(0, 2, 1, 3)
K = K.reshape(K.shape[0], K.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
V = V.reshape(V.shape[0], V.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
if past_key is not None: K = past_key.cat(K, dim=2)
if past_value is not None: V = past_value.cat(V, dim=2)