diff --git a/concrete/numpy/np_dtypes_helpers.py b/concrete/numpy/np_dtypes_helpers.py index 617b02199..1e3307644 100644 --- a/concrete/numpy/np_dtypes_helpers.py +++ b/concrete/numpy/np_dtypes_helpers.py @@ -22,14 +22,24 @@ from ..common.tracing import BaseTracer from ..common.values import BaseValue, TensorValue NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = { - numpy.dtype(numpy.int8): Integer(8, is_signed=True), - numpy.dtype(numpy.int16): Integer(16, is_signed=True), - numpy.dtype(numpy.int32): Integer(32, is_signed=True), - numpy.dtype(numpy.int64): Integer(64, is_signed=True), - numpy.dtype(numpy.uint8): Integer(8, is_signed=False), - numpy.dtype(numpy.uint16): Integer(16, is_signed=False), - numpy.dtype(numpy.uint32): Integer(32, is_signed=False), - numpy.dtype(numpy.uint64): Integer(64, is_signed=False), + numpy.dtype(numpy.byte): Integer(numpy.byte(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.short): Integer(numpy.short(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.intc): Integer(numpy.intc(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.int_): Integer(numpy.int_(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.longlong): Integer(numpy.longlong(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.int8): Integer(numpy.int8(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.int16): Integer(numpy.int16(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.int32): Integer(numpy.int32(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.int64): Integer(numpy.int64(0).nbytes * 8, is_signed=True), + numpy.dtype(numpy.ubyte): Integer(numpy.ubyte(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.ushort): Integer(numpy.ushort(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uintc): Integer(numpy.uintc(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uint): Integer(numpy.uint(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.ulonglong): Integer(numpy.ulonglong(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uint8): Integer(numpy.uint8(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uint16): Integer(numpy.uint16(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uint32): Integer(numpy.uint32(0).nbytes * 8, is_signed=False), + numpy.dtype(numpy.uint64): Integer(numpy.uint64(0).nbytes * 8, is_signed=False), numpy.dtype(numpy.float16): Float(16), numpy.dtype(numpy.float32): Float(32), numpy.dtype(numpy.float64): Float(64),