mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(python): support functions returning tensors
This commit is contained in:
@@ -23,13 +23,27 @@ struct executionArguments {
|
||||
};
|
||||
typedef struct executionArguments exectuionArguments;
|
||||
|
||||
struct lambdaArgument {
|
||||
std::unique_ptr<mlir::zamalang::LambdaArgument> ptr;
|
||||
};
|
||||
typedef struct lambdaArgument lambdaArgument;
|
||||
|
||||
MLIR_CAPI_EXPORTED mlir::zamalang::JitCompilerEngine::Lambda
|
||||
buildLambda(const char *module, const char *funcName);
|
||||
|
||||
MLIR_CAPI_EXPORTED uint64_t invokeLambda(lambda l, executionArguments args);
|
||||
MLIR_CAPI_EXPORTED lambdaArgument invokeLambda(lambda l,
|
||||
executionArguments args);
|
||||
|
||||
MLIR_CAPI_EXPORTED std::string roundTrip(const char *module);
|
||||
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg);
|
||||
MLIR_CAPI_EXPORTED std::vector<uint64_t>
|
||||
lambdaArgumentGetTensorData(lambdaArgument &lambda_arg);
|
||||
MLIR_CAPI_EXPORTED std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg);
|
||||
MLIR_CAPI_EXPORTED bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg);
|
||||
MLIR_CAPI_EXPORTED uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
using mlir::zamalang::LambdaArgument;
|
||||
|
||||
/// Populate the compiler API python module.
|
||||
void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
@@ -41,6 +42,27 @@ void mlir::zamalang::python::populateCompilerAPISubmodule(pybind11::module &m) {
|
||||
return buildLambda(mlir_input.c_str(), func_name.c_str());
|
||||
});
|
||||
|
||||
pybind11::class_<lambdaArgument>(m, "LambdaArgument")
|
||||
.def("is_tensor",
|
||||
[](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentIsTensor(lambda_arg);
|
||||
})
|
||||
.def("get_tensor_data",
|
||||
[](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentGetTensorData(lambda_arg);
|
||||
})
|
||||
.def("get_tensor_shape",
|
||||
[](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentGetTensorDimensions(lambda_arg);
|
||||
})
|
||||
.def("is_scalar",
|
||||
[](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentIsScalar(lambda_arg);
|
||||
})
|
||||
.def("get_scalar", [](lambdaArgument &lambda_arg) {
|
||||
return lambdaArgumentGetScalar(lambda_arg);
|
||||
});
|
||||
|
||||
pybind11::class_<JitCompilerEngine::Lambda>(m, "Lambda")
|
||||
.def("invoke", [](JitCompilerEngine::Lambda &py_lambda,
|
||||
std::vector<ExecutionArgument> args) {
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import List, Union
|
||||
from mlir._mlir_libs._zamalang._compiler import JitCompilerEngine as _JitCompilerEngine
|
||||
from mlir._mlir_libs._zamalang._compiler import ExecutionArgument as _ExecutionArgument
|
||||
from mlir._mlir_libs._zamalang._compiler import round_trip as _round_trip
|
||||
import numpy as np
|
||||
|
||||
|
||||
def round_trip(mlir_str: str) -> str:
|
||||
"""Parse the MLIR input, then return it back.
|
||||
@@ -67,7 +69,7 @@ class CompilerEngine:
|
||||
raise TypeError("input must be an `str`")
|
||||
self._lambda = self._engine.build_lambda(mlir_str, func_name)
|
||||
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> int:
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> Union[int, np.array]:
|
||||
"""Run the compiled code.
|
||||
|
||||
Args:
|
||||
@@ -76,11 +78,20 @@ class CompilerEngine:
|
||||
Raises:
|
||||
TypeError: if execution arguments can't be constructed
|
||||
RuntimeError: if the engine has not compiled any code yet
|
||||
RuntimeError: if the return type is unknown
|
||||
|
||||
Returns:
|
||||
int: result of execution.
|
||||
int or numpy.array: result of execution.
|
||||
"""
|
||||
if self._lambda is None:
|
||||
raise RuntimeError("need to compile an MLIR code first")
|
||||
execution_arguments = [create_execution_argument(arg) for arg in args]
|
||||
return self._lambda.invoke(execution_arguments)
|
||||
lambda_arg = self._lambda.invoke(execution_arguments)
|
||||
if lambda_arg.is_scalar():
|
||||
return lambda_arg.get_scalar()
|
||||
elif lambda_arg.is_tensor():
|
||||
shape = lambda_arg.get_tensor_shape()
|
||||
tensor = np.array(lambda_arg.get_tensor_data()).reshape(shape)
|
||||
return tensor
|
||||
else:
|
||||
raise RuntimeError("unknown return type")
|
||||
|
||||
@@ -23,7 +23,7 @@ mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
uint64_t invokeLambda(lambda l, executionArguments args) {
|
||||
lambdaArgument invokeLambda(lambda l, executionArguments args) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
|
||||
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
|
||||
|
||||
@@ -45,8 +45,12 @@ uint64_t invokeLambda(lambda l, executionArguments args) {
|
||||
}
|
||||
}
|
||||
// Run lambda
|
||||
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
|
||||
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
|
||||
llvm::Expected<std::unique_ptr<mlir::zamalang::LambdaArgument>> resOrError =
|
||||
(*lambda_ptr)
|
||||
.
|
||||
operator()<std::unique_ptr<mlir::zamalang::LambdaArgument>>(
|
||||
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(
|
||||
lambdaArgumentsRef));
|
||||
// Free heap
|
||||
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
|
||||
delete lambdaArgumentsRef[i];
|
||||
@@ -58,7 +62,8 @@ uint64_t invokeLambda(lambda l, executionArguments args) {
|
||||
<< llvm::toString(std::move(resOrError.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
return *resOrError;
|
||||
lambdaArgument result{std::move(*resOrError)};
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string roundTrip(const char *module) {
|
||||
@@ -80,3 +85,59 @@ std::string roundTrip(const char *module) {
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsTensor(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
|
||||
}
|
||||
|
||||
std::vector<uint64_t> lambdaArgumentGetTensorData(lambdaArgument &lambda_arg) {
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
|
||||
llvm::Expected<size_t> sizeOrErr = arg->getNumElements();
|
||||
if (!sizeOrErr) {
|
||||
std::string backingString;
|
||||
llvm::raw_string_ostream os(backingString);
|
||||
os << "Couldn't get size of tensor: "
|
||||
<< llvm::toString(std::move(sizeOrErr.takeError()));
|
||||
throw std::runtime_error(os.str());
|
||||
}
|
||||
std::vector<uint64_t> data(arg->getValue(), arg->getValue() + *sizeOrErr);
|
||||
return data;
|
||||
}
|
||||
|
||||
std::vector<int64_t>
|
||||
lambdaArgumentGetTensorDimensions(lambdaArgument &lambda_arg) {
|
||||
mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t>>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument(
|
||||
"LambdaArgument isn't a tensor, should "
|
||||
"be a TensorLambdaArgument<IntLambdaArgument<uint64_t>>");
|
||||
}
|
||||
return arg->getDimensions();
|
||||
}
|
||||
|
||||
bool lambdaArgumentIsScalar(lambdaArgument &lambda_arg) {
|
||||
return lambda_arg.ptr->isa<mlir::zamalang::IntLambdaArgument<uint64_t>>();
|
||||
}
|
||||
|
||||
uint64_t lambdaArgumentGetScalar(lambdaArgument &lambda_arg) {
|
||||
mlir::zamalang::IntLambdaArgument<uint64_t> *arg =
|
||||
lambda_arg.ptr->dyn_cast<mlir::zamalang::IntLambdaArgument<uint64_t>>();
|
||||
if (arg == nullptr) {
|
||||
throw std::invalid_argument("LambdaArgument isn't a scalar, should "
|
||||
"be an IntLambdaArgument<uint64_t>");
|
||||
}
|
||||
return arg->getValue();
|
||||
}
|
||||
|
||||
@@ -29,12 +29,28 @@ from zamalang import CompilerEngine
|
||||
20,
|
||||
id="dot_eint_int"
|
||||
),
|
||||
pytest.param(
|
||||
"""
|
||||
func @main(%a0: tensor<4x!HLFHE.eint<6>>, %a1: tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>> {
|
||||
%res = "HLFHELinalg.add_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<6>>, tensor<4xi7>) -> tensor<4x!HLFHE.eint<6>>
|
||||
return %res : tensor<4x!HLFHE.eint<6>>
|
||||
}
|
||||
""",
|
||||
([31, 6, 12, 9], [32, 9, 2, 3]),
|
||||
[63, 15, 14, 12],
|
||||
id="add_eint_int_1D"
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_compile_and_run(mlir_input, args, expected_result):
|
||||
engine = CompilerEngine()
|
||||
engine.compile_fhe(mlir_input)
|
||||
assert engine.run(*args) == expected_result
|
||||
if isinstance(expected_result, int):
|
||||
assert engine.run(*args) == expected_result
|
||||
else:
|
||||
# numpy array on the left
|
||||
assert (engine.run(*args) == expected_result).all()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user