diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 892d72f7b7..dec2ebd16c 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -171,9 +171,11 @@ backend_test.exclude('test_tensorscatter_*') backend_test.exclude('test_l1normalization_*') backend_test.exclude('test_l2normalization_*') backend_test.exclude('test_lpnormalization_*') -backend_test.exclude('test_mod_mixed_sign_float16_cpu') -backend_test.exclude('test_attention_3d_*') -backend_test.exclude('test_attention_4d_*') +backend_test.exclude('test_attention_4d_diff_heads_mask4d_padded_kv_cpu') # needs nonpad_kv_seqlen handling +backend_test.exclude('test_attention_4d_fp16_cpu') # fp16 numerical issues +backend_test.exclude('test_attention_4d_fp16_expanded_cpu') # fp16 numerical issues +backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_cpu') # fp16 numerical issues +backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_expanded_cpu') # fp16 numerical issues # rest of the failing tests diff --git a/tinygrad/nn/onnx.py b/tinygrad/nn/onnx.py index 2498428ec0..fba873d6dc 100644 --- a/tinygrad/nn/onnx.py +++ b/tinygrad/nn/onnx.py @@ -1048,14 +1048,15 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT return output, present def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None, - is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None, - softcap:float=0.0, softmax_precision:int|None=None): + nonpad_kv_seqlen:Tensor|None=None, is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, + qk_matmul_output_mode:int=0, scale:float|None=None, softcap:float=0.0, softmax_precision:int|None=None): + if nonpad_kv_seqlen is not None: raise NotImplementedError("nonpad_kv_seqlen is not supported") input_shape_len = Q.ndim if input_shape_len == 3: assert q_num_heads is not None and kv_num_heads is not None - Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1) - K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1) - V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1) + Q = Q.reshape(Q.shape[0], Q.shape[1], q_num_heads, -1).permute(0, 2, 1, 3) + K = K.reshape(K.shape[0], K.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3) + V = V.reshape(V.shape[0], V.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3) if past_key is not None: K = past_key.cat(K, dim=2) if past_value is not None: V = past_value.cat(V, dim=2)