From 3d6baf410187b074d9371ee1b88c27e4e78f4e29 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Wed, 29 Sep 2021 15:42:52 +0200 Subject: [PATCH] fix: fix the type --- concrete/numpy/np_dtypes_helpers.py | 2 +- tests/numpy/test_tracing.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index e974bc8ab..dc5a0d3bf 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -27,7 +27,7 @@ NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { numpy.dtype(numpy.uint64): Integer(64, is_signed=False), numpy.dtype(numpy.float32): Float(32), numpy.dtype(numpy.float64): Float(64), - numpy.dtype(bool): Integer(32, is_signed=False), + numpy.dtype(bool): Integer(8, is_signed=False), } SUPPORTED_NUMPY_DTYPES = tuple(NUMPY_TO_COMMON_DTYPE_MAPPING) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 0b9b4c42b..2ffca1a90 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -314,6 +314,7 @@ def test_trace_numpy_supported_ufuncs(inputs, expected_output_node): assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) else: assert op_graph.output_nodes[0].outputs[0] in [ + EncryptedScalar(Integer(8, is_signed=False)), EncryptedScalar(Integer(32, is_signed=False)), EncryptedScalar(Integer(32, is_signed=True)), EncryptedScalar(Integer(64, is_signed=True)),