refactor(bconcrete): Separate bufferization and CAPI call generation

This commit is contained in:
Quentin Bourgerie
2022-11-08 11:40:39 +01:00
parent fccb6da5b1
commit 9e16f31b87
20 changed files with 668 additions and 670 deletions

View File

@@ -11,7 +11,8 @@
namespace mlir {
namespace concretelang {
/// Create a pass to convert `BConcrete` dialect to CAPI calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertBConcreteToCAPIPass();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBConcreteToCAPIPass(bool gpu);
} // namespace concretelang
} // namespace mlir

View File

@@ -15,15 +15,17 @@ include "concretelang/Dialect/RT/IR/RTTypes.td"
class BConcrete_Op<string mnemonic, list<Trait> traits = []> :
Op<BConcrete_Dialect, mnemonic, traits>;
def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> {
let arguments = (ins
// BConcrete tensor operators /////////////////////////////////////////////////
def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor"> {
let arguments = (ins
1DTensorOf<[I64]>:$lhs,
1DTensorOf<[I64]>:$rhs
);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
def BConcrete_AddCRTLweTensorOp : BConcrete_Op<"add_crt_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
2DTensorOf<[I64]>:$rhs,
@@ -32,12 +34,12 @@ def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor"> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> {
def BConcrete_AddPlaintextCRTLweTensorOp : BConcrete_Op<"add_plaintext_crt_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
@@ -46,12 +48,12 @@ def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_b
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor"> {
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> {
def BConcrete_MulCleartextCRTLweTensorOp : BConcrete_Op<"mul_cleartext_crt_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$lhs,
AnyInteger:$rhs,
@@ -60,12 +62,12 @@ def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_b
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor"> {
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
def BConcrete_NegateCRTLweTensorOp : BConcrete_Op<"negate_crt_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$ciphertext,
I64ArrayAttr:$crtDecomposition
@@ -73,7 +75,7 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor"> {
let arguments = (ins
// LweKeySwitchKeyType:$keyswitch_key,
1DTensorOf<[I64]>:$ciphertext,
@@ -85,7 +87,7 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> {
def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_tensor"> {
let arguments = (ins
// LweKeySwitchKeyType:$keyswitch_key,
2DTensorOf<[I64]>:$ciphertext,
@@ -97,7 +99,7 @@ def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor"> {
let arguments = (ins
1DTensorOf<[I64]>:$input_ciphertext,
1DTensorOf<[I64]>:$lookup_table,
@@ -111,7 +113,7 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
let results = (outs 1DTensorOf<[I64]>:$result);
}
def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> {
def BConcrete_BatchedBootstrapLweTensorOp : BConcrete_Op<"batched_bootstrap_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$input_ciphertext,
1DTensorOf<[I64]>:$lookup_table,
@@ -126,7 +128,7 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_
}
// TODO(16bits): hack
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor"> {
let arguments = (ins
2DTensorOf<[I64]>:$ciphertext,
1DTensorOf<[I64]>:$lookupTable,
@@ -149,10 +151,9 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
let results = (outs 2DTensorOf<[I64]>:$result);
}
def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
BConcrete_Op<"keyswitch_lwe_buffer_async_offload"> {
def BConcrete_KeySwitchLweTensorAsyncOffloadOp :
BConcrete_Op<"keyswitch_lwe_tensor_async_offload"> {
let arguments = (ins
// LweKeySwitchKeyType:$keyswitch_key,
1DTensorOf<[I64]>:$ciphertext,
I32Attr:$level,
I32Attr:$baseLog
@@ -160,8 +161,8 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
let results = (outs RT_Future : $result);
}
def BConcrete_BootstrapLweBufferAsyncOffloadOp :
BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> {
def BConcrete_BootstrapLweTensorAsyncOffloadOp :
BConcrete_Op<"bootstrap_lwe_tensor_async_offload"> {
let arguments = (ins
1DTensorOf<[I64]>:$input_ciphertext,
1DTensorOf<[I64]>:$lookup_table,
@@ -175,6 +176,121 @@ def BConcrete_BootstrapLweBufferAsyncOffloadOp :
let results = (outs RT_Future : $result);
}
// BConcrete memref operators /////////////////////////////////////////////////
def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>;
def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
def BConcrete_AddLweBufferOp : BConcrete_Op<"add_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$lhs,
BConcrete_LweBuffer:$rhs
);
}
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$lhs,
I64:$rhs
);
}
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$lhs,
I64:$rhs
);
}
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$ciphertext
);
}
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$ciphertext,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
);
}
def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> {
let arguments = (ins
BConcrete_BatchLweBuffer:$result,
BConcrete_BatchLweBuffer:$ciphertext,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$lwe_dim_in,
I32Attr:$lwe_dim_out
);
}
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
let arguments = (ins
BConcrete_LweBuffer:$result,
BConcrete_LweBuffer:$input_ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
I32Attr:$inputLweDim,
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
);
}
def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> {
let arguments = (ins
BConcrete_BatchLweBuffer:$result,
BConcrete_BatchLweBuffer:$input_ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
I32Attr:$inputLweDim,
I32Attr:$polySize,
I32Attr:$level,
I32Attr:$baseLog,
I32Attr:$glweDimension,
I32Attr:$outPrecision
);
}
// TODO(16bits): hack
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
let arguments = (ins
BConcrete_LweCRTBuffer:$result,
BConcrete_LweCRTBuffer:$ciphertext,
MemRefRankOf<[I64], [1]>:$lookup_table,
// Bootstrap parameters
I32Attr : $bootstrapLevel,
I32Attr : $bootstrapBaseLog,
// Keyswitch parameters
I32Attr : $keyswitchLevel,
I32Attr : $keyswitchBaseLog,
// Packing keyswitch key parameters
I32Attr : $packingKeySwitchInputLweDimension,
I32Attr : $packingKeySwitchoutputPolynomialSize,
I32Attr : $packingKeySwitchLevel,
I32Attr : $packingKeySwitchBaseLog,
// Circuit bootstrap parameters
I32Attr : $circuitBootstrapLevel,
I32Attr : $circuitBootstrapBaseLog,
I64ArrayAttr:$crtDecomposition
);
}
// TODO
def BConcrete_AwaitFutureOp :
BConcrete_Op<"await_future"> {
let arguments = (ins RT_Future : $future);

View File

@@ -69,7 +69,7 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops);
bool parallelizeLoops, bool gpu);
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
llvm::Module &module);

View File

@@ -7,29 +7,375 @@
#include <mlir/Transforms/DialectConversion.h>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
namespace {
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
void runOnOperation() final;
};
} // namespace
void BConcreteToCAPIPass::runOnOperation() {
auto op = this->getOperation();
namespace BConcrete = mlir::concretelang::BConcrete;
namespace arith = mlir::arith;
namespace func = mlir::func;
namespace memref = mlir::memref;
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64";
char memref_add_plaintext_lwe_ciphertext_u64[] =
"memref_add_plaintext_lwe_ciphertext_u64";
char memref_mul_cleartext_lwe_ciphertext_u64[] =
"memref_mul_cleartext_lwe_ciphertext_u64";
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
char memref_await_future[] = "memref_await_future";
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
"memref_expand_lut_in_trivial_glwe_ct_u64";
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
value.getLoc(),
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
auto futureType =
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
auto i32Type = rewriter.getI32Type();
mlir::FunctionType funcType;
if (funcName == memref_add_lwe_ciphertexts_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {});
} else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_negate_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType}, {});
} else if (funcName == memref_keyswitch_lwe_u64 ||
funcName == memref_keyswitch_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_bootstrap_lwe_u64 ||
funcName == memref_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_keyswitch_async_lwe_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, contextType},
{futureType});
} else if (funcName == memref_bootstrap_async_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_await_future) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, futureType, memref1DType, memref1DType}, {});
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref1DType,
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
memref1DType,
},
{});
} else if (funcName == memref_wop_pbs_crt_buffer) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref2DType,
memref2DType,
memref1DType,
memref1DType,
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
contextType,
},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
}
return insertForwardDeclaration(op, rewriter, funcName, funcType);
}
template <typename BConcreteOp>
void addNoOperands(BConcreteOp op, mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {}
template <typename BConcreteOp, char const *callee>
struct BConcreteToCAPICallPattern : public mlir::OpRewritePattern<BConcreteOp> {
BConcreteToCAPICallPattern(
::mlir::MLIRContext *context,
std::function<void(BConcreteOp bOp, llvm::SmallVector<mlir::Value> &,
mlir::RewriterBase &)>
addOperands = addNoOperands<BConcreteOp>,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<BConcreteOp>(context, benefit),
addOperands(addOperands) {}
::mlir::LogicalResult
matchAndRewrite(BConcreteOp bOp,
::mlir::PatternRewriter &rewriter) const override {
// Create the operands
mlir::SmallVector<mlir::Value> operands;
// For all tensor operand get the corresponding casted buffer
for (auto &operand : bOp->getOpOperands()) {
mlir::Type type = operand.get().getType();
if (!type.isa<mlir::MemRefType>()) {
operands.push_back(operand.get());
} else {
operands.push_back(getCastedMemRef(rewriter, operand.get()));
}
}
// append additional argument
addOperands(bOp, operands, rewriter);
// Insert forward declaration of the function
if (insertForwardDeclarationOfTheCAPI(bOp, rewriter, callee).failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<func::CallOp>(bOp, callee, mlir::TypeRange{},
operands);
return ::mlir::success();
};
private:
std::function<void(BConcreteOp bOp, llvm::SmallVector<mlir::Value> &,
mlir::RewriterBase &)>
addOperands;
};
template <typename KeySwitchOp>
void keyswitchAddOperands(KeySwitchOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
// level
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// lwe_dim_in
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_inAttr()));
// lwe_dim_out
operands.push_back(
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_outAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <typename BootstrapOp>
void bootstrapAddOperands(BootstrapOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
mlir::Type crtType = mlir::RankedTensorType::get(
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> values;
for (auto a : op.crtDecomposition()) {
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
}
auto attr = rewriter.getI64TensorAttr(values);
auto x = rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), attr, crtType);
auto globalMemref = mlir::bufferization::getGlobalFor(x, 0);
rewriter.eraseOp(x);
assert(!failed(globalMemref));
auto globalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, globalRef));
// lwe_small_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
// cbs_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapLevelAttr()));
// cbs_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
// polynomial_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
// context
operands.push_back(getContextArgument(op));
}
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
BConcreteToCAPIPass(bool gpu) : gpu(gpu) {}
void runOnOperation() override {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// Mark ops from the target dialect as legal operations
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<BConcrete::BConcreteDialect>();
// Add patterns to transform BConcrete operators to CAPI call
patterns.add<BConcreteToCAPICallPattern<BConcrete::AddLweBufferOp,
memref_add_lwe_ciphertexts_u64>>(
&getContext());
patterns.add<
BConcreteToCAPICallPattern<BConcrete::AddPlaintextLweBufferOp,
memref_add_plaintext_lwe_ciphertext_u64>>(
&getContext());
patterns.add<
BConcreteToCAPICallPattern<BConcrete::MulCleartextLweBufferOp,
memref_mul_cleartext_lwe_ciphertext_u64>>(
&getContext());
patterns.add<BConcreteToCAPICallPattern<BConcrete::NegateLweBufferOp,
memref_negate_lwe_ciphertext_u64>>(
&getContext());
if (gpu) {
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_cuda_u64>>(
&getContext(), keyswitchAddOperands<BConcrete::KeySwitchLweBufferOp>);
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_cuda_u64>>(
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
} else {
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_u64>>(
&getContext(), keyswitchAddOperands<BConcrete::KeySwitchLweBufferOp>);
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_u64>>(
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
patterns.add<
BConcreteToCAPICallPattern<BConcrete::BatchedKeySwitchLweBufferOp,
memref_batched_keyswitch_lwe_u64>>(
&getContext(),
keyswitchAddOperands<BConcrete::BatchedKeySwitchLweBufferOp>);
patterns.add<
BConcreteToCAPICallPattern<BConcrete::BatchedBootstrapLweBufferOp,
memref_batched_bootstrap_lwe_u64>>(
&getContext(),
bootstrapAddOperands<BConcrete::BatchedBootstrapLweBufferOp>);
}
patterns.add<BConcreteToCAPICallPattern<BConcrete::WopPBSCRTLweBufferOp,
memref_wop_pbs_crt_buffer>>(
&getContext(), wopPBSAddOperands);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
private:
bool gpu;
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createConvertBConcreteToCAPIPass() {
return std::make_unique<BConcreteToCAPIPass>();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertBConcreteToCAPIPass(bool gpu) {
return std::make_unique<BConcreteToCAPIPass>(gpu);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -216,7 +216,7 @@ struct LowerKeySwitch : public mlir::OpRewritePattern<
rewriter.getI32IntegerAttr(inputType.getDimension());
mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(
mlir::concretelang::BConcrete::KeySwitchLweTensorOp>(
ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(),
inputDimAttr, outDimAttr);
@@ -261,7 +261,7 @@ struct LowerBatchedKeySwitch
rewriter.getI32IntegerAttr(inputType.getDimension());
mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>(
mlir::concretelang::BConcrete::BatchedKeySwitchLweTensorOp>(
bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(),
bksOp.baseLogAttr(), inputDimAttr, outDimAttr);
@@ -293,7 +293,7 @@ struct LowerBootstrap : public mlir::OpRewritePattern<
auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension());
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(
mlir::concretelang::BConcrete::BootstrapLweTensorOp>(
bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(),
inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(),
bsOp.glweDimensionAttr(), outputPrecisionAttr);
@@ -338,7 +338,7 @@ struct LowerBatchedBootstrap
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp<
mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>(
mlir::concretelang::BConcrete::BatchedBootstrapLweTensorOp>(
bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(),
inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(),
bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr);
@@ -385,7 +385,7 @@ struct AddPlaintextLweCiphertextOpPattern
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweBufferOp>(
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
} else {
@@ -394,7 +394,7 @@ struct AddPlaintextLweCiphertextOpPattern
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweBufferOp>(
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweTensorOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
@@ -436,7 +436,7 @@ struct MulCleartextLweCiphertextOpPattern
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
loc, rewriter.getIntegerType(64), concreteOp.rhs());
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweBufferOp>(
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
concreteOp, newResultTy,
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
} else {
@@ -444,7 +444,7 @@ struct MulCleartextLweCiphertextOpPattern
newAttributes.push_back(rewriter.getNamedAttr(
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
bConcreteOp =
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweBufferOp>(
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweTensorOp>(
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
newAttributes);
}
@@ -1022,14 +1022,14 @@ void ConcreteToBConcretePass::runOnOperation() {
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
LowerBatchedKeySwitch,
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
mlir::concretelang::BConcrete::AddLweBuffersOp,
BConcrete::AddCRTLweBuffersOp>,
mlir::concretelang::BConcrete::AddLweTensorOp,
BConcrete::AddCRTLweTensorOp>,
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
mlir::concretelang::BConcrete::NegateLweBufferOp,
BConcrete::NegateCRTLweBufferOp>,
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
mlir::concretelang::BConcrete::NegateLweTensorOp,
BConcrete::NegateCRTLweTensorOp>,
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp,
BConcrete::WopPBSCRTLweTensorOp>>(&getContext());
// Add patterns to rewrite tensor operators that works on encrypted
// tensors

View File

@@ -22,12 +22,12 @@ void AsyncOffloadPass::runOnOperation() {
auto module = getOperation();
std::vector<mlir::Operation *> ops;
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweBufferOp op) {
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweTensorOp op) {
mlir::OpBuilder builder(op);
mlir::Type futType =
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
mlir::Value future = builder.create<
mlir::concretelang::BConcrete::KeySwitchLweBufferAsyncOffloadOp>(
mlir::concretelang::BConcrete::KeySwitchLweTensorAsyncOffloadOp>(
op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs());
assert(op.getResult().hasOneUse() &&
@@ -43,12 +43,12 @@ void AsyncOffloadPass::runOnOperation() {
}
ops.push_back(op);
});
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweBufferOp op) {
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweTensorOp op) {
mlir::OpBuilder builder(op);
mlir::Type futType =
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
mlir::Value future = builder.create<
mlir::concretelang::BConcrete::BootstrapLweBufferAsyncOffloadOp>(
mlir::concretelang::BConcrete::BootstrapLweTensorAsyncOffloadOp>(
op.getLoc(), mlir::TypeRange{futType}, op.getOperands(),
op->getAttrs());

View File

@@ -27,170 +27,13 @@ using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::tensor;
namespace BConcrete = mlir::concretelang::BConcrete;
namespace mlir {
namespace concretelang {
namespace BConcrete {
namespace {} // namespace
} // namespace BConcrete
} // namespace concretelang
} // namespace mlir
namespace {
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
namespace BConcrete = mlir::concretelang::BConcrete;
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
loc,
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64";
char memref_add_plaintext_lwe_ciphertext_u64[] =
"memref_add_plaintext_lwe_ciphertext_u64";
char memref_mul_cleartext_lwe_ciphertext_u64[] =
"memref_mul_cleartext_lwe_ciphertext_u64";
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
char memref_await_future[] = "memref_await_future";
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
"memref_expand_lut_in_trivial_glwe_ct_u64";
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
auto futureType =
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
auto i32Type = rewriter.getI32Type();
mlir::FunctionType funcType;
if (funcName == memref_add_lwe_ciphertexts_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {});
} else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
} else if (funcName == memref_negate_lwe_ciphertext_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType}, {});
} else if (funcName == memref_keyswitch_lwe_u64 ||
funcName == memref_keyswitch_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_bootstrap_lwe_u64 ||
funcName == memref_bootstrap_lwe_cuda_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref2DType, memref2DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{});
} else if (funcName == memref_keyswitch_async_lwe_u64) {
funcType = mlir::FunctionType::get(
rewriter.getContext(), {memref1DType, memref1DType, contextType},
{futureType});
} else if (funcName == memref_bootstrap_async_lwe_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{memref1DType, memref1DType,
memref1DType, i32Type, i32Type, i32Type,
i32Type, i32Type, i32Type, contextType},
{futureType});
} else if (funcName == memref_await_future) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, futureType, memref1DType, memref1DType}, {});
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref1DType,
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
memref1DType,
},
{});
} else if (funcName == memref_wop_pbs_crt_buffer) {
funcType = mlir::FunctionType::get(rewriter.getContext(),
{
memref2DType,
memref2DType,
memref1DType,
memref1DType,
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
rewriter.getI32Type(),
contextType,
},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
}
return insertForwardDeclaration(op, rewriter, funcName, funcType);
}
template <typename Op>
void pushAdditionalArgs(Op op, mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter);
template <typename Op, char const *funcName>
struct BufferizableWithCallOpInterface
: public BufferizableOpInterface::ExternalModel<
BufferizableWithCallOpInterface<Op, funcName>, Op> {
template <typename TensorOp, typename MemrefOp>
struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
TensorToMemrefOp<TensorOp, MemrefOp>, TensorOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
@@ -215,161 +58,7 @@ struct BufferizableWithCallOpInterface
const BufferizationOptions &options) const {
auto loc = op->getLoc();
auto castOp = cast<Op>(op);
// For now we always alloc for the result, we didn't have the in place
// operators yet.
auto resTensorType =
castOp.result().getType().template cast<mlir::TensorType>();
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
resTensorType.getElementType());
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
if (mlir::failed(outMemref)) {
return mlir::failure();
}
// The first operand is the result
mlir::SmallVector<mlir::Value> operands{
getCastedMemRef(rewriter, loc, *outMemref),
};
// For all tensor operand get the corresponding casted buffer
for (auto &operand : op->getOpOperands()) {
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
operands.push_back(operand.get());
} else {
auto memrefOperand =
bufferization::getBuffer(rewriter, operand.get(), options);
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
}
}
// Append additional argument
pushAdditionalArgs<Op>(castOp, operands, rewriter);
// Insert forward declaration of the function
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
return mlir::failure();
}
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
operands);
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
return success();
}
};
template <typename Op, char const *funcName>
struct BufferizableWithAsyncCallOpInterface
: public BufferizableOpInterface::ExternalModel<
BufferizableWithAsyncCallOpInterface<Op, funcName>, Op> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto loc = op->getLoc();
auto castOp = cast<Op>(op);
// For now we always alloc for the result, we didn't have the in place
// operators yet.
auto resTensorType =
castOp.result()
.getType()
.template cast<mlir::concretelang::RT::FutureType>()
.getElementType()
.template cast<mlir::TensorType>();
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
resTensorType.getElementType());
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
if (mlir::failed(outMemref)) {
return mlir::failure();
}
// The first operand is the result
mlir::SmallVector<mlir::Value> operands{
getCastedMemRef(rewriter, loc, *outMemref),
};
// For all tensor operand get the corresponding casted buffer
for (auto &operand : op->getOpOperands()) {
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
operands.push_back(operand.get());
} else {
auto memrefOperand =
bufferization::getBuffer(rewriter, operand.get(), options);
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
}
}
// Append additional arguments
pushAdditionalArgs<Op>(castOp, operands, rewriter);
// Insert forward declaration of the function
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
return mlir::failure();
}
auto result = rewriter.create<mlir::func::CallOp>(
loc, funcName,
mlir::TypeRange{
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())},
operands);
replaceOpWithBufferizedValues(rewriter, op, result.getResult(0));
return success();
}
};
template <typename Op, char const *funcName>
struct BufferizableWithSyncCallOpInterface
: public BufferizableOpInterface::ExternalModel<
BufferizableWithSyncCallOpInterface<Op, funcName>, Op> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
auto loc = op->getLoc();
auto castOp = cast<Op>(op);
auto castOp = cast<TensorOp>(op);
auto resTensorType =
castOp.result().getType().template cast<mlir::TensorType>();
@@ -383,23 +72,18 @@ struct BufferizableWithSyncCallOpInterface
// The first operand is the result
mlir::SmallVector<mlir::Value, 3> operands{
getCastedMemRef(rewriter, loc, *outMemref),
*outMemref,
};
// Then add the future operand
operands.push_back(op->getOpOperand(0).get());
// Finally add a dependence on the memref covered by the future to
// prevent early deallocation
auto def = op->getOpOperand(0).get().getDefiningOp();
operands.push_back(def->getOpOperand(0).get());
operands.push_back(def->getOpOperand(1).get());
// Insert forward declaration of the function
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
return mlir::failure();
for (auto &operand : op->getOpOperands()) {
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
operands.push_back(operand.get());
} else {
operands.push_back(
bufferization::getBuffer(rewriter, operand.get(), options));
}
}
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
operands);
rewriter.create<MemrefOp>(loc, mlir::TypeRange{}, operands, op->getAttrs());
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
@@ -407,239 +91,49 @@ struct BufferizableWithSyncCallOpInterface
}
};
template <>
void pushAdditionalArgs(BConcrete::AddPlaintextLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {}
template <>
void pushAdditionalArgs(BConcrete::AddLweBuffersOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {}
template <>
void pushAdditionalArgs(BConcrete::MulCleartextLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {}
template <>
void pushAdditionalArgs(BConcrete::NegateLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {}
template <>
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// lwe_dim_in
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_inAttr()));
// lwe_dim_out
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_outAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// lwe_dim_in
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_inAttr()));
// lwe_dim_out
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.lwe_dim_outAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::BootstrapLweBufferAsyncOffloadOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
// input_lwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.inputLweDimAttr()));
// poly_size
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
// level
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
// base_log
operands.push_back(
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
// glwe_dim
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.glweDimensionAttr()));
// out_precision
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.outPrecisionAttr()));
// context
operands.push_back(getContextArgument(op));
}
template <>
void pushAdditionalArgs(BConcrete::WopPBSCRTLweBufferOp op,
mlir::SmallVector<mlir::Value> &operands,
RewriterBase &rewriter) {
mlir::Type crtType = mlir::RankedTensorType::get(
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
std::vector<int64_t> values;
for (auto a : op.crtDecomposition()) {
values.push_back(a.cast<IntegerAttr>().getValue().getZExtValue());
}
auto attr = rewriter.getI64TensorAttr(values);
auto x = rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), attr, crtType);
auto globalMemref = bufferization::getGlobalFor(x, 0);
assert(!failed(globalMemref));
auto globalRef = rewriter.create<memref::GetGlobalOp>(
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
operands.push_back(getCastedMemRef(rewriter, op.getLoc(), globalRef));
// lwe_small_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
// cbs_level_count
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapLevelAttr()));
// cbs_base_log
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
// polynomial_size
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
// context
operands.push_back(getContextArgument(op));
}
} // namespace
void mlir::concretelang::BConcrete::
registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx,
BConcrete::BConcreteDialect *dialect) {
BConcrete::AddLweBuffersOp::attachInterface<BufferizableWithCallOpInterface<
BConcrete::AddLweBuffersOp, memref_add_lwe_ciphertexts_u64>>(*ctx);
BConcrete::AddPlaintextLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::AddPlaintextLweBufferOp,
memref_add_plaintext_lwe_ciphertext_u64>>(*ctx);
BConcrete::MulCleartextLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::MulCleartextLweBufferOp,
memref_mul_cleartext_lwe_ciphertext_u64>>(*ctx);
BConcrete::NegateLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::NegateLweBufferOp,
memref_negate_lwe_ciphertext_u64>>(
// add_lwe_tensor => add_lwe_buffer
BConcrete::AddLweTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::AddLweTensorOp, BConcrete::AddLweBufferOp>>(
*ctx);
// add_plaintext_lwe_tensor => add_plaintext_lwe_buffer
BConcrete::AddPlaintextLweTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::AddPlaintextLweTensorOp,
BConcrete::AddPlaintextLweBufferOp>>(*ctx);
// mul_cleartext_lwe_tensor => mul_cleartext_lwe_buffer
BConcrete::MulCleartextLweTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::MulCleartextLweTensorOp,
BConcrete::MulCleartextLweBufferOp>>(*ctx);
// negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer
BConcrete::NegateLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::NegateLweTensorOp, BConcrete::NegateLweBufferOp>>(*ctx);
// negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer
BConcrete::NegateLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::NegateLweTensorOp, BConcrete::NegateLweBufferOp>>(*ctx);
// keyswitch_lwe_tensor => keyswitch_lwe_buffer
BConcrete::KeySwitchLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::KeySwitchLweTensorOp, BConcrete::KeySwitchLweBufferOp>>(
*ctx);
// bootstrap_lwe_tensor => bootstrap_lwe_buffer
BConcrete::BootstrapLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::BootstrapLweTensorOp, BConcrete::BootstrapLweBufferOp>>(
*ctx);
// batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer
BConcrete::BatchedKeySwitchLweTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::BatchedKeySwitchLweTensorOp,
BConcrete::BatchedKeySwitchLweBufferOp>>(*ctx);
// batched_bootstrap_lwe_tensor => batched_bootstrap_lwe_buffer
BConcrete::BatchedBootstrapLweTensorOp::attachInterface<
TensorToMemrefOp<BConcrete::BatchedBootstrapLweTensorOp,
BConcrete::BatchedBootstrapLweBufferOp>>(*ctx);
// wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer
BConcrete::WopPBSCRTLweTensorOp::attachInterface<TensorToMemrefOp<
BConcrete::WopPBSCRTLweTensorOp, BConcrete::WopPBSCRTLweBufferOp>>(
*ctx);
if (mlir::concretelang::getEmitGPUOption()) {
BConcrete::KeySwitchLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_cuda_u64>>(*ctx);
BConcrete::BootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_cuda_u64>>(*ctx);
} else {
BConcrete::KeySwitchLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
memref_keyswitch_lwe_u64>>(*ctx);
BConcrete::BatchedKeySwitchLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::BatchedKeySwitchLweBufferOp,
memref_batched_keyswitch_lwe_u64>>(*ctx);
BConcrete::BootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
memref_bootstrap_lwe_u64>>(*ctx);
BConcrete::BatchedBootstrapLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<
BConcrete::BatchedBootstrapLweBufferOp,
memref_batched_bootstrap_lwe_u64>>(*ctx);
}
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
memref_wop_pbs_crt_buffer>>(*ctx);
BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface<
BufferizableWithAsyncCallOpInterface<
BConcrete::KeySwitchLweBufferAsyncOffloadOp,
memref_keyswitch_async_lwe_u64>>(*ctx);
BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface<
BufferizableWithAsyncCallOpInterface<
BConcrete::BootstrapLweBufferAsyncOffloadOp,
memref_bootstrap_async_lwe_u64>>(*ctx);
BConcrete::AwaitFutureOp::attachInterface<
BufferizableWithSyncCallOpInterface<BConcrete::AwaitFutureOp,
memref_await_future>>(*ctx);
});
}

View File

@@ -273,16 +273,16 @@ struct BConcreteCRTBinaryOpPattern
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct AddPlaintextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp> {
AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context,
struct AddPlaintextCRTLweTensorOpPattern
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp> {
AddPlaintextCRTLweTensorOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp>(context,
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op,
matchAndRewrite(BConcrete::AddPlaintextCRTLweTensorOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
@@ -381,7 +381,7 @@ struct AddPlaintextCRTLweBufferOpPattern
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::AddPlaintextLweBufferOp>(
auto tmp = builder.create<BConcrete::AddPlaintextLweTensorOp>(
loc, blockTy, blockArg0, blockArg1);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
@@ -436,16 +436,16 @@ struct AddPlaintextCRTLweBufferOpPattern
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
// }
// ```
struct MulCleartextCRTLweBufferOpPattern
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp> {
MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context,
struct MulCleartextCRTLweTensorOpPattern
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp> {
MulCleartextCRTLweTensorOpPattern(mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp>(context,
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp>(context,
benefit) {
}
mlir::LogicalResult
matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op,
matchAndRewrite(BConcrete::MulCleartextCRTLweTensorOp op,
mlir::PatternRewriter &rewriter) const override {
auto resultTy =
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
@@ -494,7 +494,7 @@ struct MulCleartextCRTLweBufferOpPattern
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
auto tmp = builder.create<BConcrete::MulCleartextLweBufferOp>(
auto tmp = builder.create<BConcrete::MulCleartextLweTensorOp>(
loc, blockTy, blockArg0, rhs);
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
@@ -520,22 +520,22 @@ void EliminateCRTOpsPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
// add_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddCRTLweBuffersOp>();
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweBuffersOp,
BConcrete::AddLweBuffersOp>>(
target.addIllegalOp<BConcrete::AddCRTLweTensorOp>();
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweTensorOp,
BConcrete::AddLweTensorOp>>(
&getContext());
// add_plaintext_crt_lwe_buffers
target.addIllegalOp<BConcrete::AddPlaintextCRTLweBufferOp>();
patterns.add<AddPlaintextCRTLweBufferOpPattern>(&getContext());
target.addIllegalOp<BConcrete::AddPlaintextCRTLweTensorOp>();
patterns.add<AddPlaintextCRTLweTensorOpPattern>(&getContext());
// mul_cleartext_crt_lwe_buffer
target.addIllegalOp<BConcrete::MulCleartextCRTLweBufferOp>();
patterns.add<MulCleartextCRTLweBufferOpPattern>(&getContext());
target.addIllegalOp<BConcrete::MulCleartextCRTLweTensorOp>();
patterns.add<MulCleartextCRTLweTensorOpPattern>(&getContext());
target.addIllegalOp<BConcrete::NegateCRTLweBufferOp>();
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweBufferOp,
BConcrete::NegateLweBufferOp>>(
target.addIllegalOp<BConcrete::NegateCRTLweTensorOp>();
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweTensorOp,
BConcrete::NegateLweTensorOp>>(
&getContext());
// This dialect are used to transforms crt ops to bconcrete ops

View File

@@ -414,7 +414,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
// MLIR canonical dialects -> LLVM Dialect
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
mlirContext, module, enablePass, loopParallelize)
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
.failed()) {
return errorDiag("Failed to lower to LLVM dialect");
}

View File

@@ -293,15 +293,13 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
enablePass);
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(), enablePass);
return pm.run(module.getOperation());
}
mlir::LogicalResult
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
std::function<bool(mlir::Pass *)> enablePass,
bool parallelizeLoops) {
bool parallelizeLoops, bool gpu) {
mlir::PassManager pm(&context);
pipelinePrinting("StdToLLVM", pm, context);
@@ -345,6 +343,10 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu),
enablePass);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),

View File

@@ -300,6 +300,7 @@ cmdlineCompilationOptions() {
options.loopParallelize = cmdline::loopParallelize;
options.dataflowParallelize = cmdline::dataflowParallelize;
options.batchConcreteOps = cmdline::batchConcreteOps;
options.asyncOffload = cmdline::asyncOffload;
options.optimizeConcrete = cmdline::optimizeConcrete;
options.emitGPUOps = cmdline::emitGPUOps;

View File

@@ -1,7 +1,7 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
@@ -10,7 +10,7 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
}
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {

View File

@@ -6,7 +6,7 @@
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %c56_i64 = arith.constant 56 : i64
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
@@ -19,7 +19,7 @@ func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concr
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %c59_i64 = arith.constant 59 : i64
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
@@ -30,7 +30,7 @@ func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {

View File

@@ -1,8 +1,8 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> {
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64>
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64>
//CHECK: return %[[V2]] : tensor<1025xi64>
//CHECK: }
func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {

View File

@@ -2,8 +2,8 @@
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: return %[[V2]] : tensor<2049xi64>
//CHECK: }
func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {

View File

@@ -3,7 +3,7 @@
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %c1_i8 = arith.constant 1 : i8
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
@@ -14,7 +14,7 @@ func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concre
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
//CHECK: return %[[V1]] : tensor<1025xi64>
//CHECK: }
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
@@ -24,7 +24,7 @@ func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !C
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {

View File

@@ -1,7 +1,7 @@
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
//CHECK: func.func @neg_lwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64>
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64>
//CHECK: return %[[V0]] : tensor<1025xi64>
//CHECK: }
func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
@@ -10,7 +10,7 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
//CHECK: return %[[V0]] : tensor<5x1025xi64>
//CHECK: }
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {

View File

@@ -0,0 +1,37 @@
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
func.func @add_lwe_ciphertexts(%arg0: memref<2049xi64>, %arg1: memref<2049xi64>, %result : memref<2049xi64>) {
//CHECK: "BConcrete.add_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> ()
"BConcrete.add_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> ()
return
}
func.func @add_plaintext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) {
//CHECK: "BConcrete.add_plaintext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
"BConcrete.add_plaintext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
return
}
func.func @mul_cleartext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) {
//CHECK: "BConcrete.mul_cleartext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
"BConcrete.mul_cleartext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
return
}
func.func @negate_lwe_ciphertext(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
//CHECK: "BConcrete.negate_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>) -> ()
"BConcrete.negate_lwe_buffer"(%result, %arg0) : (memref<2049xi64>, memref<2049xi64>) -> ()
return
}
func.func @bootstrap_lwe(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) {
//CHECK: "BConcrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
"BConcrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
return
}
func.func @keyswitch_lwe(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
//CHECK: "BConcrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
"BConcrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
return
}

View File

@@ -1,91 +1,91 @@
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
//CHECK: func.func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> {
%0 = "BConcrete.add_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>)
%0 = "BConcrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
%0 = "BConcrete.add_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> {
%0 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>)
%0 = "BConcrete.add_plaintext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
%0 = "BConcrete.add_plaintext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> {
%0 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>)
%0 = "BConcrete.mul_cleartext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
%0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
%0 = "BConcrete.mul_cleartext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
%0 = "BConcrete.negate_lwe_buffer"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>)
%0 = "BConcrete.negate_lwe_tensor"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
//CHECK: return %[[V0]] : tensor<5x2049xi64>
//CHECK: }
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
%0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
%0 = "BConcrete.negate_crt_lwe_tensor"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
return %0 : tensor<5x2049xi64>
}
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
%0 = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}
//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
//CHECK: return %[[V0]] : tensor<2049xi64>
//CHECK: }
func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
%0 = "BConcrete.keyswitch_lwe_buffer"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
%0 = "BConcrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
return %0 : tensor<2049xi64>
}

View File

@@ -266,7 +266,8 @@ INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
JitTest, {defaultOptions()}, mlir::concretelang::JITSupport())
std::vector<mlir::concretelang::CompilationOptions> allOptions{
defaultOptions(), loopOptions(), asyncOptions(),
defaultOptions(),
loopOptions(),
#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
dataflowOptions(),
#endif