fix: fix the type

This commit is contained in:
Benoit Chevallier-Mames
2021-09-29 15:42:52 +02:00
committed by Benoit Chevallier
parent 7ea39fb77d
commit 3d6baf4101
2 changed files with 2 additions and 1 deletions

View File

@@ -27,7 +27,7 @@ NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.uint64): Integer(64, is_signed=False),
numpy.dtype(numpy.float32): Float(32),
numpy.dtype(numpy.float64): Float(64),
numpy.dtype(bool): Integer(32, is_signed=False),
numpy.dtype(bool): Integer(8, is_signed=False),
}
SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING)

View File

@@ -314,6 +314,7 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node):
assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64))
else:
assert op_graph.output_nodes[0].outputs[0] in [
EncryptedScalar(Integer(8, is_signed=False)),
EncryptedScalar(Integer(32, is_signed=False)),
EncryptedScalar(Integer(32, is_signed=True)),
EncryptedScalar(Integer(64, is_signed=True)),