mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(bconcrete): Separate bufferization and CAPI call generation
This commit is contained in:
@@ -11,7 +11,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `BConcrete` dialect to CAPI calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertBConcreteToCAPIPass();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToCAPIPass(bool gpu);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -15,15 +15,17 @@ include "concretelang/Dialect/RT/IR/RTTypes.td"
|
||||
class BConcrete_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<BConcrete_Dialect, mnemonic, traits>;
|
||||
|
||||
def BConcrete_AddLweBuffersOp : BConcrete_Op<"add_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
// BConcrete tensor operators /////////////////////////////////////////////////
|
||||
|
||||
def BConcrete_AddLweTensorOp : BConcrete_Op<"add_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$lhs,
|
||||
1DTensorOf<[I64]>:$rhs
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
|
||||
def BConcrete_AddCRTLweTensorOp : BConcrete_Op<"add_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
2DTensorOf<[I64]>:$rhs,
|
||||
@@ -32,12 +34,12 @@ def BConcrete_AddCRTLweBuffersOp : BConcrete_Op<"add_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
def BConcrete_AddPlaintextLweTensorOp : BConcrete_Op<"add_plaintext_lwe_tensor"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_buffer"> {
|
||||
def BConcrete_AddPlaintextCRTLweTensorOp : BConcrete_Op<"add_plaintext_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
@@ -46,12 +48,12 @@ def BConcrete_AddPlaintextCRTLweBufferOp : BConcrete_Op<"add_plaintext_crt_lwe_b
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
def BConcrete_MulCleartextLweTensorOp : BConcrete_Op<"mul_cleartext_lwe_tensor"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$lhs, I64:$rhs);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_buffer"> {
|
||||
def BConcrete_MulCleartextCRTLweTensorOp : BConcrete_Op<"mul_cleartext_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$lhs,
|
||||
AnyInteger:$rhs,
|
||||
@@ -60,12 +62,12 @@ def BConcrete_MulCleartextCRTLweBufferOp : BConcrete_Op<"mul_cleartext_crt_lwe_b
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
|
||||
def BConcrete_NegateLweTensorOp : BConcrete_Op<"negate_lwe_tensor"> {
|
||||
let arguments = (ins 1DTensorOf<[I64]>:$ciphertext);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
|
||||
def BConcrete_NegateCRTLweTensorOp : BConcrete_Op<"negate_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
@@ -73,7 +75,7 @@ def BConcrete_NegateCRTLweBufferOp : BConcrete_Op<"negate_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
def BConcrete_KeySwitchLweTensorOp : BConcrete_Op<"keyswitch_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
@@ -85,7 +87,7 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> {
|
||||
def BConcrete_BatchedKeySwitchLweTensorOp : BConcrete_Op<"batched_keyswitch_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
@@ -97,7 +99,7 @@ def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
def BConcrete_BootstrapLweTensorOp : BConcrete_Op<"bootstrap_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
@@ -111,7 +113,7 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> {
|
||||
def BConcrete_BatchedBootstrapLweTensorOp : BConcrete_Op<"batched_bootstrap_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
@@ -126,7 +128,7 @@ def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
def BConcrete_WopPBSCRTLweTensorOp : BConcrete_Op<"wop_pbs_crt_lwe_tensor"> {
|
||||
let arguments = (ins
|
||||
2DTensorOf<[I64]>:$ciphertext,
|
||||
1DTensorOf<[I64]>:$lookupTable,
|
||||
@@ -149,10 +151,9 @@ def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let results = (outs 2DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"keyswitch_lwe_buffer_async_offload"> {
|
||||
def BConcrete_KeySwitchLweTensorAsyncOffloadOp :
|
||||
BConcrete_Op<"keyswitch_lwe_tensor_async_offload"> {
|
||||
let arguments = (ins
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
@@ -160,8 +161,8 @@ def BConcrete_KeySwitchLweBufferAsyncOffloadOp :
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_buffer_async_offload"> {
|
||||
def BConcrete_BootstrapLweTensorAsyncOffloadOp :
|
||||
BConcrete_Op<"bootstrap_lwe_tensor_async_offload"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
@@ -175,6 +176,121 @@ def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
|
||||
// BConcrete memref operators /////////////////////////////////////////////////
|
||||
|
||||
def BConcrete_LweBuffer : MemRefRankOf<[I64], [1]>;
|
||||
def BConcrete_LweCRTBuffer : MemRefRankOf<[I64], [2]>;
|
||||
def BConcrete_BatchLweBuffer : MemRefRankOf<[I64], [2]>;
|
||||
|
||||
def BConcrete_AddLweBufferOp : BConcrete_Op<"add_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$lhs,
|
||||
BConcrete_LweBuffer:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_AddPlaintextLweBufferOp : BConcrete_Op<"add_plaintext_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_MulCleartextLweBufferOp : BConcrete_Op<"mul_cleartext_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$lhs,
|
||||
I64:$rhs
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_NegateLweBufferOp : BConcrete_Op<"negate_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$ciphertext
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_BatchedKeySwitchLweBufferOp : BConcrete_Op<"batched_keyswitch_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_BatchLweBuffer:$result,
|
||||
BConcrete_BatchLweBuffer:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweBuffer:$result,
|
||||
BConcrete_LweBuffer:$input_ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
}
|
||||
|
||||
def BConcrete_BatchedBootstrapLweBufferOp : BConcrete_Op<"batched_bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_BatchLweBuffer:$result,
|
||||
BConcrete_BatchLweBuffer:$input_ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(16bits): hack
|
||||
def BConcrete_WopPBSCRTLweBufferOp : BConcrete_Op<"wop_pbs_crt_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
BConcrete_LweCRTBuffer:$result,
|
||||
BConcrete_LweCRTBuffer:$ciphertext,
|
||||
MemRefRankOf<[I64], [1]>:$lookup_table,
|
||||
// Bootstrap parameters
|
||||
I32Attr : $bootstrapLevel,
|
||||
I32Attr : $bootstrapBaseLog,
|
||||
// Keyswitch parameters
|
||||
I32Attr : $keyswitchLevel,
|
||||
I32Attr : $keyswitchBaseLog,
|
||||
// Packing keyswitch key parameters
|
||||
I32Attr : $packingKeySwitchInputLweDimension,
|
||||
I32Attr : $packingKeySwitchoutputPolynomialSize,
|
||||
I32Attr : $packingKeySwitchLevel,
|
||||
I32Attr : $packingKeySwitchBaseLog,
|
||||
// Circuit bootstrap parameters
|
||||
I32Attr : $circuitBootstrapLevel,
|
||||
I32Attr : $circuitBootstrapBaseLog,
|
||||
I64ArrayAttr:$crtDecomposition
|
||||
);
|
||||
}
|
||||
|
||||
// TODO
|
||||
|
||||
|
||||
|
||||
def BConcrete_AwaitFutureOp :
|
||||
BConcrete_Op<"await_future"> {
|
||||
let arguments = (ins RT_Future : $future);
|
||||
|
||||
@@ -69,7 +69,7 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops);
|
||||
bool parallelizeLoops, bool gpu);
|
||||
|
||||
mlir::LogicalResult optimizeLLVMModule(llvm::LLVMContext &llvmContext,
|
||||
llvm::Module &module);
|
||||
|
||||
@@ -7,29 +7,375 @@
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
|
||||
namespace {
|
||||
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
void runOnOperation() final;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void BConcreteToCAPIPass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
namespace arith = mlir::arith;
|
||||
namespace func = mlir::func;
|
||||
namespace memref = mlir::memref;
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64";
|
||||
char memref_add_plaintext_lwe_ciphertext_u64[] =
|
||||
"memref_add_plaintext_lwe_ciphertext_u64";
|
||||
char memref_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64";
|
||||
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
|
||||
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
|
||||
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
|
||||
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
|
||||
this->signalPassFailure();
|
||||
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
|
||||
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
char memref_await_future[] = "memref_await_future";
|
||||
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
|
||||
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
value.getLoc(),
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
|
||||
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto futureType =
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
|
||||
mlir::FunctionType funcType;
|
||||
|
||||
if (funcName == memref_add_lwe_ciphertexts_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_negate_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_keyswitch_lwe_u64 ||
|
||||
funcName == memref_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64 ||
|
||||
funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_bootstrap_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, futureType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref1DType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
memref1DType,
|
||||
},
|
||||
{});
|
||||
} else if (funcName == memref_wop_pbs_crt_buffer) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref2DType,
|
||||
memref2DType,
|
||||
memref1DType,
|
||||
memref1DType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
contextType,
|
||||
},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return insertForwardDeclaration(op, rewriter, funcName, funcType);
|
||||
}
|
||||
|
||||
template <typename BConcreteOp>
|
||||
void addNoOperands(BConcreteOp op, mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {}
|
||||
|
||||
template <typename BConcreteOp, char const *callee>
|
||||
struct BConcreteToCAPICallPattern : public mlir::OpRewritePattern<BConcreteOp> {
|
||||
BConcreteToCAPICallPattern(
|
||||
::mlir::MLIRContext *context,
|
||||
std::function<void(BConcreteOp bOp, llvm::SmallVector<mlir::Value> &,
|
||||
mlir::RewriterBase &)>
|
||||
addOperands = addNoOperands<BConcreteOp>,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<BConcreteOp>(context, benefit),
|
||||
addOperands(addOperands) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(BConcreteOp bOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
// Create the operands
|
||||
mlir::SmallVector<mlir::Value> operands;
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : bOp->getOpOperands()) {
|
||||
mlir::Type type = operand.get().getType();
|
||||
if (!type.isa<mlir::MemRefType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(getCastedMemRef(rewriter, operand.get()));
|
||||
}
|
||||
}
|
||||
|
||||
// append additional argument
|
||||
addOperands(bOp, operands, rewriter);
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(bOp, rewriter, callee).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(bOp, callee, mlir::TypeRange{},
|
||||
operands);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<void(BConcreteOp bOp, llvm::SmallVector<mlir::Value> &,
|
||||
mlir::RewriterBase &)>
|
||||
addOperands;
|
||||
};
|
||||
|
||||
template <typename KeySwitchOp>
|
||||
void keyswitchAddOperands(KeySwitchOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_inAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(
|
||||
rewriter.create<arith::ConstantOp>(op.getLoc(), op.lwe_dim_outAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <typename BootstrapOp>
|
||||
void bootstrapAddOperands(BootstrapOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
void wopPBSAddOperands(BConcrete::WopPBSCRTLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
mlir::Type crtType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> values;
|
||||
for (auto a : op.crtDecomposition()) {
|
||||
values.push_back(a.cast<mlir::IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto attr = rewriter.getI64TensorAttr(values);
|
||||
auto x = rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), attr, crtType);
|
||||
auto globalMemref = mlir::bufferization::getGlobalFor(x, 0);
|
||||
rewriter.eraseOp(x);
|
||||
assert(!failed(globalMemref));
|
||||
|
||||
auto globalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, globalRef));
|
||||
|
||||
// lwe_small_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
|
||||
// cbs_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapLevelAttr()));
|
||||
// cbs_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
|
||||
// polynomial_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
|
||||
BConcreteToCAPIPass(bool gpu) : gpu(gpu) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<BConcrete::BConcreteDialect>();
|
||||
|
||||
// Add patterns to transform BConcrete operators to CAPI call
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::AddLweBufferOp,
|
||||
memref_add_lwe_ciphertexts_u64>>(
|
||||
&getContext());
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::AddPlaintextLweBufferOp,
|
||||
memref_add_plaintext_lwe_ciphertext_u64>>(
|
||||
&getContext());
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::MulCleartextLweBufferOp,
|
||||
memref_mul_cleartext_lwe_ciphertext_u64>>(
|
||||
&getContext());
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::NegateLweBufferOp,
|
||||
memref_negate_lwe_ciphertext_u64>>(
|
||||
&getContext());
|
||||
if (gpu) {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(
|
||||
&getContext(), keyswitchAddOperands<BConcrete::KeySwitchLweBufferOp>);
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64>>(
|
||||
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
|
||||
} else {
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(
|
||||
&getContext(), keyswitchAddOperands<BConcrete::KeySwitchLweBufferOp>);
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64>>(
|
||||
&getContext(), bootstrapAddOperands<BConcrete::BootstrapLweBufferOp>);
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::BatchedKeySwitchLweBufferOp,
|
||||
memref_batched_keyswitch_lwe_u64>>(
|
||||
&getContext(),
|
||||
keyswitchAddOperands<BConcrete::BatchedKeySwitchLweBufferOp>);
|
||||
patterns.add<
|
||||
BConcreteToCAPICallPattern<BConcrete::BatchedBootstrapLweBufferOp,
|
||||
memref_batched_bootstrap_lwe_u64>>(
|
||||
&getContext(),
|
||||
bootstrapAddOperands<BConcrete::BatchedBootstrapLweBufferOp>);
|
||||
}
|
||||
|
||||
patterns.add<BConcreteToCAPICallPattern<BConcrete::WopPBSCRTLweBufferOp,
|
||||
memref_wop_pbs_crt_buffer>>(
|
||||
&getContext(), wopPBSAddOperands);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
bool gpu;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertBConcreteToCAPIPass() {
|
||||
return std::make_unique<BConcreteToCAPIPass>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertBConcreteToCAPIPass(bool gpu) {
|
||||
return std::make_unique<BConcreteToCAPIPass>(gpu);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -216,7 +216,7 @@ struct LowerKeySwitch : public mlir::OpRewritePattern<
|
||||
rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
|
||||
mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(
|
||||
mlir::concretelang::BConcrete::KeySwitchLweTensorOp>(
|
||||
ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(),
|
||||
inputDimAttr, outDimAttr);
|
||||
|
||||
@@ -261,7 +261,7 @@ struct LowerBatchedKeySwitch
|
||||
rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
|
||||
mlir::Operation *bBatchedKeySwitchOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BatchedKeySwitchLweBufferOp>(
|
||||
mlir::concretelang::BConcrete::BatchedKeySwitchLweTensorOp>(
|
||||
bksOp, bksOp.getType(), bksOp.ciphertexts(), bksOp.levelAttr(),
|
||||
bksOp.baseLogAttr(), inputDimAttr, outDimAttr);
|
||||
|
||||
@@ -293,7 +293,7 @@ struct LowerBootstrap : public mlir::OpRewritePattern<
|
||||
auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
|
||||
mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(
|
||||
mlir::concretelang::BConcrete::BootstrapLweTensorOp>(
|
||||
bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(),
|
||||
inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(),
|
||||
bsOp.glweDimensionAttr(), outputPrecisionAttr);
|
||||
@@ -338,7 +338,7 @@ struct LowerBatchedBootstrap
|
||||
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
|
||||
|
||||
mlir::Operation *bBatchedBootstrapOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BatchedBootstrapLweBufferOp>(
|
||||
mlir::concretelang::BConcrete::BatchedBootstrapLweTensorOp>(
|
||||
bbsOp, bbsOp.getType(), bbsOp.input_ciphertexts(), bbsOp.lookup_table(),
|
||||
inputDimAttr, bbsOp.polySizeAttr(), bbsOp.levelAttr(),
|
||||
bbsOp.baseLogAttr(), bbsOp.glweDimensionAttr(), outputPrecisionAttr);
|
||||
@@ -385,7 +385,7 @@ struct AddPlaintextLweCiphertextOpPattern
|
||||
auto encoded = rewriter.create<mlir::arith::ShLIOp>(
|
||||
loc, rewriter.getI64Type(), castedInt, constantShiftOp);
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweBufferOp>(
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), encoded}, attributes);
|
||||
} else {
|
||||
@@ -394,7 +394,7 @@ struct AddPlaintextLweCiphertextOpPattern
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweBufferOp>(
|
||||
rewriter.replaceOpWithNewOp<BConcrete::AddPlaintextCRTLweTensorOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
@@ -436,7 +436,7 @@ struct MulCleartextLweCiphertextOpPattern
|
||||
mlir::Value castedInt = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
loc, rewriter.getIntegerType(64), concreteOp.rhs());
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweBufferOp>(
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextLweTensorOp>(
|
||||
concreteOp, newResultTy,
|
||||
mlir::ValueRange{concreteOp.lhs(), castedInt}, attributes);
|
||||
} else {
|
||||
@@ -444,7 +444,7 @@ struct MulCleartextLweCiphertextOpPattern
|
||||
newAttributes.push_back(rewriter.getNamedAttr(
|
||||
"crtDecomposition", rewriter.getI64ArrayAttr(crt)));
|
||||
bConcreteOp =
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweBufferOp>(
|
||||
rewriter.replaceOpWithNewOp<BConcrete::MulCleartextCRTLweTensorOp>(
|
||||
concreteOp, newResultTy, concreteOp.getOperation()->getOperands(),
|
||||
newAttributes);
|
||||
}
|
||||
@@ -1022,14 +1022,14 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
LowerBootstrap, LowerBatchedBootstrap, LowerKeySwitch,
|
||||
LowerBatchedKeySwitch,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp,
|
||||
BConcrete::AddCRTLweBuffersOp>,
|
||||
mlir::concretelang::BConcrete::AddLweTensorOp,
|
||||
BConcrete::AddCRTLweTensorOp>,
|
||||
AddPlaintextLweCiphertextOpPattern, MulCleartextLweCiphertextOpPattern,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::NegateLweCiphertextOp,
|
||||
mlir::concretelang::BConcrete::NegateLweBufferOp,
|
||||
BConcrete::NegateCRTLweBufferOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
|
||||
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
|
||||
mlir::concretelang::BConcrete::NegateLweTensorOp,
|
||||
BConcrete::NegateCRTLweTensorOp>,
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweTensorOp,
|
||||
BConcrete::WopPBSCRTLweTensorOp>>(&getContext());
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
|
||||
@@ -22,12 +22,12 @@ void AsyncOffloadPass::runOnOperation() {
|
||||
auto module = getOperation();
|
||||
std::vector<mlir::Operation *> ops;
|
||||
|
||||
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweBufferOp op) {
|
||||
module.walk([&](mlir::concretelang::BConcrete::KeySwitchLweTensorOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferAsyncOffloadOp>(
|
||||
mlir::concretelang::BConcrete::KeySwitchLweTensorAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperand(), op->getAttrs());
|
||||
|
||||
assert(op.getResult().hasOneUse() &&
|
||||
@@ -43,12 +43,12 @@ void AsyncOffloadPass::runOnOperation() {
|
||||
}
|
||||
ops.push_back(op);
|
||||
});
|
||||
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweBufferOp op) {
|
||||
module.walk([&](mlir::concretelang::BConcrete::BootstrapLweTensorOp op) {
|
||||
mlir::OpBuilder builder(op);
|
||||
mlir::Type futType =
|
||||
mlir::concretelang::RT::FutureType::get(op.getResult().getType());
|
||||
mlir::Value future = builder.create<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferAsyncOffloadOp>(
|
||||
mlir::concretelang::BConcrete::BootstrapLweTensorAsyncOffloadOp>(
|
||||
op.getLoc(), mlir::TypeRange{futType}, op.getOperands(),
|
||||
op->getAttrs());
|
||||
|
||||
|
||||
@@ -27,170 +27,13 @@ using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
using namespace mlir::tensor;
|
||||
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace BConcrete {
|
||||
namespace {} // namespace
|
||||
} // namespace BConcrete
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
namespace {
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc,
|
||||
mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
loc,
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
char memref_add_lwe_ciphertexts_u64[] = "memref_add_lwe_ciphertexts_u64";
|
||||
char memref_add_plaintext_lwe_ciphertext_u64[] =
|
||||
"memref_add_plaintext_lwe_ciphertext_u64";
|
||||
char memref_mul_cleartext_lwe_ciphertext_u64[] =
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64";
|
||||
char memref_negate_lwe_ciphertext_u64[] = "memref_negate_lwe_ciphertext_u64";
|
||||
char memref_keyswitch_lwe_u64[] = "memref_keyswitch_lwe_u64";
|
||||
char memref_batched_keyswitch_lwe_u64[] = "memref_batched_keyswitch_lwe_u64";
|
||||
char memref_bootstrap_lwe_u64[] = "memref_bootstrap_lwe_u64";
|
||||
char memref_batched_bootstrap_lwe_u64[] = "memref_batched_bootstrap_lwe_u64";
|
||||
char memref_keyswitch_async_lwe_u64[] = "memref_keyswitch_async_lwe_u64";
|
||||
char memref_bootstrap_async_lwe_u64[] = "memref_bootstrap_async_lwe_u64";
|
||||
char memref_await_future[] = "memref_await_future";
|
||||
char memref_keyswitch_lwe_cuda_u64[] = "memref_keyswitch_lwe_cuda_u64";
|
||||
char memref_bootstrap_lwe_cuda_u64[] = "memref_bootstrap_lwe_cuda_u64";
|
||||
char memref_expand_lut_in_trivial_glwe_ct_u64[] =
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64";
|
||||
|
||||
char memref_wop_pbs_crt_buffer[] = "memref_wop_pbs_crt_buffer";
|
||||
|
||||
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
|
||||
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
auto memref2DType = getDynamicMemrefWithUnknownOffset(rewriter, 2);
|
||||
auto futureType =
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType());
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
auto i32Type = rewriter.getI32Type();
|
||||
|
||||
mlir::FunctionType funcType;
|
||||
|
||||
if (funcName == memref_add_lwe_ciphertexts_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_add_plaintext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_mul_cleartext_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, memref1DType, rewriter.getI64Type()}, {});
|
||||
} else if (funcName == memref_negate_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_keyswitch_lwe_u64 ||
|
||||
funcName == memref_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_keyswitch_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64 ||
|
||||
funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_batched_bootstrap_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref2DType, memref2DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_keyswitch_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_bootstrap_async_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{futureType});
|
||||
} else if (funcName == memref_await_future) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, futureType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref1DType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
memref1DType,
|
||||
},
|
||||
{});
|
||||
} else if (funcName == memref_wop_pbs_crt_buffer) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
memref2DType,
|
||||
memref2DType,
|
||||
memref1DType,
|
||||
memref1DType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
contextType,
|
||||
},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return insertForwardDeclaration(op, rewriter, funcName, funcType);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void pushAdditionalArgs(Op op, mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter);
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithCallOpInterface<Op, funcName>, Op> {
|
||||
template <typename TensorOp, typename MemrefOp>
|
||||
struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
|
||||
TensorToMemrefOp<TensorOp, MemrefOp>, TensorOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
@@ -215,161 +58,7 @@ struct BufferizableWithCallOpInterface
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<Op>(op);
|
||||
|
||||
// For now we always alloc for the result, we didn't have the in place
|
||||
// operators yet.
|
||||
auto resTensorType =
|
||||
castOp.result().getType().template cast<mlir::TensorType>();
|
||||
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
if (mlir::failed(outMemref)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
auto memrefOperand =
|
||||
bufferization::getBuffer(rewriter, operand.get(), options);
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
|
||||
}
|
||||
}
|
||||
// Append additional argument
|
||||
pushAdditionalArgs<Op>(castOp, operands, rewriter);
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
|
||||
operands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithAsyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithAsyncCallOpInterface<Op, funcName>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<Op>(op);
|
||||
|
||||
// For now we always alloc for the result, we didn't have the in place
|
||||
// operators yet.
|
||||
auto resTensorType =
|
||||
castOp.result()
|
||||
.getType()
|
||||
.template cast<mlir::concretelang::RT::FutureType>()
|
||||
.getElementType()
|
||||
.template cast<mlir::TensorType>();
|
||||
|
||||
auto outMemrefType = MemRefType::get(resTensorType.getShape(),
|
||||
resTensorType.getElementType());
|
||||
auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {});
|
||||
if (mlir::failed(outMemref)) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
auto memrefOperand =
|
||||
bufferization::getBuffer(rewriter, operand.get(), options);
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
|
||||
}
|
||||
}
|
||||
|
||||
// Append additional arguments
|
||||
pushAdditionalArgs<Op>(castOp, operands, rewriter);
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto result = rewriter.create<mlir::func::CallOp>(
|
||||
loc, funcName,
|
||||
mlir::TypeRange{
|
||||
mlir::concretelang::RT::FutureType::get(rewriter.getIndexType())},
|
||||
operands);
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, result.getResult(0));
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithSyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithSyncCallOpInterface<Op, funcName>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
auto loc = op->getLoc();
|
||||
auto castOp = cast<Op>(op);
|
||||
auto castOp = cast<TensorOp>(op);
|
||||
|
||||
auto resTensorType =
|
||||
castOp.result().getType().template cast<mlir::TensorType>();
|
||||
@@ -383,23 +72,18 @@ struct BufferizableWithSyncCallOpInterface
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value, 3> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
*outMemref,
|
||||
};
|
||||
// Then add the future operand
|
||||
operands.push_back(op->getOpOperand(0).get());
|
||||
// Finally add a dependence on the memref covered by the future to
|
||||
// prevent early deallocation
|
||||
auto def = op->getOpOperand(0).get().getDefiningOp();
|
||||
operands.push_back(def->getOpOperand(0).get());
|
||||
operands.push_back(def->getOpOperand(1).get());
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
return mlir::failure();
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(
|
||||
bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.create<mlir::func::CallOp>(loc, funcName, mlir::TypeRange{},
|
||||
operands);
|
||||
rewriter.create<MemrefOp>(loc, mlir::TypeRange{}, operands, op->getAttrs());
|
||||
|
||||
replaceOpWithBufferizedValues(rewriter, op, *outMemref);
|
||||
|
||||
@@ -407,239 +91,49 @@ struct BufferizableWithSyncCallOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::AddPlaintextLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {}
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::AddLweBuffersOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {}
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::MulCleartextLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {}
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::NegateLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_inAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_outAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BatchedKeySwitchLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_inAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_outAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BatchedBootstrapLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BootstrapLweBufferAsyncOffloadOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::WopPBSCRTLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
mlir::Type crtType = mlir::RankedTensorType::get(
|
||||
{(int)op.crtDecompositionAttr().size()}, rewriter.getI64Type());
|
||||
std::vector<int64_t> values;
|
||||
for (auto a : op.crtDecomposition()) {
|
||||
values.push_back(a.cast<IntegerAttr>().getValue().getZExtValue());
|
||||
}
|
||||
auto attr = rewriter.getI64TensorAttr(values);
|
||||
auto x = rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), attr, crtType);
|
||||
auto globalMemref = bufferization::getGlobalFor(x, 0);
|
||||
assert(!failed(globalMemref));
|
||||
|
||||
auto globalRef = rewriter.create<memref::GetGlobalOp>(
|
||||
op.getLoc(), (*globalMemref).type(), (*globalMemref).getName());
|
||||
operands.push_back(getCastedMemRef(rewriter, op.getLoc(), globalRef));
|
||||
|
||||
// lwe_small_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchInputLweDimensionAttr()));
|
||||
// cbs_level_count
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapLevelAttr()));
|
||||
// cbs_base_log
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.circuitBootstrapBaseLogAttr()));
|
||||
// polynomial_size
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.packingKeySwitchoutputPolynomialSizeAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void mlir::concretelang::BConcrete::
|
||||
registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx,
|
||||
BConcrete::BConcreteDialect *dialect) {
|
||||
BConcrete::AddLweBuffersOp::attachInterface<BufferizableWithCallOpInterface<
|
||||
BConcrete::AddLweBuffersOp, memref_add_lwe_ciphertexts_u64>>(*ctx);
|
||||
BConcrete::AddPlaintextLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::AddPlaintextLweBufferOp,
|
||||
memref_add_plaintext_lwe_ciphertext_u64>>(*ctx);
|
||||
BConcrete::MulCleartextLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::MulCleartextLweBufferOp,
|
||||
memref_mul_cleartext_lwe_ciphertext_u64>>(*ctx);
|
||||
BConcrete::NegateLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::NegateLweBufferOp,
|
||||
memref_negate_lwe_ciphertext_u64>>(
|
||||
// add_lwe_tensor => add_lwe_buffer
|
||||
BConcrete::AddLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::AddLweTensorOp, BConcrete::AddLweBufferOp>>(
|
||||
*ctx);
|
||||
// add_plaintext_lwe_tensor => add_plaintext_lwe_buffer
|
||||
BConcrete::AddPlaintextLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::AddPlaintextLweTensorOp,
|
||||
BConcrete::AddPlaintextLweBufferOp>>(*ctx);
|
||||
// mul_cleartext_lwe_tensor => mul_cleartext_lwe_buffer
|
||||
BConcrete::MulCleartextLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::MulCleartextLweTensorOp,
|
||||
BConcrete::MulCleartextLweBufferOp>>(*ctx);
|
||||
// negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer
|
||||
BConcrete::NegateLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::NegateLweTensorOp, BConcrete::NegateLweBufferOp>>(*ctx);
|
||||
// negate_cleartext_lwe_tensor => negate_cleartext_lwe_buffer
|
||||
BConcrete::NegateLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::NegateLweTensorOp, BConcrete::NegateLweBufferOp>>(*ctx);
|
||||
// keyswitch_lwe_tensor => keyswitch_lwe_buffer
|
||||
BConcrete::KeySwitchLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::KeySwitchLweTensorOp, BConcrete::KeySwitchLweBufferOp>>(
|
||||
*ctx);
|
||||
// bootstrap_lwe_tensor => bootstrap_lwe_buffer
|
||||
BConcrete::BootstrapLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::BootstrapLweTensorOp, BConcrete::BootstrapLweBufferOp>>(
|
||||
*ctx);
|
||||
// batched_keyswitch_lwe_tensor => batched_keyswitch_lwe_buffer
|
||||
BConcrete::BatchedKeySwitchLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::BatchedKeySwitchLweTensorOp,
|
||||
BConcrete::BatchedKeySwitchLweBufferOp>>(*ctx);
|
||||
// batched_bootstrap_lwe_tensor => batched_bootstrap_lwe_buffer
|
||||
BConcrete::BatchedBootstrapLweTensorOp::attachInterface<
|
||||
TensorToMemrefOp<BConcrete::BatchedBootstrapLweTensorOp,
|
||||
BConcrete::BatchedBootstrapLweBufferOp>>(*ctx);
|
||||
// wop_pbs_crt_lwe_tensor => wop_pbs_crt_lwe_buffer
|
||||
BConcrete::WopPBSCRTLweTensorOp::attachInterface<TensorToMemrefOp<
|
||||
BConcrete::WopPBSCRTLweTensorOp, BConcrete::WopPBSCRTLweBufferOp>>(
|
||||
*ctx);
|
||||
if (mlir::concretelang::getEmitGPUOption()) {
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64>>(*ctx);
|
||||
} else {
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BatchedKeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::BatchedKeySwitchLweBufferOp,
|
||||
memref_batched_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64>>(*ctx);
|
||||
BConcrete::BatchedBootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<
|
||||
BConcrete::BatchedBootstrapLweBufferOp,
|
||||
memref_batched_bootstrap_lwe_u64>>(*ctx);
|
||||
}
|
||||
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
|
||||
memref_wop_pbs_crt_buffer>>(*ctx);
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp,
|
||||
memref_keyswitch_async_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp,
|
||||
memref_bootstrap_async_lwe_u64>>(*ctx);
|
||||
BConcrete::AwaitFutureOp::attachInterface<
|
||||
BufferizableWithSyncCallOpInterface<BConcrete::AwaitFutureOp,
|
||||
memref_await_future>>(*ctx);
|
||||
});
|
||||
}
|
||||
|
||||
@@ -273,16 +273,16 @@ struct BConcreteCRTBinaryOpPattern
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct AddPlaintextCRTLweBufferOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp> {
|
||||
AddPlaintextCRTLweBufferOpPattern(mlir::MLIRContext *context,
|
||||
struct AddPlaintextCRTLweTensorOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp> {
|
||||
AddPlaintextCRTLweTensorOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweBufferOp>(context,
|
||||
: mlir::OpRewritePattern<BConcrete::AddPlaintextCRTLweTensorOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::AddPlaintextCRTLweBufferOp op,
|
||||
matchAndRewrite(BConcrete::AddPlaintextCRTLweTensorOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
@@ -381,7 +381,7 @@ struct AddPlaintextCRTLweBufferOpPattern
|
||||
auto blockArg1 = builder.create<tensor::ExtractOp>(loc, x_decomp, i);
|
||||
// %tmp = "BConcreteOp"(%blockArg0, %blockArg1)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::AddPlaintextLweBufferOp>(
|
||||
auto tmp = builder.create<BConcrete::AddPlaintextLweTensorOp>(
|
||||
loc, blockTy, blockArg0, blockArg1);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
@@ -436,16 +436,16 @@ struct AddPlaintextCRTLweBufferOpPattern
|
||||
// scf.yield %res : tensor<nbBlocksxlweSizexi64>
|
||||
// }
|
||||
// ```
|
||||
struct MulCleartextCRTLweBufferOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp> {
|
||||
MulCleartextCRTLweBufferOpPattern(mlir::MLIRContext *context,
|
||||
struct MulCleartextCRTLweTensorOpPattern
|
||||
: public mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp> {
|
||||
MulCleartextCRTLweTensorOpPattern(mlir::MLIRContext *context,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweBufferOp>(context,
|
||||
: mlir::OpRewritePattern<BConcrete::MulCleartextCRTLweTensorOp>(context,
|
||||
benefit) {
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(BConcrete::MulCleartextCRTLweBufferOp op,
|
||||
matchAndRewrite(BConcrete::MulCleartextCRTLweTensorOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultTy =
|
||||
((mlir::Type)op.getResult().getType()).cast<mlir::RankedTensorType>();
|
||||
@@ -494,7 +494,7 @@ struct MulCleartextCRTLweBufferOpPattern
|
||||
|
||||
// %tmp = BConcrete.mul_cleartext_lwe_buffer(%blockArg0, %x)
|
||||
// : (tensor<lweSizexi64>, i64) -> (tensor<lweSizexi64>)
|
||||
auto tmp = builder.create<BConcrete::MulCleartextLweBufferOp>(
|
||||
auto tmp = builder.create<BConcrete::MulCleartextLweTensorOp>(
|
||||
loc, blockTy, blockArg0, rhs);
|
||||
|
||||
// %res = tensor.insert_slice %tmp into %acc[%i, 0] [1, lweSize] [1,
|
||||
@@ -520,22 +520,22 @@ void EliminateCRTOpsPass::runOnOperation() {
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// add_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddCRTLweBuffersOp>();
|
||||
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweBuffersOp,
|
||||
BConcrete::AddLweBuffersOp>>(
|
||||
target.addIllegalOp<BConcrete::AddCRTLweTensorOp>();
|
||||
patterns.add<BConcreteCRTBinaryOpPattern<BConcrete::AddCRTLweTensorOp,
|
||||
BConcrete::AddLweTensorOp>>(
|
||||
&getContext());
|
||||
|
||||
// add_plaintext_crt_lwe_buffers
|
||||
target.addIllegalOp<BConcrete::AddPlaintextCRTLweBufferOp>();
|
||||
patterns.add<AddPlaintextCRTLweBufferOpPattern>(&getContext());
|
||||
target.addIllegalOp<BConcrete::AddPlaintextCRTLweTensorOp>();
|
||||
patterns.add<AddPlaintextCRTLweTensorOpPattern>(&getContext());
|
||||
|
||||
// mul_cleartext_crt_lwe_buffer
|
||||
target.addIllegalOp<BConcrete::MulCleartextCRTLweBufferOp>();
|
||||
patterns.add<MulCleartextCRTLweBufferOpPattern>(&getContext());
|
||||
target.addIllegalOp<BConcrete::MulCleartextCRTLweTensorOp>();
|
||||
patterns.add<MulCleartextCRTLweTensorOpPattern>(&getContext());
|
||||
|
||||
target.addIllegalOp<BConcrete::NegateCRTLweBufferOp>();
|
||||
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweBufferOp,
|
||||
BConcrete::NegateLweBufferOp>>(
|
||||
target.addIllegalOp<BConcrete::NegateCRTLweTensorOp>();
|
||||
patterns.add<BConcreteCRTUnaryOpPattern<BConcrete::NegateCRTLweTensorOp,
|
||||
BConcrete::NegateLweTensorOp>>(
|
||||
&getContext());
|
||||
|
||||
// This dialect are used to transforms crt ops to bconcrete ops
|
||||
|
||||
@@ -414,7 +414,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// MLIR canonical dialects -> LLVM Dialect
|
||||
if (mlir::concretelang::pipeline::lowerStdToLLVMDialect(
|
||||
mlirContext, module, enablePass, loopParallelize)
|
||||
mlirContext, module, enablePass, loopParallelize, options.emitGPUOps)
|
||||
.failed()) {
|
||||
return errorDiag("Failed to lower to LLVM dialect");
|
||||
}
|
||||
|
||||
@@ -293,15 +293,13 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(pm, mlir::concretelang::createAddRuntimeContext(),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(), enablePass);
|
||||
return pm.run(module.getOperation());
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops) {
|
||||
bool parallelizeLoops, bool gpu) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("StdToLLVM", pm, context);
|
||||
|
||||
@@ -345,6 +343,10 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
|
||||
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu),
|
||||
enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass(),
|
||||
|
||||
@@ -300,6 +300,7 @@ cmdlineCompilationOptions() {
|
||||
options.loopParallelize = cmdline::loopParallelize;
|
||||
options.dataflowParallelize = cmdline::dataflowParallelize;
|
||||
options.batchConcreteOps = cmdline::batchConcreteOps;
|
||||
options.asyncOffload = cmdline::asyncOffload;
|
||||
options.optimizeConcrete = cmdline::optimizeConcrete;
|
||||
options.emitGPUOps = cmdline::emitGPUOps;
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !Concrete.lwe_ciphertext<2048,7>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
@@ -10,7 +10,7 @@ func.func @add_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: !
|
||||
}
|
||||
|
||||
//CHECK: func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>, %arg1: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],2048,7> {
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %c56_i64 = arith.constant 56 : i64
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c56_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
@@ -19,7 +19,7 @@ func.func @add_glwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concr
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %c59_i64 = arith.constant 59 : i64
|
||||
//CHECK: %[[V1:.*]] = arith.shli %[[V0]], %c59_i64 : i64
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[V1]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
@@ -30,7 +30,7 @@ func.func @add_glwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %c1_i8) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i8) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,7> {
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_tensor"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V2]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//CHECK: func.func @mul_lwe_const_int(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i8 = arith.constant 1 : i8
|
||||
//CHECK: %[[V0:.*]] = arith.extui %c1_i8 : i8 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concrete.lwe_ciphertext<1024,7> {
|
||||
@@ -14,7 +14,7 @@ func.func @mul_lwe_const_int(%arg0: !Concrete.lwe_ciphertext<1024,7>) -> !Concre
|
||||
|
||||
//CHECK: func.func @mul_lwe_int(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: i5) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = arith.extui %[[A1]] : i5 to i64
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[V0]]) : (tensor<1025xi64>, i64) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V1]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
@@ -24,7 +24,7 @@ func.func @mul_lwe_int(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: i5) -> !C
|
||||
|
||||
|
||||
//CHECK: func.func @mul_cleartext_lwe_ciphertext_crt(%[[A0:.*]]: tensor<5x1025xi64>, %[[A1:.*]]: i5) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>, i5) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_lwe_ciphertext_crt(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>, %arg1: i5) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @neg_lwe(%[[A0:.*]]: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<1025xi64>) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
@@ -10,7 +10,7 @@ func.func @neg_lwe(%arg0: !Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_cip
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<5x1025xi64>) -> tensor<5x1025xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x1025xi64>) -> tensor<5x1025xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x1025xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4>) -> !Concrete.lwe_ciphertext<crt=[2,3,5,7,11],1024,4> {
|
||||
|
||||
37
compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir
Normal file
37
compiler/tests/check_tests/Dialect/BConcrete/ops_memref.mlir
Normal file
@@ -0,0 +1,37 @@
|
||||
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
func.func @add_lwe_ciphertexts(%arg0: memref<2049xi64>, %arg1: memref<2049xi64>, %result : memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.add_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
"BConcrete.add_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.add_plaintext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
|
||||
"BConcrete.add_plaintext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: memref<2049xi64>, %arg1: i64, %result: memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.mul_cleartext_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
|
||||
"BConcrete.mul_cleartext_lwe_buffer"(%result, %arg0, %arg1) : (memref<2049xi64>, memref<2049xi64>, i64) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @negate_lwe_ciphertext(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.negate_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
"BConcrete.negate_lwe_buffer"(%result, %arg0) : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @bootstrap_lwe(%arg0: memref<2049xi64>, %arg1: memref<16xi64>, %result: memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.bootstrap_lwe_buffer"(%[[R:.*]], %[[A0:.*]], %[[A1:.*]]) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
"BConcrete.bootstrap_lwe_buffer"(%result, %arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>, memref<16xi64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
func.func @keyswitch_lwe(%arg0: memref<2049xi64>, %result: memref<2049xi64>) {
|
||||
//CHECK: "BConcrete.keyswitch_lwe_buffer"(%[[R:.*]], %[[A0:.*]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
"BConcrete.keyswitch_lwe_buffer"(%result, %arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (memref<2049xi64>, memref<2049xi64>) -> ()
|
||||
return
|
||||
}
|
||||
@@ -1,91 +1,91 @@
|
||||
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @add_lwe_ciphertexts(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_lwe_ciphertexts(%arg0: tensor<2049xi64>, %arg1: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.add_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>)
|
||||
%0 = "BConcrete.add_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, tensor<2049xi64>) -> ( tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_crt_lwe_ciphertexts(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_crt_lwe_ciphertexts(%arg0: tensor<5x2049xi64>, %arg1: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
|
||||
%0 = "BConcrete.add_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, tensor<5x2049xi64>) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_plaintext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.add_plaintext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>)
|
||||
%0 = "BConcrete.add_plaintext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> ( tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @add_plaintext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.add_plaintext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @add_plaintext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.add_plaintext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
|
||||
%0 = "BConcrete.add_plaintext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> ( tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func @mul_cleartext_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>, %[[A1:.*]]: i64) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_buffer"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_lwe_tensor"(%[[A0]], %[[A1]]) : (tensor<2049xi64>, i64) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_lwe_ciphertext(%arg0: tensor<2049xi64>, %arg1: i64) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.mul_cleartext_lwe_buffer"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>)
|
||||
%0 = "BConcrete.mul_cleartext_lwe_tensor"(%arg0, %arg1) : (tensor<2049xi64>, i64) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @mul_cleartext_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>, %[[A1:.*]]: i64) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_buffer"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.mul_cleartext_crt_lwe_tensor"(%[[A0]], %[[A1]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @mul_cleartext_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>, %arg1: i64) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.mul_cleartext_crt_lwe_buffer"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
|
||||
%0 = "BConcrete.mul_cleartext_crt_lwe_tensor"(%arg0, %arg1) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>, i64) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_lwe_ciphertext(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_buffer"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_lwe_tensor"(%[[A0]]) : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_lwe_ciphertext(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.negate_lwe_buffer"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "BConcrete.negate_lwe_tensor"(%arg0) : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @negate_crt_lwe_ciphertext(%[[A0:.*]]: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_buffer"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.negate_crt_lwe_tensor"(%[[A0]]) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> tensor<5x2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<5x2049xi64>
|
||||
//CHECK: }
|
||||
func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049xi64> {
|
||||
%0 = "BConcrete.negate_crt_lwe_buffer"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
|
||||
%0 = "BConcrete.negate_crt_lwe_tensor"(%arg0) {crtDecomposition = [2, 3, 5, 7, 11]} : (tensor<5x2049xi64>) -> (tensor<5x2049xi64>)
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "BConcrete.bootstrap_lwe_tensor"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_tensor"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.keyswitch_lwe_buffer"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "BConcrete.keyswitch_lwe_tensor"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
@@ -266,7 +266,8 @@ INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
|
||||
JitTest, {defaultOptions()}, mlir::concretelang::JITSupport())
|
||||
|
||||
std::vector<mlir::concretelang::CompilationOptions> allOptions{
|
||||
defaultOptions(), loopOptions(), asyncOptions(),
|
||||
defaultOptions(),
|
||||
loopOptions(),
|
||||
#ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED
|
||||
dataflowOptions(),
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user