feat: support more dtype for scalars/tensors

dtype supported now: uint8, uint16, uint32, uint64
This commit is contained in:
youben11
2021-12-10 10:36:04 +01:00
committed by Ayoub Benaissa
parent 550318f67e
commit 60b2cfd9b7
5 changed files with 143 additions and 12 deletions

View File

@@ -47,9 +47,15 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
executionArguments args);
// Create a lambdaArgument from a tensor
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensor(
// Create a lambdaArgument from a tensor of different data types
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
std::vector<uint16_t> data, std::vector<int64_t> dimensions);
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
std::vector<uint32_t> data, std::vector<int64_t> dimensions);
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
std::vector<uint64_t> data, std::vector<int64_t> dimensions);
// Create a lambdaArgument from a scalar
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar);
// Check if a lambdaArgument holds a tensor

View File

@@ -45,7 +45,22 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor", lambdaArgumentFromTensor)
.def_static("from_tensor",
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU8(tensor, dims);
})
.def_static("from_tensor",
[](std::vector<uint16_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU16(tensor, dims);
})
.def_static("from_tensor",
[](std::vector<uint32_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU32(tensor, dims);
})
.def_static("from_tensor",
[](std::vector<uint64_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU64(tensor, dims);
})
.def_static("from_scalar", lambdaArgumentFromScalar)
.def("is_tensor",
[](lambdaArgument &lambda_arg) {

View File

@@ -10,6 +10,11 @@ from mlir._mlir_libs._zamalang._compiler import library as _library
import numpy as np
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
def _lookup_runtime_lib() -> str:
"""Try to find the absolute path to the runtime library.
@@ -95,10 +100,10 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
Returns:
_LambdaArgument: lambda argument holding the appropriate value
"""
if not isinstance(value, (int, np.ndarray, np.uint8)):
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint8")
if isinstance(value, (int, np.uint8)):
if not (0 <= value < (2 ** 64 - 1)):
if not isinstance(value, ACCEPTED_TYPES):
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
if isinstance(value, ACCEPTED_INTS):
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
raise TypeError(
"single integer must be in the range [0, 2**64 - 1] (uint64)"
)
@@ -107,8 +112,8 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
assert isinstance(value, np.ndarray)
if value.shape == ():
return _LambdaArgument.from_scalar(value)
if value.dtype != np.uint8:
raise TypeError("numpy.array must be of dtype uint8")
if value.dtype not in ACCEPTED_NUMPY_UINTS:
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)

View File

@@ -143,14 +143,38 @@ uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
return arg->getValue();
}
lambdaArgument lambdaArgumentFromTensor(std::vector<uint8_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument lambdaArgumentFromTensorU8(std::vector<uint8_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU16(std::vector<uint16_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU32(std::vector<uint32_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromTensorU64(std::vector<uint64_t> data,
std::vector<int64_t> dimensions) {
lambdaArgument tensor_arg{
std::make_shared<mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
return tensor_arg;
}
lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
lambdaArgument scalar_arg{
std::make_shared<mlir::zamalang::IntLambdaArgument<uint64_t>>(scalar)};

View File

@@ -43,6 +43,39 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
8,
id="add_eint_int_with_np_uint8_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
""",
(np.uint16(3), np.uint16(5)),
8,
id="add_eint_int_with_np_uint16_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
""",
(np.uint32(3), np.uint32(5)),
8,
id="add_eint_int_with_np_uint32_as_scalar",
),
pytest.param(
"""
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
return %1: !HLFHE.eint<7>
}
""",
(np.uint64(3), np.uint64(5)),
8,
id="add_eint_int_with_np_uint64_as_scalar",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
@@ -57,7 +90,55 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
np.array([4, 3, 2, 1], dtype=np.uint8),
),
20,
id="dot_eint_int",
id="dot_eint_int_uint8",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint16),
np.array([4, 3, 2, 1], dtype=np.uint16),
),
20,
id="dot_eint_int_uint16",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint32),
np.array([4, 3, 2, 1], dtype=np.uint32),
),
20,
id="dot_eint_int_uint32",
),
pytest.param(
"""
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
{
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
return %ret : !HLFHE.eint<7>
}
""",
(
np.array([1, 2, 3, 4], dtype=np.uint64),
np.array([4, 3, 2, 1], dtype=np.uint64),
),
20,
id="dot_eint_int_uint64",
),
pytest.param(
"""