mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-14 23:38:10 -05:00
This commit rebases the compiler onto commit 465ee9bfb26d from
llvm-project with locally maintained patches on top, i.e.:
* 5d8669d669ee: Fix the element alignment (size) for memrefCopy
* 4239163ea337: fix: Do not fold the memref.subview if the offset are
!= 0 and strides != 1
* 72c5decfcc21: remove github stuff from llvm
* 8d0ce8f9eca1: Support arbitrary element types in named operations
via attributes
* 94f64805c38c: Copy attributes of scf.for on bufferization and make
it an allocation hoisting barrier
Main upstream changes from llvm-project that required modification of
concretecompiler:
* Switch to C++17
* Various changes in the interfaces for linalg named operations
* Transition from `llvm::Optional` to `std::optional`
* Use of enums instead of string values for iterator types in linalg
* Changed default naming convention of getter methods in
ODS-generated operation classes from `some_value()` to
`getSomeValue()`
* Renaming of Arithmetic dialect to Arith
* Refactoring of side effect interfaces (i.e., renaming from
`NoSideEffect` to `Pure`)
* Re-design of the data flow analysis framework
* Refactoring of build targets for Python bindings
* Refactoring of array attributes with integer values
* Renaming of `linalg.init_tensor` to `tensor.empty`
* Emission of `linalg.map` operations in bufferization of the Tensor
dialect requiring another linalg conversion pass and registration
of the bufferization op interfaces for linalg operations
* Refactoring of the one-shot bufferizer
* Necessity to run the expand-strided-metadata, affine-to-std and
finalize-memref-to-llvm passes before converson to the LLVM
dialect
* Renaming of `BlockAndValueMapping` to `IRMapping`
* Changes in the build function of `LLVM::CallOp`
* Refactoring of the construction of `llvm::ArrayRef` and
`llvm::MutableArrayRef` (direct invocation of constructor instead
of builder functions for some cases)
* New naming conventions for generated SSA values requiring rewrite
of some check tests
* Refactoring of `mlir::LLVM::lookupOrCreateMallocFn()`
* Interface changes in generated type parsers
* New dependencies for to mlir_float16_utils and
MLIRSparseTensorRuntime for the runtime
* Overhaul of MLIR-c deleting `mlir-c/Registration.h`
* Deletion of library MLIRLinalgToSPIRV
* Deletion of library MLIRLinalgAnalysis
* Deletion of library MLIRMemRefUtils
* Deletion of library MLIRQuantTransforms
* Deletion of library MLIRVectorToROCDL
258 lines
8.8 KiB
C++
258 lines
8.8 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
#include <cassert>
|
|
#include <map>
|
|
|
|
#include <llvm/ADT/Optional.h>
|
|
#include <llvm/ADT/STLExtras.h>
|
|
#include <llvm/Support/Error.h>
|
|
|
|
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
|
#include <optional>
|
|
|
|
#include "concrete/curves.h"
|
|
#include "concretelang/ClientLib/ClientParameters.h"
|
|
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
|
|
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
|
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
|
#include "concretelang/Support/Error.h"
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
namespace clientlib = ::concretelang::clientlib;
|
|
using ::concretelang::clientlib::ChunkInfo;
|
|
using ::concretelang::clientlib::CircuitGate;
|
|
using ::concretelang::clientlib::ClientParameters;
|
|
using ::concretelang::clientlib::Encoding;
|
|
using ::concretelang::clientlib::EncryptionGate;
|
|
using ::concretelang::clientlib::LweSecretKeyID;
|
|
using ::concretelang::clientlib::Precision;
|
|
using ::concretelang::clientlib::Variance;
|
|
|
|
const auto keyFormat = concrete::BINARY;
|
|
|
|
/// For the v0 the secretKeyID and precision are the same for all gates.
|
|
llvm::Expected<CircuitGate>
|
|
gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
|
Variance variance, llvm::Optional<ChunkInfo> chunkInfo,
|
|
mlir::Type type) {
|
|
if (type.isIntOrIndex()) {
|
|
// TODO - The index type is dependant of the target architecture, so
|
|
// actually we assume we target only 64 bits, we need to have some the size
|
|
// of the word of the target system.
|
|
size_t width = 64;
|
|
if (!type.isIndex()) {
|
|
width = type.getIntOrFloatBitWidth();
|
|
}
|
|
|
|
bool sign = type.isSignedInteger();
|
|
|
|
return CircuitGate{
|
|
/*.encryption = */ std::nullopt,
|
|
/*.shape = */
|
|
{/*.width = */ width,
|
|
/*.dimensions = */ std::vector<int64_t>(),
|
|
/*.size = */ 0,
|
|
/* .sign */ sign},
|
|
/*.chunkInfo = */ std::nullopt,
|
|
};
|
|
}
|
|
if (auto lweTy = type.dyn_cast_or_null<
|
|
mlir::concretelang::FHE::FheIntegerInterface>()) {
|
|
bool sign = lweTy.isSigned();
|
|
std::vector<int64_t> crt;
|
|
if (fheContext.parameter.largeInteger.has_value()) {
|
|
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
|
}
|
|
size_t width;
|
|
uint64_t size = 0;
|
|
std::vector<int64_t> dims;
|
|
if (chunkInfo.has_value()) {
|
|
width = chunkInfo->size;
|
|
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
|
size = lweTy.getWidth() / chunkInfo->width;
|
|
dims.push_back(size);
|
|
} else {
|
|
width = (size_t)lweTy.getWidth();
|
|
}
|
|
return CircuitGate{
|
|
/* .encryption = */ std::optional<EncryptionGate>({
|
|
/* .secretKeyID = */ secretKeyID,
|
|
/* .variance = */ variance,
|
|
/* .encoding = */
|
|
{
|
|
/* .precision = */ width,
|
|
/* .crt = */ crt,
|
|
/*.sign = */ sign,
|
|
},
|
|
}),
|
|
/*.shape = */
|
|
{
|
|
/*.width = */ width,
|
|
/*.dimensions = */ dims,
|
|
/*.size = */ size,
|
|
/*.sign = */ sign,
|
|
},
|
|
/*.chunkInfo = */ chunkInfo,
|
|
};
|
|
}
|
|
if (auto lweTy = type.dyn_cast_or_null<
|
|
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
|
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
|
return CircuitGate{
|
|
/* .encryption = */ std::optional<EncryptionGate>({
|
|
/* .secretKeyID = */ secretKeyID,
|
|
/* .variance = */ variance,
|
|
/* .encoding = */
|
|
{
|
|
/* .precision = */ width,
|
|
/* .crt = */ std::vector<int64_t>(),
|
|
/* .sign = */ false,
|
|
},
|
|
}),
|
|
/*.shape = */
|
|
{
|
|
/*.width = */ width,
|
|
/*.dimensions = */ std::vector<int64_t>(),
|
|
/*.size = */ 0,
|
|
/*.sign = */ false,
|
|
},
|
|
/*.chunkInfo = */ std::nullopt,
|
|
};
|
|
}
|
|
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
|
if (tensor != nullptr) {
|
|
auto gate = gateFromMLIRType(fheContext, secretKeyID, variance, chunkInfo,
|
|
tensor.getElementType());
|
|
if (auto err = gate.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
gate->shape.dimensions = tensor.getShape().vec();
|
|
gate->shape.size = 1;
|
|
for (auto dimSize : gate->shape.dimensions) {
|
|
gate->shape.size *= dimSize;
|
|
}
|
|
return gate;
|
|
}
|
|
return llvm::make_error<llvm::StringError>(
|
|
"cannot convert MLIR type to shape", llvm::inconvertibleErrorCode());
|
|
}
|
|
|
|
llvm::Expected<ClientParameters>
|
|
createClientParametersForV0(V0FHEContext fheContext,
|
|
llvm::StringRef functionName, mlir::ModuleOp module,
|
|
int bitsOfSecurity,
|
|
llvm::Optional<ChunkInfo> chunkInfo) {
|
|
const auto v0Curve = concrete::getSecurityCurve(bitsOfSecurity, keyFormat);
|
|
|
|
if (v0Curve == nullptr) {
|
|
return StreamStringError("Cannot find security curves for ")
|
|
<< bitsOfSecurity << "bits";
|
|
}
|
|
|
|
V0Parameter &v0Param = fheContext.parameter;
|
|
Variance inputVariance =
|
|
v0Curve->getVariance(1, v0Param.getNBigLweDimension(), 64);
|
|
|
|
Variance bootstrapKeyVariance = v0Curve->getVariance(
|
|
v0Param.glweDimension, v0Param.getPolynomialSize(), 64);
|
|
Variance keyswitchKeyVariance = v0Curve->getVariance(1, v0Param.nSmall, 64);
|
|
// Static client parameters from global parameters for v0
|
|
ClientParameters c;
|
|
|
|
assert(c.secretKeys.size() == clientlib::BIG_KEY);
|
|
clientlib::LweSecretKeyParam skParam;
|
|
skParam.dimension = v0Param.getNBigLweDimension();
|
|
c.secretKeys.push_back(skParam);
|
|
|
|
bool has_small_key = v0Param.nSmall != 0;
|
|
bool has_bootstrap = v0Param.brLevel != 0;
|
|
if (has_small_key) {
|
|
assert(c.secretKeys.size() == clientlib::SMALL_KEY);
|
|
clientlib::LweSecretKeyParam skParam2;
|
|
skParam2.dimension = v0Param.nSmall;
|
|
c.secretKeys.push_back(skParam2);
|
|
}
|
|
if (has_bootstrap) {
|
|
auto inputKey = (has_small_key) ? clientlib::SMALL_KEY : clientlib::BIG_KEY;
|
|
clientlib::BootstrapKeyParam bskParam;
|
|
bskParam.inputSecretKeyID = inputKey;
|
|
bskParam.outputSecretKeyID = clientlib::BIG_KEY;
|
|
bskParam.level = v0Param.brLevel;
|
|
bskParam.baseLog = v0Param.brLogBase;
|
|
bskParam.glweDimension = v0Param.glweDimension;
|
|
bskParam.variance = bootstrapKeyVariance;
|
|
bskParam.polynomialSize = v0Param.getPolynomialSize();
|
|
bskParam.inputLweDimension = v0Param.nSmall;
|
|
c.bootstrapKeys.push_back(bskParam);
|
|
}
|
|
if (v0Param.largeInteger.has_value()) {
|
|
clientlib::PackingKeyswitchKeyParam param;
|
|
param.inputSecretKeyID = clientlib::BIG_KEY;
|
|
param.outputSecretKeyID = clientlib::BIG_KEY;
|
|
param.level = v0Param.largeInteger->wopPBS.packingKeySwitch.level;
|
|
param.baseLog = v0Param.largeInteger->wopPBS.packingKeySwitch.baseLog;
|
|
|
|
param.glweDimension = v0Param.glweDimension;
|
|
param.polynomialSize = v0Param.getPolynomialSize();
|
|
param.inputLweDimension = v0Param.getNBigLweDimension();
|
|
param.variance = v0Curve->getVariance(v0Param.glweDimension,
|
|
v0Param.getPolynomialSize(), 64);
|
|
|
|
c.packingKeyswitchKeys.push_back(param);
|
|
}
|
|
if (has_small_key) {
|
|
clientlib::KeyswitchKeyParam kskParam;
|
|
kskParam.inputSecretKeyID = clientlib::BIG_KEY;
|
|
kskParam.outputSecretKeyID = clientlib::SMALL_KEY;
|
|
kskParam.level = v0Param.ksLevel;
|
|
kskParam.baseLog = v0Param.ksLogBase;
|
|
kskParam.variance = keyswitchKeyVariance;
|
|
c.keyswitchKeys.push_back(kskParam);
|
|
}
|
|
|
|
c.functionName = (std::string)functionName;
|
|
// Find the input function
|
|
auto rangeOps = module.getOps<mlir::func::FuncOp>();
|
|
auto funcOp = llvm::find_if(rangeOps, [&](mlir::func::FuncOp op) {
|
|
return op.getName() == functionName;
|
|
});
|
|
if (funcOp == rangeOps.end()) {
|
|
return StreamStringError(
|
|
"cannot find the function for generate client parameters: ")
|
|
<< functionName;
|
|
}
|
|
|
|
// Create input and output circuit gate parameters
|
|
auto funcType = (*funcOp).getFunctionType();
|
|
|
|
auto inputs = funcType.getInputs();
|
|
|
|
auto gateFromType = [&](mlir::Type ty) {
|
|
return gateFromMLIRType(fheContext, clientlib::BIG_KEY, inputVariance,
|
|
chunkInfo, ty);
|
|
};
|
|
for (auto inType : inputs) {
|
|
auto gate = gateFromType(inType);
|
|
if (auto err = gate.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
c.inputs.push_back(gate.get());
|
|
}
|
|
for (auto outType : funcType.getResults()) {
|
|
auto gate = gateFromType(outType);
|
|
if (auto err = gate.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
c.outputs.push_back(gate.get());
|
|
}
|
|
return c;
|
|
}
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|