mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: support more dtype for scalars/tensors
dtype supported now: uint8, uint16, uint32, uint64
This commit is contained in:
@@ -47,9 +47,15 @@ MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
|
||||
executionArguments args);
|
||||
|
||||
// Create a lambdaArgument from a tensor
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensor(
|
||||
// Create a lambdaArgument from a tensor of different data types
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU8(
|
||||
std::vector<uint8_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU16(
|
||||
std::vector<uint16_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU32(
|
||||
std::vector<uint32_t> data, std::vector<int64_t> dimensions);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromTensorU64(
|
||||
std::vector<uint64_t> data, std::vector<int64_t> dimensions);
|
||||
// Create a lambdaArgument from a scalar
|
||||
MLIR_CAPI_EXPORTED lambdaArgument lambdaArgumentFromScalar(uint64_t scalar);
|
||||
// Check if a lambdaArgument holds a tensor
|
||||
|
||||
@@ -45,7 +45,22 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor", lambdaArgumentFromTensor)
|
||||
.def_static("from_tensor",
|
||||
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU8(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
[](std::vector<uint16_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU16(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
[](std::vector<uint32_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU32(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
[](std::vector<uint64_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU64(tensor, dims);
|
||||
})
|
||||
.def_static("from_scalar", lambdaArgumentFromScalar)
|
||||
.def("is_tensor",
|
||||
[](lambdaArgument &lambda_arg) {
|
||||
|
||||
@@ -10,6 +10,11 @@ from mlir._mlir_libs._zamalang._compiler import library as _library
|
||||
import numpy as np
|
||||
|
||||
|
||||
ACCEPTED_NUMPY_UINTS = (np.uint8, np.uint16, np.uint32, np.uint64)
|
||||
ACCEPTED_INTS = (int,) + ACCEPTED_NUMPY_UINTS
|
||||
ACCEPTED_TYPES = (np.ndarray,) + ACCEPTED_INTS
|
||||
|
||||
|
||||
def _lookup_runtime_lib() -> str:
|
||||
"""Try to find the absolute path to the runtime library.
|
||||
|
||||
@@ -95,10 +100,10 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
Returns:
|
||||
_LambdaArgument: lambda argument holding the appropriate value
|
||||
"""
|
||||
if not isinstance(value, (int, np.ndarray, np.uint8)):
|
||||
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint8")
|
||||
if isinstance(value, (int, np.uint8)):
|
||||
if not (0 <= value < (2 ** 64 - 1)):
|
||||
if not isinstance(value, ACCEPTED_TYPES):
|
||||
raise TypeError("value of execution argument must be either int, numpy.array or numpy.uint{8,16,32,64}")
|
||||
if isinstance(value, ACCEPTED_INTS):
|
||||
if isinstance(value, int) and not (0 <= value < np.iinfo(np.uint64).max):
|
||||
raise TypeError(
|
||||
"single integer must be in the range [0, 2**64 - 1] (uint64)"
|
||||
)
|
||||
@@ -107,8 +112,8 @@ def create_execution_argument(value: Union[int, np.ndarray]) -> "_LambdaArgument
|
||||
assert isinstance(value, np.ndarray)
|
||||
if value.shape == ():
|
||||
return _LambdaArgument.from_scalar(value)
|
||||
if value.dtype != np.uint8:
|
||||
raise TypeError("numpy.array must be of dtype uint8")
|
||||
if value.dtype not in ACCEPTED_NUMPY_UINTS:
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
return _LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
|
||||
|
||||
|
||||
@@ -143,14 +143,38 @@ uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
lambdaArgument lambdaArgumentFromTensor(std::vector<uint8_t> data,
|
||||
std::vector<int64_t> dimensions) {
|
||||
lambdaArgument lambdaArgumentFromTensorU8(std::vector<uint8_t> data,
|
||||
std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
lambdaArgument lambdaArgumentFromTensorU16(std::vector<uint16_t> data,
|
||||
std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint16_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
lambdaArgument lambdaArgumentFromTensorU32(std::vector<uint32_t> data,
|
||||
std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint32_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
lambdaArgument lambdaArgumentFromTensorU64(std::vector<uint64_t> data,
|
||||
std::vector<int64_t> dimensions) {
|
||||
lambdaArgument tensor_arg{
|
||||
std::make_shared<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>>(data, dimensions)};
|
||||
return tensor_arg;
|
||||
}
|
||||
|
||||
lambdaArgument lambdaArgumentFromScalar(uint64_t scalar) {
|
||||
lambdaArgument scalar_arg{
|
||||
std::make_shared<mlir::zamalang::IntLambdaArgument<uint64_t>>(scalar)};
|
||||
|
||||
@@ -43,6 +43,39 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
8,
|
||||
id="add_eint_int_with_np_uint8_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint16(3), np.uint16(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint16_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint32(3), np.uint32(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint32_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: !HLFHE.eint<7>, %arg1: i8) -> !HLFHE.eint<7> {
|
||||
%1 = "HLFHE.add_eint_int"(%arg0, %arg1): (!HLFHE.eint<7>, i8) -> (!HLFHE.eint<7>)
|
||||
return %1: !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(np.uint64(3), np.uint64(5)),
|
||||
8,
|
||||
id="add_eint_int_with_np_uint64_as_scalar",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
@@ -57,7 +90,55 @@ KEY_SET_CACHE_PATH = os.path.join(tempfile.gettempdir(), 'KeySetCache')
|
||||
np.array([4, 3, 2, 1], dtype=np.uint8),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int",
|
||||
id="dot_eint_int_uint8",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint16),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint16),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint16",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint32),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint32),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint32",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%arg0: tensor<4x!HLFHE.eint<7>>, %arg1: tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
{
|
||||
%ret = "HLFHELinalg.dot_eint_int"(%arg0, %arg1) :
|
||||
(tensor<4x!HLFHE.eint<7>>, tensor<4xi8>) -> !HLFHE.eint<7>
|
||||
return %ret : !HLFHE.eint<7>
|
||||
}
|
||||
""",
|
||||
(
|
||||
np.array([1, 2, 3, 4], dtype=np.uint64),
|
||||
np.array([4, 3, 2, 1], dtype=np.uint64),
|
||||
),
|
||||
20,
|
||||
id="dot_eint_int_uint64",
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user