diff --git a/extra/onnx.py b/extra/onnx.py index b3d1301c9a..ace358665a 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -38,7 +38,7 @@ def get_run_onnx(onnx_model: ModelProto): elif len(inp.int64_data) > 0: ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) elif len(inp.int32_data) > 0: - ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False) + ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False) else: ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False) else: diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 87465dfade..666e260e7c 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -205,9 +205,10 @@ def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast( def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool) -def ConstantOfShape(input, value=0.0): +def ConstantOfShape(input, value:Tensor=None): + if value is None: value=Tensor([0.0]) shape = [int(x) for x in safe_numpy(input)] - return Tensor.ones(*shape) * value + return Tensor.ones(*shape, dtype=value.dtype) * (value if input.shape !=(0,) else 1) # this is obviously wrong, but since we don't have types, it's better than nothing def Cast(input, to):