diff --git a/extra/onnx.py b/extra/onnx.py index 61a557bb4b..86fe64d0fb 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -153,7 +153,7 @@ def get_run_onnx(onnx_model: ModelProto): if n.op_type == "Add": ret = inp[0] + inp[1] if n.op_type == "Sub": ret = inp[0] - inp[1] if n.op_type == "Mul": ret = inp[0] * inp[1] - if n.op_type == "Pow": ret = inp[0] ** inp[1] + if n.op_type == "Pow": ret = (inp[0] ** inp[1]).cast(inp[0].dtype) elif n.op_type == "Split": if 'split' not in opt: opt['split'] = [int(x) for x in safe_numpy(inp[1])] # split can be a tensor if 'axis' not in opt: opt['axis'] = 0