chore: let us use dtype=np.uint8 if we want

closes #986
This commit is contained in:
Benoit Chevallier-Mames
2021-11-22 19:13:27 +01:00
committed by Benoit Chevallier
parent 4921ebadf8
commit 446dfaf834
2 changed files with 12 additions and 0 deletions

View File

@@ -22,8 +22,12 @@ from ..common.tracing import BaseTracer
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int8): Integer(8, is_signed=True),
numpy.dtype(numpy.int16): Integer(16, is_signed=True),
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
numpy.dtype(numpy.int64): Integer(64, is_signed=True),
numpy.dtype(numpy.uint8): Integer(8, is_signed=False),
numpy.dtype(numpy.uint16): Integer(16, is_signed=False),
numpy.dtype(numpy.uint32): Integer(32, is_signed=False),
numpy.dtype(numpy.uint64): Integer(64, is_signed=False),
numpy.dtype(numpy.float32): Float(32),

View File

@@ -16,10 +16,18 @@ from concrete.numpy.np_dtypes_helpers import (
@pytest.mark.parametrize(
"numpy_dtype,expected_common_type",
[
pytest.param(numpy.int8, Integer(8, is_signed=True)),
pytest.param("int8", Integer(8, is_signed=True)),
pytest.param(numpy.int16, Integer(16, is_signed=True)),
pytest.param("int16", Integer(16, is_signed=True)),
pytest.param(numpy.int32, Integer(32, is_signed=True)),
pytest.param("int32", Integer(32, is_signed=True)),
pytest.param(numpy.int64, Integer(64, is_signed=True)),
pytest.param("int64", Integer(64, is_signed=True)),
pytest.param(numpy.uint8, Integer(8, is_signed=False)),
pytest.param("uint8", Integer(8, is_signed=False)),
pytest.param(numpy.uint16, Integer(16, is_signed=False)),
pytest.param("uint16", Integer(16, is_signed=False)),
pytest.param(numpy.uint32, Integer(32, is_signed=False)),
pytest.param("uint32", Integer(32, is_signed=False)),
pytest.param(numpy.uint64, Integer(64, is_signed=False)),