mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
fix(python-bindings): Support np.array with dtype upt to 64 bits
This commit is contained in:
@@ -262,44 +262,95 @@ std::string roundTrip(const char *module) {
|
||||
|
||||
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>() ||
|
||||
lambda_arg.ptr->isa<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
template <typename T>
|
||||
std::vector<uint64_t> copyTensorLambdaArgumentTo64bitsvector(
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<T>> *tensor) {
|
||||
auto numElements = tensor->getNumElements();
|
||||
if (!numElements) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(sizeOrErr.takeError()));
|
||||
<< llvm::toString(std::move(numElements.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
std::vector<uint64_t> res;
|
||||
res.reserve(*numElements);
|
||||
T *data = tensor->getValue();
|
||||
for (size_t i = 0; i < *numElements; i++) {
|
||||
res.push_back(data[i]);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(sizeOrErr.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return copyTensorLambdaArgumentTo64bitsvector(arg);
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor or has an unsupported bitwidth");
|
||||
}
|
||||
|
||||
std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
|
||||
mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
return arg->getDimensions();
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
if (auto arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>()) {
|
||||
return arg->getDimensions();
|
||||
}
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
|
||||
|
||||
Reference in New Issue
Block a user