ConstantOfShape ONNX test fixed. (#890)

* ConstantOfShape ONNX test fixed.

* removed redundant if statement

* value is optional and should default to a float32 tensor with value of 0

* fixed: default parameters are created at function definition, bad for mutable objects.
This commit is contained in:
Steven Anderson
2023-06-02 07:34:25 -07:00
committed by GitHub
parent 5feee9c94b
commit 301f7b54c6
2 changed files with 4 additions and 3 deletions

View File

@@ -38,7 +38,7 @@ def get_run_onnx(onnx_model: ModelProto):
elif len(inp.int64_data) > 0:
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
elif len(inp.int32_data) > 0:
ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
else:
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
else:

View File

@@ -205,9 +205,10 @@ def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
def ConstantOfShape(input, value=0.0):
def ConstantOfShape(input, value:Tensor=None):
if value is None: value=Tensor([0.0])
shape = [int(x) for x in safe_numpy(input)]
return Tensor.ones(*shape) * value
return Tensor.ones(*shape, dtype=value.dtype) * (value if input.shape !=(0,) else 1)
# this is obviously wrong, but since we don't have types, it's better than nothing
def Cast(input, to):