fix(python-bindings): Support np.array with dtype upt to 64 bits

This commit is contained in:
Quentin Bourgerie
2022-08-19 11:58:52 +02:00
parent d647bc735f
commit 9257404f5f
5 changed files with 154 additions and 56 deletions

View File

@@ -227,19 +227,19 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
});
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
.def_static("from_tensor",
.def_static("from_tensor_8",
[](std::vector<uint8_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU8(tensor, dims);
})
.def_static("from_tensor",
.def_static("from_tensor_16",
[](std::vector<uint16_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU16(tensor, dims);
})
.def_static("from_tensor",
.def_static("from_tensor_32",
[](std::vector<uint32_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU32(tensor, dims);
})
.def_static("from_tensor",
.def_static("from_tensor_64",
[](std::vector<uint64_t> tensor, std::vector<int64_t> dims) {
return lambdaArgumentFromTensorU64(tensor, dims);
})

View File

@@ -188,4 +188,12 @@ class ClientSupport(WrapperCpp):
value = value.max()
# should be a single uint here
return LambdaArgument.from_scalar(value)
return LambdaArgument.from_tensor(value.flatten().tolist(), value.shape)
if value.dtype == np.uint8:
return LambdaArgument.from_tensor_8(value.flatten().tolist(), value.shape)
if value.dtype == np.uint16:
return LambdaArgument.from_tensor_16(value.flatten().tolist(), value.shape)
if value.dtype == np.uint32:
return LambdaArgument.from_tensor_32(value.flatten().tolist(), value.shape)
if value.dtype == np.uint64:
return LambdaArgument.from_tensor_64(value.flatten().tolist(), value.shape)
raise TypeError("numpy.array must be of dtype uint{8,16,32,64}")

View File

@@ -63,7 +63,7 @@ class LambdaArgument(WrapperCpp):
return LambdaArgument.wrap(_LambdaArgument.from_scalar(scalar))
@staticmethod
def from_tensor(data: List[int], shape: List[int]) -> "LambdaArgument":
def from_tensor_8(data: List[int], shape: List[int]) -> "LambdaArgument":
"""Build a LambdaArgument containing the given tensor.
Args:
@@ -73,7 +73,46 @@ class LambdaArgument(WrapperCpp):
Returns:
LambdaArgument
"""
return LambdaArgument.wrap(_LambdaArgument.from_tensor(data, shape))
return LambdaArgument.wrap(_LambdaArgument.from_tensor_8(data, shape))
@staticmethod
def from_tensor_16(data: List[int], shape: List[int]) -> "LambdaArgument":
"""Build a LambdaArgument containing the given tensor.
Args:
data (List[int]): flattened tensor data
shape (List[int]): shape of original tensor before flattening
Returns:
LambdaArgument
"""
return LambdaArgument.wrap(_LambdaArgument.from_tensor_16(data, shape))
@staticmethod
def from_tensor_32(data: List[int], shape: List[int]) -> "LambdaArgument":
"""Build a LambdaArgument containing the given tensor.
Args:
data (List[int]): flattened tensor data
shape (List[int]): shape of original tensor before flattening
Returns:
LambdaArgument
"""
return LambdaArgument.wrap(_LambdaArgument.from_tensor_32(data, shape))
@staticmethod
def from_tensor_64(data: List[int], shape: List[int]) -> "LambdaArgument":
"""Build a LambdaArgument containing the given tensor.
Args:
data (List[int]): flattened tensor data
shape (List[int]): shape of original tensor before flattening
Returns:
LambdaArgument
"""
return LambdaArgument.wrap(_LambdaArgument.from_tensor_64(data, shape))
def is_scalar(self) -> bool:
"""Check if the contained argument is a scalar.

View File

@@ -262,44 +262,95 @@ std::string roundTrip(const char *module) {
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
}
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
template <typename T>
std::vector<uint64_t> copyTensorLambdaArgumentTo64bitsvector(
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
auto numElements = tensor->getNumElements();
if (!numElements) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(sizeOrErr.takeError()));
<< llvm::toString(std::move(numElements.takeError()));
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
std::vector<uint64_t> res;
res.reserve(*numElements);
T *data = tensor->getValue();
for (size_t i = 0; i < *numElements; i++) {
res.push_back(data[i]);
}
return res;
}
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(sizeOrErr.takeError()));
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
std::vector<int64_t>
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return arg->getDimensions();
}
return arg->getDimensions();
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return arg->getDimensions();
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {

View File

@@ -43,30 +43,30 @@ def test_accepted_ints(value):
# TODO: #495
# @pytest.mark.parametrize(
# "dtype",
# [
# pytest.param(np.uint8, id="uint8"),
# pytest.param(np.uint16, id="uint16"),
# pytest.param(np.uint32, id="uint32"),
# pytest.param(np.uint64, id="uint64"),
# ],
# )
# def test_accepted_ndarray(dtype):
# value = np.array([0, 1, 2], dtype=dtype)
# try:
# arg = ClientSupport._create_lambda_argument(value)
# except Exception:
# pytest.fail(f"value of type {type(value)} should be supported")
@pytest.mark.parametrize(
"dtype, maxvalue",
[
pytest.param(np.uint8, 2**8 - 1, id="uint8"),
pytest.param(np.uint16, 2**16 - 1, id="uint16"),
pytest.param(np.uint32, 2**32 - 1, id="uint32"),
pytest.param(np.uint64, 2**64 - 1, id="uint64"),
],
)
def test_accepted_ndarray(dtype, maxvalue):
value = np.array([0, 1, 2, maxvalue], dtype=dtype)
try:
arg = ClientSupport._create_lambda_argument(value)
except Exception:
pytest.fail(f"value of type {type(value)} should be supported")
# assert arg.is_tensor(), "should have been a tensor"
# assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
# assert np.all(
# np.equal(
# value,
# np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
# )
# )
assert arg.is_tensor(), "should have been a tensor"
assert np.all(np.equal(arg.get_tensor_shape(), value.shape))
assert np.all(
np.equal(
value,
np.array(arg.get_tensor_data()).reshape(arg.get_tensor_shape()),
)
)
def test_accepted_array_as_scalar():