mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix onnx attention permute (#14025)
* fix onnx attention permute * skip test_attention_4d_fp16_cpu too
This commit is contained in:
8
test/external/external_test_onnx_backend.py
vendored
8
test/external/external_test_onnx_backend.py
vendored
@@ -171,9 +171,11 @@ backend_test.exclude('test_tensorscatter_*')
|
||||
backend_test.exclude('test_l1normalization_*')
|
||||
backend_test.exclude('test_l2normalization_*')
|
||||
backend_test.exclude('test_lpnormalization_*')
|
||||
backend_test.exclude('test_mod_mixed_sign_float16_cpu')
|
||||
backend_test.exclude('test_attention_3d_*')
|
||||
backend_test.exclude('test_attention_4d_*')
|
||||
backend_test.exclude('test_attention_4d_diff_heads_mask4d_padded_kv_cpu') # needs nonpad_kv_seqlen handling
|
||||
backend_test.exclude('test_attention_4d_fp16_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_fp16_expanded_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_cpu') # fp16 numerical issues
|
||||
backend_test.exclude('test_attention_4d_gqa_with_past_and_present_fp16_expanded_cpu') # fp16 numerical issues
|
||||
|
||||
|
||||
# rest of the failing tests
|
||||
|
||||
@@ -1048,14 +1048,15 @@ def get_onnx_ops() -> dict[str, types.FunctionType|dict[OpSetId, types.FunctionT
|
||||
return output, present
|
||||
|
||||
def attention_onnx(Q:Tensor, K:Tensor, V:Tensor, attn_mask:Tensor|None=None, past_key:Tensor|None=None, past_value:Tensor|None=None,
|
||||
is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None, qk_matmul_output_mode:int=0, scale:float|None=None,
|
||||
softcap:float=0.0, softmax_precision:int|None=None):
|
||||
nonpad_kv_seqlen:Tensor|None=None, is_causal:int=0, kv_num_heads:int|None=None, q_num_heads:int|None=None,
|
||||
qk_matmul_output_mode:int=0, scale:float|None=None, softcap:float=0.0, softmax_precision:int|None=None):
|
||||
if nonpad_kv_seqlen is not None: raise NotImplementedError("nonpad_kv_seqlen is not supported")
|
||||
input_shape_len = Q.ndim
|
||||
if input_shape_len == 3:
|
||||
assert q_num_heads is not None and kv_num_heads is not None
|
||||
Q = Q.reshape(Q.shape[0], q_num_heads, Q.shape[1], -1)
|
||||
K = K.reshape(K.shape[0], kv_num_heads, K.shape[1], -1)
|
||||
V = V.reshape(V.shape[0], kv_num_heads, V.shape[1], -1)
|
||||
Q = Q.reshape(Q.shape[0], Q.shape[1], q_num_heads, -1).permute(0, 2, 1, 3)
|
||||
K = K.reshape(K.shape[0], K.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
|
||||
V = V.reshape(V.shape[0], V.shape[1], kv_num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
if past_key is not None: K = past_key.cat(K, dim=2)
|
||||
if past_value is not None: V = past_value.cat(V, dim=2)
|
||||
|
||||
Reference in New Issue
Block a user