diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index e974bc8ab..dc5a0d3bf 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -27,7 +27,7 @@ NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.uint64): Integer(64, is_signed=False), numpy.dtype(numpy.float32): Float(32), numpy.dtype(numpy.float64): Float(64), - numpy.dtype(bool): Integer(32, is_signed=False), + numpy.dtype(bool): Integer(8, is_signed=False), } SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 0b9b4c42b..2ffca1a90 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -314,6 +314,7 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node): assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) else: assert op_graph.output_nodes[0].outputs[0] in [ + EncryptedScalar(Integer(8, is_signed=False)), EncryptedScalar(Integer(32, is_signed=False)), EncryptedScalar(Integer(32, is_signed=True)), EncryptedScalar(Integer(64, is_signed=True)),