add test_onnx_ops.py (#8569)

* boom

* fix webgpu

* use exact variable names in test so that AI can read easier

* add tag for specific test name like test a specific dtype

* fix ruff

* astype everything

* dtype in array creation

* just arange

* is 67% considered fixed?

* move test up

* small cleanups

* share function

* add qgemm as well

* add qgemm too

* make sure qgemm comes out as int

* take out qgemm for now

* fixed test

* add correct qgemm

* addressing feedback here too, early naive fix for now

* simplify bias and c to be minimalistic enough to test correctness

* refactored qlinearops

* maybe these asserts aren't the best..

* fix test

* updated tests to cover new ops

* try to add to CI

* move test_onnx_ops into testextra/

* more attention tests

* qlinear_add atol=1

* attention still not fullllllly correct

* it is what it is

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
geohotstan
2025-02-25 05:15:22 +08:00
committed by GitHub
parent 56288243e6
commit f0b24d230c
5 changed files with 230 additions and 38 deletions

View File

@@ -385,6 +385,8 @@ jobs:
run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test ONNX (LLVM)
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test Additional ONNX Ops (CPU)
run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py
- name: Run CLOUD=1 Test
run: |
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py

View File

@@ -287,6 +287,7 @@ def get_onnx_ops():
Softmax = {1:Softmax_1, 13:Softmax_13}
def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1)
def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf())
def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate)
def FastGelu(x:Tensor, bias:Tensor|None=None):
# this is tanh approximated
return (x + bias).gelu() if bias is not None else x.gelu()
@@ -548,8 +549,11 @@ def get_onnx_ops():
def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]):
return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9)
def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12):
x = x + skip + bias
return x.layernorm(eps=epsilon) * gamma + beta, None, None, x
x = x + skip
if bias is not None: x = x + bias
ret = x.layernorm(eps=epsilon) * gamma
if beta is not None: ret = ret + beta
return ret, None, None, x
def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor,
segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None,
position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0):
@@ -616,44 +620,49 @@ def get_onnx_ops():
base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1)
return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1)
def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None,
relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None,
mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None,
qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None):
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention
assert num_heads is not None # required
assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None)
assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \
"functionality not supported yet" # TODO strange params
hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,)
def Attention(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, attention_bias:Tensor|None=None,
past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, num_heads:int|None=None,
past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, rotary_embedding_dim:int|None=None,
scale:float|None=None, unidirectional:int=0):
assert not do_rotary and not attention_bias, "TODO"
if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3
qkv = x.linear(weights, bias)
q, k, v = qkv.split(qkv_hidden_sizes, dim=2)
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = x.linear(weights, bias)
xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))]
xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)]
batch_size, seq_len, _ = x.shape
q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes)
q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size)))
present = None
if past is not None:
xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2)
present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0))
k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2)
present = k.stack(v)
def attn(query, key, value, attn_mask):
query_length, key_length = query.shape[-2], key.shape[-2]
cdim = max(query_length, key_length) + 1
attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1])
# This is where Tensor.scaled_dot_product_attention differs:
causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length]
masked = Tensor.where(causal_mask, attn_weights, -math.inf)
if attn_mask is not None: masked = masked + attn_mask
return masked.softmax(-1) @ value
if scale is None: scale = 1.0 / math.sqrt(q_head_size)
attn_scores = q @ k.transpose(-1, -2) * scale
bsz, _, seq_len, _ = xq.shape
out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1)
return out, present if past is not None else out
if mask_index is not None:
assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}"
if mask_index.ndim != 1: mask = mask_index.bool()
else:
if mask_index.shape[0] == batch_size:
mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1)
elif mask_index.shape[0] == 2*batch_size:
end_positions = mask_index[:batch_size]
start_positions = mask_index[batch_size:]
arange = Tensor.arange(seq_len).unsqueeze(0)
mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1))
else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented")
while mask.ndim < 4: mask = mask.unsqueeze(1)
attn_scores = mask.where(attn_scores, mask_filter_value)
if unidirectional:
causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril()
attn_scores = causal_mask.where(attn_scores, mask_filter_value)
output = attn_scores.softmax(-1) @ v
output = output.transpose(1, 2).reshape(batch_size, seq_len, -1)
return output, present
# ***** Indexing Ops *****
def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices]

View File

@@ -16,7 +16,6 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue]):
def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
run_onnx = OnnxRunner(onnx.load(onnx_file))
tinygrad_out = run_onnx(inputs)
ort_options = ort.SessionOptions()
ort_options.log_severity_level = 3
@@ -26,8 +25,10 @@ def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5):
out_values = ort_sess.run(out_names, np_inputs)
ort_out = dict(zip(out_names, out_values))
assert len(tinygrad_out) == len(ort_out) and tinygrad_out.keys() == ort_out.keys()
tinygrad_out = run_onnx(inputs)
assert tinygrad_out.keys() == ort_out.keys()
for k in tinygrad_out.keys():
tiny_v, onnx_v = tinygrad_out[k], ort_out[k]
if tiny_v is None: assert tiny_v == onnx_v
if tiny_v is None: assert onnx_v is None, f"{k}: {tiny_v=}, {onnx_v=}"
else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}")

View File

@@ -54,6 +54,7 @@ setup(name='tinygrad',
"pillow",
"onnx==1.16.0",
"onnx2torch",
"onnxruntime",
"opencv-python",
"tabulate",
"tqdm",

179
test/external/external_test_onnx_ops.py vendored Normal file
View File

@@ -0,0 +1,179 @@
# inputs, attributes, and outputs for tests are found here:
# https://github.com/onnx/onnx/blob/main/docs/Operators.md
# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md
from typing import Any
import unittest, onnx, tempfile
import numpy as np
from extra.onnx_helpers import validate
class TestOnnxOps(unittest.TestCase):
DOMAIN = None
def helper_test_single_op(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str], rtol=1e-3, atol=1e-6):
onnx_inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()]
onnx_outputs = [onnx.helper.make_empty_tensor_value_info(name) for name in outs]
nodes = [onnx.helper.make_node(op, list(inps), list(outs), domain=self.DOMAIN, **opts)]
graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", onnx_inputs, onnx_outputs)
model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}")
with tempfile.NamedTemporaryFile() as tmp:
onnx.save(model, tmp.name)
validate(tmp.name, inps, rtol, atol)
class TestMainOnnxOps(TestOnnxOps):
DOMAIN = ""
def test_reshape(self):
inputs = {"in": np.arange(6, dtype=np.float32), "shape": np.array([2,3], dtype=np.int64)}
attributes = {}
outputs = ["out"]
self.helper_test_single_op("Reshape", inputs, attributes, outputs)
def test_qlinear_conv(self):
for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]:
for b in (np.ones([32], dtype=np.int32), np.zeros([32], dtype=np.int32)):
with self.subTest(dtype=dtype, zero_point=zero_point):
dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max
inputs = {
"x": np.random.randint(dtype_min, dtype_max + 1, [1, 3, 224, 224], dtype=dtype),
"x_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"x_zero_point": np.array(zero_point, dtype=dtype),
"w": np.random.randint(dtype_min, dtype_max + 1, [32, 3, 3, 3], dtype=dtype),
"w_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"w_zero_point": np.array(zero_point, dtype=dtype),
"y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"y_zero_point": np.array(zero_point, dtype=dtype),
"b": b
}
attributes = {'auto_pad': 'NOTSET', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (3, 3), 'pads': (1, 1, 1, 1), 'strides': (2, 2)}
outputs = ["out"]
self.helper_test_single_op("QLinearConv", inputs, attributes, outputs, atol=1)
def test_qlinear_matmul(self):
for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]:
with self.subTest(dtype=dtype, zero_point=zero_point):
dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max
inputs = {
"A": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype),
"A_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"A_zero_point": np.array(zero_point, dtype=dtype),
"B": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype),
"B_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"B_zero_point": np.array(zero_point, dtype=dtype),
"Y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"Y_zero_point": np.array(zero_point, dtype=dtype)
}
attributes = {}
outputs = ["Y"]
self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs, atol=1)
class TestContribOnnxOps(TestOnnxOps):
DOMAIN = "com.microsoft"
def test_attention(self):
batch_size, seq_len, input_hidden_size = 2, 8, 256
num_heads, head_size = 4, 64
hidden_size = num_heads * head_size
v_hidden_size = hidden_size
# for mask_index
right_padding_mask = np.random.randint(1, seq_len + 1, size=(batch_size,), dtype=np.int32)
end_positions = np.random.randint(1, seq_len + 1, size=(batch_size,), dtype=np.int32)
start_positions = np.array([np.random.randint(0, end) for end in end_positions], dtype=np.int32)
left_padding_mask = np.concatenate([end_positions, start_positions])
base_inps = {
"input": np.random.randn(batch_size, seq_len, input_hidden_size).astype(np.float32),
"weights": np.random.randn(input_hidden_size, hidden_size * 3).astype(np.float32),
# bias is required in ORT (segfaults otherwise), eventhough docs says it's optional
"bias": np.random.randn(hidden_size * 2 + v_hidden_size).astype(np.float32),
}
base_opts = {"num_heads": num_heads}
test_cases = [
({}, {}),
({}, {"scale": 0.1}),
({}, {"scale": 1.0}),
({}, {"unidirectional": 1}),
({"mask_index": right_padding_mask}, {}),
({"mask_index": left_padding_mask}, {}),
({"mask_index": np.random.randint(0, seq_len, size=(batch_size, seq_len), dtype=np.int32)}, {"mask_filter_value": -5000.0}),
({"mask_index": np.random.randint(0, seq_len, size=(batch_size, seq_len, seq_len), dtype=np.int32)}, {"mask_filter_value": -np.inf}),
# BUG: when `mask_index` is used with `unidirectional`, the first value must be True
# otherwise this will trigger a different ORT behavior where start consecutive Falses will be turned True
# e.g. mask_index = [[0, 0, 1, 0, 1, 1, 1, 1], [0, 0, 1, 0, 1, 1, 1, 1]]
# will need mask[:, :, 0:1, 0:1] = True
({"mask_index": np.array([[1, 0, 1, 0, 1, 1, 1, 1], [1, 0, 1, 0, 1, 1, 1, 1]], dtype=np.int32)}, {"unidirectional": 1}),
({ "weights": np.random.randn(input_hidden_size, hidden_size + hidden_size + 128).astype(np.float32),
"bias": np.random.randn(hidden_size + hidden_size + 128).astype(np.float32)},
{"qkv_hidden_sizes": [hidden_size, hidden_size, 128]}),
# TODO: past is not tested. ORT gives type error for input
]
for i, (extra_inps, extra_opts) in enumerate(test_cases):
with self.subTest(f"test_attention_{i}"):
inps = {**base_inps, **extra_inps}
opts = {**base_opts, **extra_opts}
outputs = ["output", "present"] if "past" in inps else ["output"]
self.helper_test_single_op("Attention", inps, opts, outputs, atol=1e-4)
def test_skip_layer_normalization(self):
shape = (2, 8, 32)
for has_beta in [True, False]:
for has_bias in [True, False]:
with self.subTest(has_beta=has_beta, has_bias=has_bias):
hidden_size = shape[-1]
inputs = {
"input": np.random.randn(*shape).astype(np.float32),
"skip": np.random.randn(*shape).astype(np.float32),
"gamma": np.random.randn(hidden_size).astype(np.float32),
}
if has_beta: inputs["beta"] = np.random.randn(hidden_size).astype(np.float32)
if has_bias: inputs["bias"] = np.random.randn(hidden_size).astype(np.float32)
attributes = {"epsilon": 1e-12}
outputs = ["output", "mean", "inv_std_var", "input_skip_bias_sum"]
self.helper_test_single_op("SkipLayerNormalization", inputs, attributes, outputs)
def test_bias_gelu(self):
shape = (2,3,4)
inputs = {
"A": np.random.randn(*shape).astype(np.float32),
"B": np.random.randn(shape[-1]).astype(np.float32)
}
attributes = {}
outputs = ["C"]
self.helper_test_single_op("BiasGelu", inputs, attributes, outputs)
def test_qlinear_add(self):
for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]:
with self.subTest(dtype=dtype, zero_point=zero_point):
dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max
inputs = {
"A": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype),
"A_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"A_zero_point": np.array(zero_point, dtype=dtype),
"B": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype),
"B_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"B_zero_point": np.array(zero_point, dtype=dtype),
"C_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"C_zero_point": np.array(zero_point, dtype=dtype)
}
attributes = {}
outputs = ["C"]
self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1)
def test_qlinear_global_average_pool(self):
for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]:
with self.subTest(dtype=dtype, zero_point=zero_point):
dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max
inputs = {
"X": np.random.randint(dtype_min, dtype_max + 1, [1, 3, 32, 32], dtype=dtype),
"x_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"x_zero_point": np.array(zero_point, dtype=dtype),
"y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32),
"y_zero_point": np.array(zero_point, dtype=dtype)
}
attributes = {"channels_last": 0}
outputs = ["C"]
self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs, atol=1)
if __name__ == "__main__":
unittest.main()