mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
ConstantOfShape ONNX test fixed. (#890)
* ConstantOfShape ONNX test fixed. * removed redundant if statement * value is optional and should default to a float32 tensor with value of 0 * fixed: default parameters are created at function definition, bad for mutable objects.
This commit is contained in:
@@ -38,7 +38,7 @@ def get_run_onnx(onnx_model: ModelProto):
|
||||
elif len(inp.int64_data) > 0:
|
||||
ret = Tensor(np.array(inp.int64_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
elif len(inp.int32_data) > 0:
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.float32).reshape(inp.dims), requires_grad=False)
|
||||
ret = Tensor(np.array(inp.int32_data, dtype=np.int32).reshape(inp.dims), requires_grad=False)
|
||||
else:
|
||||
ret = Tensor(np.frombuffer(inp.raw_data, dtype=tensor_dtype_to_np_dtype(inp.data_type)).reshape(inp.dims).astype(np.float32).copy(), requires_grad=False)
|
||||
else:
|
||||
|
||||
@@ -205,9 +205,10 @@ def Or(x:Tensor, y:Tensor): return Where((x==y), x, Tensor.ones(*x.shape)).cast(
|
||||
def Xor(x:Tensor, y:Tensor): return Where((x==y), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
def Not(x:Tensor): return Where((x==1), Tensor.zeros(*x.shape), Tensor.ones(*x.shape)).cast(dtypes.bool)
|
||||
|
||||
def ConstantOfShape(input, value=0.0):
|
||||
def ConstantOfShape(input, value:Tensor=None):
|
||||
if value is None: value=Tensor([0.0])
|
||||
shape = [int(x) for x in safe_numpy(input)]
|
||||
return Tensor.ones(*shape) * value
|
||||
return Tensor.ones(*shape, dtype=value.dtype) * (value if input.shape !=(0,) else 1)
|
||||
|
||||
# this is obviously wrong, but since we don't have types, it's better than nothing
|
||||
def Cast(input, to):
|
||||
|
||||
Reference in New Issue
Block a user