mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move test over (#10508)
This commit is contained in:
62
test/external/external_test_onnx_ops.py
vendored
62
test/external/external_test_onnx_ops.py
vendored
@@ -8,6 +8,8 @@ from tinygrad import dtypes
|
||||
from tinygrad.frontend.onnx import OnnxRunner
|
||||
import numpy as np
|
||||
from extra.onnx_helpers import validate
|
||||
from onnx.defs import ONNX_DOMAIN, AI_ONNX_PREVIEW_TRAINING_DOMAIN
|
||||
MICROSOFT_CONTRIB_OPS_DOMAIN = "com.microsoft"
|
||||
|
||||
class TestOnnxOps(unittest.TestCase):
|
||||
DOMAIN = None
|
||||
@@ -26,7 +28,7 @@ class TestOnnxOps(unittest.TestCase):
|
||||
validate(tmp.name, inps, rtol, atol)
|
||||
|
||||
class TestMainOnnxOps(TestOnnxOps):
|
||||
DOMAIN = ""
|
||||
DOMAIN = ONNX_DOMAIN
|
||||
def test_reshape(self):
|
||||
inputs = {"in": np.arange(6, dtype=np.float32), "shape": np.array([2,3], dtype=np.int64)}
|
||||
attributes = {}
|
||||
@@ -195,8 +197,64 @@ class TestMainOnnxOps(TestOnnxOps):
|
||||
def test_qlinearmatmul_2D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 2)
|
||||
def test_qlinearmatmul_3D_int8_float32(self): self._run_qlinearmatmul_test(np.int8, np.float32, 3)
|
||||
|
||||
class TestTrainingOnnxOps(TestOnnxOps):
|
||||
# NOTE: ORT doesn't actually support training ops on cpu so we test using functions provided by onnx
|
||||
DOMAIN = AI_ONNX_PREVIEW_TRAINING_DOMAIN
|
||||
def _validate_training(self, op:str, onnx_fxn, inps:dict[str, np.ndarray], opts:dict[str, Any], outs:list[str]):
|
||||
model = self.helper_build_model(op, inps, opts, outs)
|
||||
if op == "Momentum": del opts['mode']
|
||||
runner = OnnxRunner(model)
|
||||
tiny_out = runner(inps)
|
||||
onnx_out = onnx_fxn(**inps, **opts)
|
||||
for (nm, t_out), o_out in zip(tiny_out.items(), onnx_out):
|
||||
np.testing.assert_allclose(t_out.numpy(), o_out, rtol=1e-3, atol=1e-6, err_msg=f"{nm} failed")
|
||||
|
||||
def test_adagrad_t_greater_than_zero(self):
|
||||
from onnx.backend.test.case.node.adagrad import apply_adagrad
|
||||
for t in [1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
"x": np.random.randn(3, 3).astype(np.float32),
|
||||
"g": np.random.randn(3, 3).astype(np.float32),
|
||||
"h": np.random.randn(3, 3).astype(np.float32),
|
||||
}
|
||||
attributes = {"decay_factor": 0.1, "epsilon": 1e-6, "norm_coefficient": 0.01}
|
||||
outputs = ["X_out", "H_out"]
|
||||
self._validate_training("Adagrad", apply_adagrad, inputs, attributes, outputs)
|
||||
|
||||
def test_momentum_t_greater_than_zero(self):
|
||||
from onnx.backend.test.case.node.momentum import apply_momentum, apply_nesterov
|
||||
for onnx_fxn, mode in ((apply_momentum, "standard"), (apply_nesterov, "nesterov")):
|
||||
for t in [1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
"x": np.random.randn(3, 3).astype(np.float32),
|
||||
"g": np.random.randn(3, 3).astype(np.float32),
|
||||
"v": np.random.randn(3, 3).astype(np.float32),
|
||||
}
|
||||
attributes = {"alpha": 0.9, "beta": 0.1, "mode": mode, "norm_coefficient": 0.01}
|
||||
outputs = ["X_out", "V_out"]
|
||||
self._validate_training("Momentum", onnx_fxn, inputs, attributes, outputs)
|
||||
|
||||
def test_adam_t_greater_than_zero(self):
|
||||
from onnx.backend.test.case.node.adam import apply_adam
|
||||
for t in [1, 3, 100]:
|
||||
inputs = {
|
||||
"r": np.array(0.01, dtype=np.float32),
|
||||
"t": np.array(t, dtype=np.int32),
|
||||
"x": np.random.randn(3, 3).astype(np.float32),
|
||||
"g": np.random.randn(3, 3).astype(np.float32),
|
||||
"v": np.random.randn(3, 3).astype(np.float32),
|
||||
"h": np.random.randn(3, 3).astype(np.float32),
|
||||
}
|
||||
attributes = { "alpha": 0.9, "beta": 0.999, "epsilon": 1e-8, "norm_coefficient": 0.01, "norm_coefficient_post": 0.02 }
|
||||
outputs = ["X_new", "V_new", "H_new"]
|
||||
self._validate_training("Adam", apply_adam, inputs, attributes, outputs)
|
||||
|
||||
class TestContribOnnxOps(TestOnnxOps):
|
||||
DOMAIN = "com.microsoft"
|
||||
DOMAIN = MICROSOFT_CONTRIB_OPS_DOMAIN
|
||||
def test_attention(self):
|
||||
batch_size, seq_len, input_hidden_size = 2, 8, 256
|
||||
num_heads, head_size = 4, 64
|
||||
|
||||
Reference in New Issue
Block a user