mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
all onnx model tests pass
This commit is contained in:
@@ -74,6 +74,8 @@ def get_run_onnx(onnx_model):
|
||||
attribute_dict = {}
|
||||
for num,n in enumerate(onnx_model.graph.node):
|
||||
attribute_dict[num] = attribute_to_dict(n.attribute)
|
||||
|
||||
onnx_version = onnx_model.opset_import[0].version
|
||||
|
||||
def run_onnx(inputs={}, debug=False):
|
||||
if getenv("DEBUGONNX"): debug = True
|
||||
@@ -156,7 +158,7 @@ def get_run_onnx(onnx_model):
|
||||
i = i+s
|
||||
continue
|
||||
elif n.op_type == "Slice":
|
||||
assert onnx_model.opset_import[0].version == 10
|
||||
assert onnx_version == 10
|
||||
arg = [(0,x) for x in inp[0].shape]
|
||||
starts, ends, axes = inp[1:4]
|
||||
assert axes.shape == (1,)
|
||||
@@ -166,7 +168,14 @@ def get_run_onnx(onnx_model):
|
||||
arg[axis] = (starts, ends)
|
||||
ret = inp[0].slice(arg=arg)
|
||||
elif hasattr(onnx_ops, n.op_type):
|
||||
ret = getattr(onnx_ops, n.op_type)(*inp, **opt)
|
||||
fxn = getattr(onnx_ops, n.op_type)
|
||||
if isinstance(fxn, dict):
|
||||
for k in sorted(fxn.keys()):
|
||||
if k < onnx_version:
|
||||
real_fxn = fxn[k]
|
||||
else:
|
||||
real_fxn = fxn
|
||||
ret = real_fxn(*inp, **opt)
|
||||
else:
|
||||
print("UNSUPPORTED", n.op_type, n.input, n.output)
|
||||
raise Exception(f"op_type {n.op_type} not supported")
|
||||
@@ -174,7 +183,7 @@ def get_run_onnx(onnx_model):
|
||||
assert len(n.output) <= len(ret), f"expected output size must be less than {len(ret)}, it's {n.output}"
|
||||
if debug: print([x.shape if isinstance(x, Tensor) else None for x in ret])
|
||||
for i in range(len(n.output)): intermediate_tensors[n.output[i]] = ret[i]
|
||||
#print(ret.numpy().mean())
|
||||
#print(ret[0].numpy().mean())
|
||||
if num == ONNXLIMIT:
|
||||
output_tensor_names = n.output
|
||||
break
|
||||
|
||||
@@ -116,7 +116,9 @@ def Softplus(X): return X.softplus()
|
||||
def PRelu(X, slope): return X.leakyrelu(slope)
|
||||
def LeakyRelu(X, alpha=0.01): return X.leakyrelu(alpha)
|
||||
def ThresholdedRelu(X, alpha=1.0): return (X-alpha).relu() + (X-alpha).relu().sign() * alpha
|
||||
def Softmax(input, axis=-1): return input.softmax(axis)
|
||||
def Softmax_1(input, axis=1): return input.softmax(axis)
|
||||
def Softmax_13(input, axis=-1): return input.softmax(axis)
|
||||
Softmax = {1: Softmax_1, 13: Softmax_13} # Softmax default axis changed
|
||||
def LogSoftmax(input, axis=-1): return input.log_softmax(axis)
|
||||
def Clip(input, min=-3.4e38, max=3.4e38): return input.clip(min, max)
|
||||
|
||||
|
||||
@@ -152,89 +152,21 @@ backend_test.exclude('test_stft_*')
|
||||
backend_test.exclude('test_melweightmatrix_*')
|
||||
|
||||
# disable model tests for now since they are slow
|
||||
for x in backend_test.test_suite:
|
||||
if 'OnnxBackendRealModelTest' in str(type(x)):
|
||||
backend_test.exclude(str(x).split(" ")[0])
|
||||
|
||||
# passing node tests
|
||||
"""
|
||||
backend_test.include('test_unsqueeze_*')
|
||||
backend_test.include('test_gemm_*')
|
||||
backend_test.include('test_batchnorm_*')
|
||||
backend_test.include('test_transpose_*')
|
||||
backend_test.include('test_shape_*')
|
||||
backend_test.include('test_flatten_*')
|
||||
backend_test.include('test_sum_*')
|
||||
backend_test.include('test_global*')
|
||||
backend_test.include('test_log_softmax*')
|
||||
backend_test.include('test_softplus*')
|
||||
"""
|
||||
|
||||
# requires Less, which would be a new llop
|
||||
#backend_test.include('test_clip_*')
|
||||
|
||||
# broken empty tensor
|
||||
#backend_test.include('test_reduce_sum_*')
|
||||
#backend_test.include('test_reduce_l1_')
|
||||
|
||||
# requires cast
|
||||
#backend_test.include('test_reduce_log_sum*')
|
||||
#backend_test.include('test_pow_*')
|
||||
|
||||
# almost passing node tests
|
||||
#backend_test.include('test_PReLU*')
|
||||
#backend_test.include('test_expand_*')
|
||||
#backend_test.include('test_conv_.*')
|
||||
#backend_test.include('test_dropout_*')
|
||||
#backend_test.include('test_reshape_*')
|
||||
|
||||
# good to investigate
|
||||
#backend_test.include('test_slice_*')
|
||||
|
||||
# failing for real reasons
|
||||
#backend_test.include('test_averagepool_2d_*')
|
||||
#backend_test.include('test_maxpool_2d_*')
|
||||
|
||||
"""
|
||||
backend_test.include('test_tanh_*')
|
||||
|
||||
# should be passing (good place to start!)
|
||||
"""
|
||||
|
||||
# requires CastLike?
|
||||
#backend_test.include('test_relu_*')
|
||||
#backend_test.include('test_elu_*')
|
||||
#backend_test.include('test_leakyrelu_*')
|
||||
#backend_test.include('test_hardsigmoid_*')
|
||||
|
||||
# failing for lack of type support
|
||||
#backend_test.include('test_add_*')
|
||||
#backend_test.include('test_sub_*')
|
||||
#backend_test.include('test_div_*')
|
||||
|
||||
|
||||
# the node tests, slowly
|
||||
#backend_test.include('test_softmax_*')
|
||||
#backend_test.include('test_lrn_*')
|
||||
|
||||
# working big model tests
|
||||
#backend_test.include('test_resnet50')
|
||||
#backend_test.include('test_densenet121')
|
||||
#backend_test.include('test_vgg19')
|
||||
|
||||
"""
|
||||
# wrong big model tests
|
||||
backend_test.include('test_shufflenet')
|
||||
backend_test.include('test_inception_v2')
|
||||
backend_test.include('test_squeezenet')
|
||||
"""
|
||||
|
||||
"""
|
||||
# unsupported big model tests : LRN
|
||||
backend_test.include('test_bvlc_alexnet')
|
||||
backend_test.include('test_inception_v1')
|
||||
backend_test.include('test_zfnet512')
|
||||
"""
|
||||
if True:
|
||||
for x in backend_test.test_suite:
|
||||
if 'OnnxBackendRealModelTest' in str(type(x)):
|
||||
backend_test.exclude(str(x).split(" ")[0])
|
||||
else:
|
||||
# model tests all pass!
|
||||
backend_test.include('test_resnet50')
|
||||
backend_test.include('test_inception_v1')
|
||||
backend_test.include('test_inception_v2')
|
||||
backend_test.include('test_densenet121')
|
||||
backend_test.include('test_shufflenet')
|
||||
backend_test.include('test_squeezenet')
|
||||
backend_test.include('test_bvlc_alexnet')
|
||||
backend_test.include('test_zfnet512')
|
||||
backend_test.include('test_vgg19')
|
||||
|
||||
globals().update(backend_test.enable_report().test_cases)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user