fix(python-bindings): Support np.array with dtype upt to 64 bits

This commit is contained in:
Quentin Bourgerie
2022-08-19 11:58:52 +02:00
parent d647bc735f
commit 9257404f5f
5 changed files with 154 additions and 56 deletions

View File

@@ -262,44 +262,95 @@ std::string roundTrip(const char *module) {
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
}
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
template <typename T>
std::vector<uint64_t> copyTensorLambdaArgumentTo64bitsvector(
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
auto numElements = tensor->getNumElements();
if (!numElements) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(sizeOrErr.takeError()));
<< llvm::toString(std::move(numElements.takeError()));
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
std::vector<uint64_t> res;
res.reserve(*numElements);
T *data = tensor->getValue();
for (size_t i = 0; i < *numElements; i++) {
res.push_back(data[i]);
}
return res;
}
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
if (!sizeOrErr) {
std::string backingString;
llvm::raw_string_ostream os(backingString);
os << "Couldn't get size of tensor: "
<< llvm::toString(std::move(sizeOrErr.takeError()));
throw std::runtime_error(os.str());
}
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
return data;
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return copyTensorLambdaArgumentTo64bitsvector(arg);
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
}
std::vector<int64_t>
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
if (arg == nullptr) {
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
return arg->getDimensions();
}
return arg->getDimensions();
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
return arg->getDimensions();
}
if (auto arg =
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
return arg->getDimensions();
}
throw std::invalid_argument(
"LambdaArgument isn't a tensor, should "
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
}
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {