mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(python-bindings): Support np.array with dtype upt to 64 bits
This commit is contained in:
@@ -227,19 +227,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def_static("from_tensor",
|
||||
.def_static("from_tensor_8",
|
||||
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU8(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
.def_static("from_tensor_16",
|
||||
[](std::vector<uint16_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU16(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
.def_static("from_tensor_32",
|
||||
[](std::vector<uint32_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU32(tensor, dims);
|
||||
})
|
||||
.def_static("from_tensor",
|
||||
.def_static("from_tensor_64",
|
||||
[](std::vector<uint64_t> tensor, std::vector<int64_t> dims) {
|
||||
return lambdaArgumentFromTensorU64(tensor, dims);
|
||||
})
|
||||
|
||||
@@ -188,4 +188,12 @@ class ClientSupport(WrapperCpp):
|
||||
value = value.max()
|
||||
# should be a single uint here
|
||||
return LambdaArgument.from_scalar(value)
|
||||
return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint8:
|
||||
return LambdaArgument.from_tensor_8(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint16:
|
||||
return LambdaArgument.from_tensor_16(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint32:
|
||||
return LambdaArgument.from_tensor_32(value.flatten().tolist(), value.shape)
|
||||
if value.dtype == np.uint64:
|
||||
return LambdaArgument.from_tensor_64(value.flatten().tolist(), value.shape)
|
||||
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")
|
||||
|
||||
@@ -63,7 +63,7 @@ class LambdaArgument(WrapperCpp):
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar))
|
||||
|
||||
@staticmethod
|
||||
def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
def from_tensor_8(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given tensor.
|
||||
|
||||
Args:
|
||||
@@ -73,7 +73,46 @@ class LambdaArgument(WrapperCpp):
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape))
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor_8(data, shape))
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_16(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given tensor.
|
||||
|
||||
Args:
|
||||
data (List[int]): flattened tensor data
|
||||
shape (List[int]): shape of original tensor before flattening
|
||||
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor_16(data, shape))
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_32(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given tensor.
|
||||
|
||||
Args:
|
||||
data (List[int]): flattened tensor data
|
||||
shape (List[int]): shape of original tensor before flattening
|
||||
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor_32(data, shape))
|
||||
|
||||
@staticmethod
|
||||
def from_tensor_64(data: List[int], shape: List[int]) -> "LambdaArgument":
|
||||
"""Build a LambdaArgument containing the given tensor.
|
||||
|
||||
Args:
|
||||
data (List[int]): flattened tensor data
|
||||
shape (List[int]): shape of original tensor before flattening
|
||||
|
||||
Returns:
|
||||
LambdaArgument
|
||||
"""
|
||||
return LambdaArgument.wrap(_LambdaArgument.from_tensor_64(data, shape))
|
||||
|
||||
def is_scalar(self) -> bool:
|
||||
"""Check if the contained argument is a scalar.
|
||||
|
||||
@@ -262,44 +262,95 @@ std::string roundTrip(const char *module) {
|
||||
|
||||
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
template <typename T>
|
||||
std::vector<uint64_t> copyTensorLambdaArgumentTo64bitsvector(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
|
||||
auto numElements = tensor->getNumElements();
|
||||
if (!numElements) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(sizeOrErr.takeError()));
|
||||
<< llvm::toString(std::move(numElements.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
std::vector<uint64_t> res;
|
||||
res.reserve(*numElements);
|
||||
T *data = tensor->getValue();
|
||||
for (size_t i = 0; i < *numElements; i++) {
|
||||
res.push_back(data[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(sizeOrErr.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
|
||||
std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
return arg->getDimensions();
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
|
||||
|
||||
@@ -43,30 +43,30 @@ def test_accepted_ints(value):
|
||||
|
||||
|
||||
# TODO: #495
|
||||
# @pytest.mark.parametrize(
|
||||
# "dtype",
|
||||
# [
|
||||
# pytest.param(np.uint8, id="uint8"),
|
||||
# pytest.param(np.uint16, id="uint16"),
|
||||
# pytest.param(np.uint32, id="uint32"),
|
||||
# pytest.param(np.uint64, id="uint64"),
|
||||
# ],
|
||||
# )
|
||||
# def test_accepted_ndarray(dtype):
|
||||
# value = np.array([0, 1, 2], dtype=dtype)
|
||||
# try:
|
||||
# arg = ClientSupport._create_lambda_argument(value)
|
||||
# except Exception:
|
||||
# pytest.fail(f"value of type {type(value)} should be supported")
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, maxvalue",
|
||||
[
|
||||
pytest.param(np.uint8, 2**8 - 1, id="uint8"),
|
||||
pytest.param(np.uint16, 2**16 - 1, id="uint16"),
|
||||
pytest.param(np.uint32, 2**32 - 1, id="uint32"),
|
||||
pytest.param(np.uint64, 2**64 - 1, id="uint64"),
|
||||
],
|
||||
)
|
||||
def test_accepted_ndarray(dtype, maxvalue):
|
||||
value = np.array([0, 1, 2, maxvalue], dtype=dtype)
|
||||
try:
|
||||
arg = ClientSupport._create_lambda_argument(value)
|
||||
except Exception:
|
||||
pytest.fail(f"value of type {type(value)} should be supported")
|
||||
|
||||
# assert arg.is_tensor(), "should have been a tensor"
|
||||
# assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
|
||||
# assert np.all(
|
||||
# np.equal(
|
||||
# value,
|
||||
# np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
|
||||
# )
|
||||
# )
|
||||
assert arg.is_tensor(), "should have been a tensor"
|
||||
assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
|
||||
assert np.all(
|
||||
np.equal(
|
||||
value,
|
||||
np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_accepted_array_as_scalar():
|
||||
|
||||
Reference in New Issue
Block a user