From f0b24d230cbe0aad509316e13e84152f66bc6d00 Mon Sep 17 00:00:00 2001 From: geohotstan <135171913+geohotstan@users.noreply.github.com> Date: Tue, 25 Feb 2025 05:15:22 +0800 Subject: [PATCH] add test_onnx_ops.py (#8569) * boom * fix webgpu * use exact variable names in test so that AI can read easier * add tag for specific test name like test a specific dtype * fix ruff * astype everything * dtype in array creation * just arange * is 67% considered fixed? * move test up * small cleanups * share function * add qgemm as well * add qgemm too * make sure qgemm comes out as int * take out qgemm for now * fixed test * add correct qgemm * addressing feedback here too, early naive fix for now * simplify bias and c to be minimalistic enough to test correctness * refactored qlinearops * maybe these asserts aren't the best.. * fix test * updated tests to cover new ops * try to add to CI * move test_onnx_ops into testextra/ * more attention tests * qlinear_add atol=1 * attention still not fullllllly correct * it is what it is --------- Co-authored-by: chenyu --- .github/workflows/test.yml | 2 + extra/onnx.py | 79 ++++++----- extra/onnx_helpers.py | 7 +- setup.py | 1 + test/external/external_test_onnx_ops.py | 179 ++++++++++++++++++++++++ 5 files changed, 230 insertions(+), 38 deletions(-) create mode 100644 test/external/external_test_onnx_ops.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ffbc735144..670ba2bf2e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -385,6 +385,8 @@ jobs: run: CPU=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 - name: Test ONNX (LLVM) run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 + - name: Test Additional ONNX Ops (CPU) + run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py - name: Run CLOUD=1 Test run: | CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py diff --git a/extra/onnx.py b/extra/onnx.py index 96f8e07de0..bf704dc514 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -287,6 +287,7 @@ def get_onnx_ops(): Softmax = {1:Softmax_1, 13:Softmax_13} def HardSigmoid(x:Tensor, alpha:float=0.2, beta:float=0.5): return (alpha*x + beta).clip(0, 1) def Gelu(x:Tensor, approximate:str|None=None): return x.gelu() if approximate == "tanh" else 0.5 * x * (1 + (x/math.sqrt(2)).erf()) + def BiasGelu(x: Tensor, bias: Tensor, approximate: str | None = None) -> Tensor: return Gelu(x + bias, approximate) def FastGelu(x:Tensor, bias:Tensor|None=None): # this is tanh approximated return (x + bias).gelu() if bias is not None else x.gelu() @@ -548,8 +549,11 @@ def get_onnx_ops(): def MeanVarianceNormalization(x:Tensor, axis:list[int]=[0,2,3]): return (x - x.mean(axis, keepdim=True)) / (x.std(axis, keepdim=True, correction=0) + 1e-9) def SkipLayerNormalization(x:Tensor, skip:Tensor, gamma:Tensor, beta:Tensor|None=None, bias:Tensor|None=None, epsilon:float=1e-12): - x = x + skip + bias - return x.layernorm(eps=epsilon) * gamma + beta, None, None, x + x = x + skip + if bias is not None: x = x + bias + ret = x.layernorm(eps=epsilon) * gamma + if beta is not None: ret = ret + beta + return ret, None, None, x def EmbedLayerNormalization(input_ids: Tensor, segment_ids:Tensor, word_embedding:Tensor, position_embedding:Tensor, segment_embedding:Tensor, gamma=None, beta=None, mask:Tensor|None=None, position_ids:Tensor|None=None, epsilon=1e-12, mask_index_type=0): @@ -616,44 +620,49 @@ def get_onnx_ops(): base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) - def Attention(x:Tensor, weights, bias:Tensor, mask_index:Tensor|None=None, past:Tensor|None=None, - relative_position_bias:Tensor|None=None, past_sequence_length:Tensor|None=None, do_rotary:int|None=None, - mask_filter_value:float|None=None, num_heads:int|None=None, past_present_share_buffer:int|None=None, - qkv_hidden_sizes:list[int]|None=None, scale:float|None=None, unidirectional:int|None=None): - # https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Attention - assert num_heads is not None # required - assert (qkv_hidden_sizes is None and past is not None) or (qkv_hidden_sizes is not None) - assert relative_position_bias is do_rotary is past_sequence_length is mask_filter_value is past_present_share_buffer is scale is None, \ - "functionality not supported yet" # TODO strange params - hidden_size, v_hidden_size = qkv_hidden_sizes[1:] if qkv_hidden_sizes is not None else 2*(weights.shape[1] // 3,) + def Attention(x:Tensor, weights:Tensor, bias:Tensor|None=None, mask_index:Tensor|None=None, past:Tensor|None=None, attention_bias:Tensor|None=None, + past_sequence_length:Tensor|None=None, do_rotary:int=0, mask_filter_value:float=-10000.0, num_heads:int|None=None, + past_present_share_buffer:int|None=None, qkv_hidden_sizes:list[int]|None=None, rotary_embedding_dim:int|None=None, + scale:float|None=None, unidirectional:int=0): + assert not do_rotary and not attention_bias, "TODO" + if qkv_hidden_sizes is None: qkv_hidden_sizes = [weights.shape[1] // 3] * 3 + qkv = x.linear(weights, bias) + q, k, v = qkv.split(qkv_hidden_sizes, dim=2) - if unidirectional: # gpt-style - assert hidden_size == v_hidden_size - xqkv = x.linear(weights, bias) - xq, xk, xv = [xqkv.shrink([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)] - else: # bert-style - wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:] - bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None - xq, xk, xv = [x.linear(w, b) for w, b in zip((wq, wk, wv), (bq, bk, bv))] - xq, xk, xv = [x.reshape(x.shape[0], x.shape[1], num_heads, -1).transpose(1, 2) for x in (xq, xk, xv)] + batch_size, seq_len, _ = x.shape + q_head_size, k_head_size, v_head_size = (sz // num_heads for sz in qkv_hidden_sizes) + q, k, v = (x.reshape(batch_size, seq_len, num_heads, hsz).transpose(1, 2) for x, hsz in zip((q, k, v), (q_head_size, k_head_size, v_head_size))) + present = None if past is not None: - xk, xv = Tensor.cat(past[0], xk, dim=-2), Tensor.cat(past[1], xv, dim=-2) - present = Tensor.cat(xk.unsqueeze(0), xv.unsqueeze(0)) + k, v = past[0].cat(k, dim=2), past[1].cat(v, dim=2) + present = k.stack(v) - def attn(query, key, value, attn_mask): - query_length, key_length = query.shape[-2], key.shape[-2] - cdim = max(query_length, key_length) + 1 - attn_weights = query @ key.transpose(-1, -2) / math.sqrt(value.shape[-1]) - # This is where Tensor.scaled_dot_product_attention differs: - causal_mask = Tensor.ones((cdim, cdim), requires_grad=False, dtype=dtypes.bool).tril(0)[key_length - query_length : key_length, :key_length] - masked = Tensor.where(causal_mask, attn_weights, -math.inf) - if attn_mask is not None: masked = masked + attn_mask - return masked.softmax(-1) @ value + if scale is None: scale = 1.0 / math.sqrt(q_head_size) + attn_scores = q @ k.transpose(-1, -2) * scale - bsz, _, seq_len, _ = xq.shape - out = attn(xq, xk, xv, mask_index).transpose(1, 2).reshape(bsz, seq_len, -1) - return out, present if past is not None else out + if mask_index is not None: + assert 4 >= mask_index.ndim >= 1, f"{mask_index.ndim=}" + if mask_index.ndim != 1: mask = mask_index.bool() + else: + if mask_index.shape[0] == batch_size: + mask = Tensor.arange(attn_scores.shape[-1], requires_grad=False, device=mask_index.device).unsqueeze(0) < mask_index.unsqueeze(1) + elif mask_index.shape[0] == 2*batch_size: + end_positions = mask_index[:batch_size] + start_positions = mask_index[batch_size:] + arange = Tensor.arange(seq_len).unsqueeze(0) + mask = (arange < end_positions.unsqueeze(1)) & (arange >= start_positions.unsqueeze(1)) + else: raise NotImplementedError("mask_index with shape (3 * batch_size + 2) is not implemented") + while mask.ndim < 4: mask = mask.unsqueeze(1) + attn_scores = mask.where(attn_scores, mask_filter_value) + + if unidirectional: + causal_mask = Tensor.ones((seq_len, seq_len), dtype=dtypes.bool).tril() + attn_scores = causal_mask.where(attn_scores, mask_filter_value) + + output = attn_scores.softmax(-1) @ v + output = output.transpose(1, 2).reshape(batch_size, seq_len, -1) + return output, present # ***** Indexing Ops ***** def ArrayFeatureExtractor(x:Tensor, indices:Tensor): return x[..., indices] diff --git a/extra/onnx_helpers.py b/extra/onnx_helpers.py index cbd6ee7163..5c3b64bd38 100644 --- a/extra/onnx_helpers.py +++ b/extra/onnx_helpers.py @@ -16,7 +16,6 @@ def get_example_inputs(graph_inputs:dict[str, OnnxValue]): def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5): run_onnx = OnnxRunner(onnx.load(onnx_file)) - tinygrad_out = run_onnx(inputs) ort_options = ort.SessionOptions() ort_options.log_severity_level = 3 @@ -26,8 +25,10 @@ def validate(onnx_file, inputs, rtol=1e-5, atol=1e-5): out_values = ort_sess.run(out_names, np_inputs) ort_out = dict(zip(out_names, out_values)) - assert len(tinygrad_out) == len(ort_out) and tinygrad_out.keys() == ort_out.keys() + tinygrad_out = run_onnx(inputs) + + assert tinygrad_out.keys() == ort_out.keys() for k in tinygrad_out.keys(): tiny_v, onnx_v = tinygrad_out[k], ort_out[k] - if tiny_v is None: assert tiny_v == onnx_v + if tiny_v is None: assert onnx_v is None, f"{k}: {tiny_v=}, {onnx_v=}" else: np.testing.assert_allclose(tiny_v.numpy(), onnx_v, rtol=rtol, atol=atol, err_msg=f"For tensor '{k}' in {tinygrad_out.keys()}") \ No newline at end of file diff --git a/setup.py b/setup.py index 5618d77c90..0d1546c558 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,7 @@ setup(name='tinygrad', "pillow", "onnx==1.16.0", "onnx2torch", + "onnxruntime", "opencv-python", "tabulate", "tqdm", diff --git a/test/external/external_test_onnx_ops.py b/test/external/external_test_onnx_ops.py new file mode 100644 index 0000000000..884f94421c --- /dev/null +++ b/test/external/external_test_onnx_ops.py @@ -0,0 +1,179 @@ +# inputs, attributes, and outputs for tests are found here: +# https://github.com/onnx/onnx/blob/main/docs/Operators.md +# https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md + +from typing import Any +import unittest, onnx, tempfile +import numpy as np +from extra.onnx_helpers import validate + +class TestOnnxOps(unittest.TestCase): + DOMAIN = None + def helper_test_single_op(self, op:str, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str], rtol=1e-3, atol=1e-6): + onnx_inputs = [onnx.helper.make_tensor_value_info(name, onnx.helper.np_dtype_to_tensor_dtype(arr.dtype), arr.shape) for name, arr in inps.items()] + onnx_outputs = [onnx.helper.make_empty_tensor_value_info(name) for name in outs] + nodes = [onnx.helper.make_node(op, list(inps), list(outs), domain=self.DOMAIN, **opts)] + graph = onnx.helper.make_graph(nodes, f"test_{op.lower()}", onnx_inputs, onnx_outputs) + model = onnx.helper.make_model(graph, producer_name=f"test_{op.lower()}") + with tempfile.NamedTemporaryFile() as tmp: + onnx.save(model, tmp.name) + validate(tmp.name, inps, rtol, atol) + +class TestMainOnnxOps(TestOnnxOps): + DOMAIN = "" + def test_reshape(self): + inputs = {"in": np.arange(6, dtype=np.float32), "shape": np.array([2,3], dtype=np.int64)} + attributes = {} + outputs = ["out"] + self.helper_test_single_op("Reshape", inputs, attributes, outputs) + + def test_qlinear_conv(self): + for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]: + for b in (np.ones([32], dtype=np.int32), np.zeros([32], dtype=np.int32)): + with self.subTest(dtype=dtype, zero_point=zero_point): + dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max + inputs = { + "x": np.random.randint(dtype_min, dtype_max + 1, [1, 3, 224, 224], dtype=dtype), + "x_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "x_zero_point": np.array(zero_point, dtype=dtype), + "w": np.random.randint(dtype_min, dtype_max + 1, [32, 3, 3, 3], dtype=dtype), + "w_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "w_zero_point": np.array(zero_point, dtype=dtype), + "y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "y_zero_point": np.array(zero_point, dtype=dtype), + "b": b + } + attributes = {'auto_pad': 'NOTSET', 'dilations': (1, 1), 'group': 1, 'kernel_shape': (3, 3), 'pads': (1, 1, 1, 1), 'strides': (2, 2)} + outputs = ["out"] + self.helper_test_single_op("QLinearConv", inputs, attributes, outputs, atol=1) + + def test_qlinear_matmul(self): + for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]: + with self.subTest(dtype=dtype, zero_point=zero_point): + dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max + inputs = { + "A": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype), + "A_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "A_zero_point": np.array(zero_point, dtype=dtype), + "B": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype), + "B_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "B_zero_point": np.array(zero_point, dtype=dtype), + "Y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "Y_zero_point": np.array(zero_point, dtype=dtype) + } + attributes = {} + outputs = ["Y"] + self.helper_test_single_op("QLinearMatMul", inputs, attributes, outputs, atol=1) + +class TestContribOnnxOps(TestOnnxOps): + DOMAIN = "com.microsoft" + + def test_attention(self): + batch_size, seq_len, input_hidden_size = 2, 8, 256 + num_heads, head_size = 4, 64 + hidden_size = num_heads * head_size + v_hidden_size = hidden_size + + # for mask_index + right_padding_mask = np.random.randint(1, seq_len + 1, size=(batch_size,), dtype=np.int32) + end_positions = np.random.randint(1, seq_len + 1, size=(batch_size,), dtype=np.int32) + start_positions = np.array([np.random.randint(0, end) for end in end_positions], dtype=np.int32) + left_padding_mask = np.concatenate([end_positions, start_positions]) + + base_inps = { + "input": np.random.randn(batch_size, seq_len, input_hidden_size).astype(np.float32), + "weights": np.random.randn(input_hidden_size, hidden_size * 3).astype(np.float32), + # bias is required in ORT (segfaults otherwise), eventhough docs says it's optional + "bias": np.random.randn(hidden_size * 2 + v_hidden_size).astype(np.float32), + } + base_opts = {"num_heads": num_heads} + + test_cases = [ + ({}, {}), + ({}, {"scale": 0.1}), + ({}, {"scale": 1.0}), + ({}, {"unidirectional": 1}), + ({"mask_index": right_padding_mask}, {}), + ({"mask_index": left_padding_mask}, {}), + ({"mask_index": np.random.randint(0, seq_len, size=(batch_size, seq_len), dtype=np.int32)}, {"mask_filter_value": -5000.0}), + ({"mask_index": np.random.randint(0, seq_len, size=(batch_size, seq_len, seq_len), dtype=np.int32)}, {"mask_filter_value": -np.inf}), + # BUG: when `mask_index` is used with `unidirectional`, the first value must be True + # otherwise this will trigger a different ORT behavior where start consecutive Falses will be turned True + # e.g. mask_index = [[0, 0, 1, 0, 1, 1, 1, 1], [0, 0, 1, 0, 1, 1, 1, 1]] + # will need mask[:, :, 0:1, 0:1] = True + ({"mask_index": np.array([[1, 0, 1, 0, 1, 1, 1, 1], [1, 0, 1, 0, 1, 1, 1, 1]], dtype=np.int32)}, {"unidirectional": 1}), + ({ "weights": np.random.randn(input_hidden_size, hidden_size + hidden_size + 128).astype(np.float32), + "bias": np.random.randn(hidden_size + hidden_size + 128).astype(np.float32)}, + {"qkv_hidden_sizes": [hidden_size, hidden_size, 128]}), + # TODO: past is not tested. ORT gives type error for input + ] + + for i, (extra_inps, extra_opts) in enumerate(test_cases): + with self.subTest(f"test_attention_{i}"): + inps = {**base_inps, **extra_inps} + opts = {**base_opts, **extra_opts} + outputs = ["output", "present"] if "past" in inps else ["output"] + self.helper_test_single_op("Attention", inps, opts, outputs, atol=1e-4) + + def test_skip_layer_normalization(self): + shape = (2, 8, 32) + for has_beta in [True, False]: + for has_bias in [True, False]: + with self.subTest(has_beta=has_beta, has_bias=has_bias): + hidden_size = shape[-1] + inputs = { + "input": np.random.randn(*shape).astype(np.float32), + "skip": np.random.randn(*shape).astype(np.float32), + "gamma": np.random.randn(hidden_size).astype(np.float32), + } + if has_beta: inputs["beta"] = np.random.randn(hidden_size).astype(np.float32) + if has_bias: inputs["bias"] = np.random.randn(hidden_size).astype(np.float32) + attributes = {"epsilon": 1e-12} + outputs = ["output", "mean", "inv_std_var", "input_skip_bias_sum"] + self.helper_test_single_op("SkipLayerNormalization", inputs, attributes, outputs) + + def test_bias_gelu(self): + shape = (2,3,4) + inputs = { + "A": np.random.randn(*shape).astype(np.float32), + "B": np.random.randn(shape[-1]).astype(np.float32) + } + attributes = {} + outputs = ["C"] + self.helper_test_single_op("BiasGelu", inputs, attributes, outputs) + + def test_qlinear_add(self): + for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]: + with self.subTest(dtype=dtype, zero_point=zero_point): + dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max + inputs = { + "A": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype), + "A_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "A_zero_point": np.array(zero_point, dtype=dtype), + "B": np.random.randint(dtype_min, dtype_max + 1, [10, 10], dtype=dtype), + "B_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "B_zero_point": np.array(zero_point, dtype=dtype), + "C_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "C_zero_point": np.array(zero_point, dtype=dtype) + } + attributes = {} + outputs = ["C"] + self.helper_test_single_op("QLinearAdd", inputs, attributes, outputs, atol=1) + + def test_qlinear_global_average_pool(self): + for dtype, zero_point in [(np.uint8, 128), (np.int8, 0)]: + with self.subTest(dtype=dtype, zero_point=zero_point): + dtype_min, dtype_max = np.iinfo(dtype).min, np.iinfo(dtype).max + inputs = { + "X": np.random.randint(dtype_min, dtype_max + 1, [1, 3, 32, 32], dtype=dtype), + "x_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "x_zero_point": np.array(zero_point, dtype=dtype), + "y_scale": np.array(np.random.uniform(0.01, 0.1), dtype=np.float32), + "y_zero_point": np.array(zero_point, dtype=dtype) + } + attributes = {"channels_last": 0} + outputs = ["C"] + self.helper_test_single_op("QLinearGlobalAveragePool", inputs, attributes, outputs, atol=1) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file