feat: add support for all numpy integer types

This commit is contained in:
Umut
2022-03-08 14:30:07 +03:00
parent 47fe98640d
commit 62d9872cf2

View File

@@ -22,14 +22,24 @@ from ..common.tracing import BaseTracer
from ..common.values import BaseValue, TensorValue
NUMPY_TO_COMMON_DTYPE_MAPPING: Dict[numpy.dtype, BaseDataType] = {
numpy.dtype(numpy.int8): Integer(8, is_signed=True),
numpy.dtype(numpy.int16): Integer(16, is_signed=True),
numpy.dtype(numpy.int32): Integer(32, is_signed=True),
numpy.dtype(numpy.int64): Integer(64, is_signed=True),
numpy.dtype(numpy.uint8): Integer(8, is_signed=False),
numpy.dtype(numpy.uint16): Integer(16, is_signed=False),
numpy.dtype(numpy.uint32): Integer(32, is_signed=False),
numpy.dtype(numpy.uint64): Integer(64, is_signed=False),
numpy.dtype(numpy.byte): Integer(numpy.byte(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.short): Integer(numpy.short(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.intc): Integer(numpy.intc(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int_): Integer(numpy.int_(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.longlong): Integer(numpy.longlong(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int8): Integer(numpy.int8(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int16): Integer(numpy.int16(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int32): Integer(numpy.int32(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.int64): Integer(numpy.int64(0).nbytes * 8, is_signed=True),
numpy.dtype(numpy.ubyte): Integer(numpy.ubyte(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.ushort): Integer(numpy.ushort(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uintc): Integer(numpy.uintc(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint): Integer(numpy.uint(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.ulonglong): Integer(numpy.ulonglong(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint8): Integer(numpy.uint8(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint16): Integer(numpy.uint16(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint32): Integer(numpy.uint32(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.uint64): Integer(numpy.uint64(0).nbytes * 8, is_signed=False),
numpy.dtype(numpy.float16): Float(16),
numpy.dtype(numpy.float32): Float(32),
numpy.dtype(numpy.float64): Float(64),