mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: replace some operands by attrs in bs/ks
This commit is contained in:
@@ -12,7 +12,7 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Concrete` dialect to `BConcrete` dialect.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize, bool emitGPUOps);
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize);
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -78,7 +78,9 @@ def BConcrete_KeySwitchLweBufferOp : BConcrete_Op<"keyswitch_lwe_buffer"> {
|
||||
// LweKeySwitchKeyType:$keyswitch_key,
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$lwe_dim_in,
|
||||
I32Attr:$lwe_dim_out
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
@@ -87,12 +89,12 @@ def BConcrete_BootstrapLweBufferOp : BConcrete_Op<"bootstrap_lwe_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
@@ -137,12 +139,12 @@ def BConcrete_BootstrapLweBufferAsyncOffloadOp :
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
I32Attr:$inputLweDim,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$glweDimension,
|
||||
I32Attr:$outPrecision
|
||||
);
|
||||
let results = (outs RT_Future : $result);
|
||||
}
|
||||
@@ -153,35 +155,4 @@ def BConcrete_AwaitFutureOp :
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// This is a different op in BConcrete just because of the way we are lowering to CAPI
|
||||
// When the CAPI lowering is detached from bufferization, we can remove this op, and lower
|
||||
// to the appropriate CAPI (gpu or cpu) depending on the emitGPUOps compilation option
|
||||
def BConcrete_BootstrapLweGPUBufferOp : BConcrete_Op<"bootstrap_lwe_gpu_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
// This is a different op in BConcrete just because of the way we are lowering to CAPI
|
||||
// When the CAPI lowering is detached from bufferization, we can remove this op, and lower
|
||||
// to the appropriate CAPI (gpu or cpu) depending on the emitGPUOps compilation option
|
||||
def BConcrete_KeySwitchLweGPUBufferOp : BConcrete_Op<"keyswitch_lwe_gpu_buffer"> {
|
||||
let arguments = (ins
|
||||
1DTensorOf<[I64]>:$ciphertext,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$lwe_dim_in,
|
||||
I32:$lwe_dim_out
|
||||
);
|
||||
let results = (outs 1DTensorOf<[I64]>:$result);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -58,12 +58,10 @@ def Concrete_BootstrapLweOp : Concrete_Op<"bootstrap_lwe"> {
|
||||
let arguments = (ins
|
||||
Concrete_LweCiphertextType:$input_ciphertext,
|
||||
1DTensorOf<[I64]>:$lookup_table,
|
||||
I32:$inputLweDim,
|
||||
I32:$polySize,
|
||||
I32:$level,
|
||||
I32:$baseLog,
|
||||
I32:$glweDimension,
|
||||
I32:$outPrecision
|
||||
I32Attr:$level,
|
||||
I32Attr:$baseLog,
|
||||
I32Attr:$polySize,
|
||||
I32Attr:$glweDimension
|
||||
);
|
||||
let results = (outs Concrete_LweCiphertextType:$result);
|
||||
}
|
||||
|
||||
@@ -99,10 +99,9 @@ def TFHE_BootstrapGLWEOp : TFHE_Op<"bootstrap_glwe"> {
|
||||
let arguments = (ins
|
||||
TFHE_GLWECipherTextType : $ciphertext,
|
||||
1DTensorOf<[I64]> : $lookup_table,
|
||||
I32Attr : $inputLweDim,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $level,
|
||||
I32Attr : $baseLog,
|
||||
I32Attr : $polySize,
|
||||
I32Attr : $glweDimension
|
||||
);
|
||||
|
||||
|
||||
@@ -62,6 +62,8 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
void *memref_keyswitch_async_lwe_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
bool getEmitGPUOption();
|
||||
|
||||
/// Compilation context that acts as the root owner of LLVM and MLIR
|
||||
/// data structures directly and indirectly referenced by artefacts
|
||||
/// produced by the `CompilerEngine`.
|
||||
|
||||
@@ -47,7 +47,7 @@ lowerTFHEToConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool emitGPUOps);
|
||||
bool parallelizeLoops);
|
||||
|
||||
mlir::LogicalResult
|
||||
optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
|
||||
@@ -48,12 +48,11 @@ struct ConcreteToBConcretePass
|
||||
: public ConcreteToBConcreteBase<ConcreteToBConcretePass> {
|
||||
void runOnOperation() final;
|
||||
ConcreteToBConcretePass() = delete;
|
||||
ConcreteToBConcretePass(bool loopParallelize, bool emitGPUOps)
|
||||
: loopParallelize(loopParallelize), emitGPUOps(emitGPUOps){};
|
||||
ConcreteToBConcretePass(bool loopParallelize)
|
||||
: loopParallelize(loopParallelize){};
|
||||
|
||||
private:
|
||||
bool loopParallelize;
|
||||
bool emitGPUOps;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@@ -201,43 +200,67 @@ struct LowToBConcrete : public mlir::OpRewritePattern<ConcreteOp> {
|
||||
};
|
||||
};
|
||||
|
||||
struct KeySwitchToGPU : public mlir::OpRewritePattern<
|
||||
struct LowerKeySwitch : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::KeySwitchLweOp> {
|
||||
KeySwitchToGPU(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
LowerKeySwitch(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::KeySwitchLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp keySwitchOp,
|
||||
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp ksOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::Value levelCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
keySwitchOp.getLoc(), keySwitchOp.level(), 32);
|
||||
mlir::Value baseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
keySwitchOp.getLoc(), keySwitchOp.baseLog(), 32);
|
||||
|
||||
// construct operands for in/out dimensions
|
||||
mlir::concretelang::Concrete::LweCiphertextType outType =
|
||||
keySwitchOp.getType();
|
||||
auto outDim = outType.getDimension();
|
||||
mlir::Value outDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
keySwitchOp.getLoc(), outDim, 32);
|
||||
// construct attributes for in/out dimensions
|
||||
mlir::concretelang::Concrete::LweCiphertextType outType = ksOp.getType();
|
||||
auto outDimAttr = rewriter.getI32IntegerAttr(outType.getDimension());
|
||||
auto inputType =
|
||||
keySwitchOp.ciphertext()
|
||||
ksOp.ciphertext()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto inputDim = inputType.getDimension();
|
||||
mlir::Value inputDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
keySwitchOp.getLoc(), inputDim, 32);
|
||||
mlir::IntegerAttr inputDimAttr =
|
||||
rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
|
||||
mlir::Operation *bKeySwitchGPUOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweGPUBufferOp>(
|
||||
keySwitchOp, outType, keySwitchOp.ciphertext(), levelCst, baseLogCst,
|
||||
inputDimCst, outDimCst);
|
||||
mlir::Operation *bKeySwitchOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>(
|
||||
ksOp, outType, ksOp.ciphertext(), ksOp.levelAttr(), ksOp.baseLogAttr(),
|
||||
inputDimAttr, outDimAttr);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bKeySwitchGPUOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
rewriter, bKeySwitchOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
|
||||
struct LowerBootstrap : public mlir::OpRewritePattern<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp> {
|
||||
LowerBootstrap(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<mlir::concretelang::Concrete::BootstrapLweOp>(
|
||||
context, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp bsOp,
|
||||
::mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToBConcreteTypeConverter converter;
|
||||
|
||||
mlir::concretelang::Concrete::LweCiphertextType outType = bsOp.getType();
|
||||
auto inputType =
|
||||
bsOp.input_ciphertext()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
auto inputDimAttr = rewriter.getI32IntegerAttr(inputType.getDimension());
|
||||
auto outputPrecisionAttr = rewriter.getI32IntegerAttr(outType.getP());
|
||||
mlir::Operation *bBootstrapOp = rewriter.replaceOpWithNewOp<
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>(
|
||||
bsOp, outType, bsOp.input_ciphertext(), bsOp.lookup_table(),
|
||||
inputDimAttr, bsOp.polySizeAttr(), bsOp.levelAttr(), bsOp.baseLogAttr(),
|
||||
bsOp.glweDimensionAttr(), outputPrecisionAttr);
|
||||
|
||||
mlir::concretelang::convertOperandAndResultTypes(
|
||||
rewriter, bBootstrapOp, [&](mlir::MLIRContext *, mlir::Type t) {
|
||||
return converter.convertType(t);
|
||||
});
|
||||
|
||||
@@ -909,6 +932,7 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
// Add patterns to trivialy convert Concrete op to the equivalent
|
||||
// BConcrete op
|
||||
patterns.insert<
|
||||
LowerBootstrap, LowerKeySwitch,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::AddLweCiphertextsOp,
|
||||
mlir::concretelang::BConcrete::AddLweBuffersOp,
|
||||
BConcrete::AddCRTLweBuffersOp>,
|
||||
@@ -919,24 +943,6 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
LowToBConcrete<Concrete::WopPBSLweOp, BConcrete::WopPBSCRTLweBufferOp,
|
||||
BConcrete::WopPBSCRTLweBufferOp>>(&getContext());
|
||||
|
||||
if (this->emitGPUOps) {
|
||||
patterns
|
||||
.insert<LowToBConcrete<
|
||||
mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweGPUBufferOp>,
|
||||
KeySwitchToGPU>(&getContext());
|
||||
} else {
|
||||
patterns.insert<
|
||||
LowToBConcrete<mlir::concretelang::Concrete::BootstrapLweOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp,
|
||||
mlir::concretelang::BConcrete::BootstrapLweBufferOp>,
|
||||
LowToBConcrete<mlir::concretelang::Concrete::KeySwitchLweOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp,
|
||||
mlir::concretelang::BConcrete::KeySwitchLweBufferOp>>(
|
||||
&getContext());
|
||||
}
|
||||
|
||||
// Add patterns to rewrite tensor operators that works on encrypted
|
||||
// tensors
|
||||
patterns
|
||||
@@ -1063,8 +1069,8 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize, bool emitGPUOps) {
|
||||
return std::make_unique<ConcreteToBConcretePass>(loopParallelize, emitGPUOps);
|
||||
createConvertConcreteToBConcretePass(bool loopParallelize) {
|
||||
return std::make_unique<ConcreteToBConcretePass>(loopParallelize);
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -113,7 +113,7 @@ struct ApplyLookupTableEintOpToKeyswitchBootstrapPattern
|
||||
});
|
||||
// %0 = "TFHE.bootstrap_glwe"(%glwe_ks, %glwe_lut)
|
||||
rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1, -1);
|
||||
lutOp, resultTy, glweKs, lutOp.lut(), -1, -1, -1, -1);
|
||||
return ::mlir::success();
|
||||
};
|
||||
};
|
||||
@@ -146,8 +146,6 @@ struct ApplyLookupTableEintOpToWopPBSPattern
|
||||
matchAndRewrite(FHE::ApplyLookupTableEintOp lutOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
FHEToTFHETypeConverter converter;
|
||||
auto inputTy = converter.convertType(lutOp.a().getType())
|
||||
.cast<TFHE::GLWECipherTextType>();
|
||||
auto resultTy = converter.convertType(lutOp.getType());
|
||||
// %0 = "TFHE.wop_pbs_glwe"(%ct, %lut)
|
||||
// : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
|
||||
|
||||
@@ -154,9 +154,8 @@ struct BootstrapGLWEOpPattern
|
||||
auto newOutputTy = converter.convertType(outputTy);
|
||||
auto newOp = rewriter.replaceOpWithNewOp<TFHE::BootstrapGLWEOp>(
|
||||
bsOp, newOutputTy, bsOp.ciphertext(), bsOp.lookup_table(),
|
||||
cryptoParameters.nSmall, cryptoParameters.getPolynomialSize(),
|
||||
cryptoParameters.brLevel, cryptoParameters.brLogBase,
|
||||
cryptoParameters.glweDimension);
|
||||
cryptoParameters.getPolynomialSize(), cryptoParameters.glweDimension);
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.ciphertext().setType(newInputTy);
|
||||
rewriter.finalizeRootUpdate(newOp);
|
||||
|
||||
@@ -79,25 +79,9 @@ struct BootstrapGLWEOpPattern
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
mlir::Type resultType = converter.convertType(bsOp.getType());
|
||||
|
||||
auto precision = bsOp.getType().cast<TFHE::GLWECipherTextType>().getP();
|
||||
|
||||
mlir::Value inputLweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.inputLweDim(), 32);
|
||||
mlir::Value polySizeCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.polySize(), 32);
|
||||
mlir::Value levelCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.level(), 32);
|
||||
mlir::Value baseLogCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.baseLog(), 32);
|
||||
mlir::Value glweDimCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), bsOp.glweDimension(), 32);
|
||||
mlir::Value precisionCst = rewriter.create<mlir::arith::ConstantIntOp>(
|
||||
bsOp.getLoc(), precision, 32);
|
||||
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Concrete::BootstrapLweOp>(
|
||||
bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(),
|
||||
inputLweDimCst, polySizeCst, levelCst, baseLogCst, glweDimCst,
|
||||
precisionCst);
|
||||
bsOp, resultType, bsOp.ciphertext(), bsOp.lookup_table(), bsOp.level(),
|
||||
bsOp.baseLog(), bsOp.polySize(), bsOp.glweDimension());
|
||||
|
||||
rewriter.startRootUpdate(newOp);
|
||||
newOp.input_ciphertext().setType(
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include <mlir/IR/AffineExpr.h>
|
||||
#include <mlir/IR/AffineMap.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
@@ -110,15 +111,14 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
} else if (funcName == memref_negate_lwe_ciphertext_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_keyswitch_lwe_u64) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(), {memref1DType, memref1DType, contextType}, {});
|
||||
} else if (funcName == memref_keyswitch_lwe_cuda_u64) {
|
||||
} else if (funcName == memref_keyswitch_lwe_u64 ||
|
||||
funcName == memref_keyswitch_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_bootstrap_lwe_u64) {
|
||||
} else if (funcName == memref_bootstrap_lwe_u64 ||
|
||||
funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
@@ -138,12 +138,6 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, futureType, memref1DType, memref1DType}, {});
|
||||
} else if (funcName == memref_bootstrap_lwe_cuda_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{memref1DType, memref1DType,
|
||||
memref1DType, i32Type, i32Type, i32Type,
|
||||
i32Type, i32Type, i32Type, contextType},
|
||||
{});
|
||||
} else if (funcName == memref_expand_lut_in_trivial_glwe_ct_u64) {
|
||||
funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
@@ -253,10 +247,10 @@ struct BufferizableWithCallOpInterface
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Op, char const *funcName, bool withContext = false>
|
||||
template <typename Op, char const *funcName>
|
||||
struct BufferizableWithAsyncCallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
BufferizableWithAsyncCallOpInterface<Op, funcName, withContext>, Op> {
|
||||
BufferizableWithAsyncCallOpInterface<Op, funcName>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
@@ -300,7 +294,7 @@ struct BufferizableWithAsyncCallOpInterface
|
||||
}
|
||||
|
||||
// The first operand is the result
|
||||
mlir::SmallVector<mlir::Value, 3> operands{
|
||||
mlir::SmallVector<mlir::Value> operands{
|
||||
getCastedMemRef(rewriter, loc, *outMemref),
|
||||
};
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
@@ -313,10 +307,9 @@ struct BufferizableWithAsyncCallOpInterface
|
||||
operands.push_back(getCastedMemRef(rewriter, loc, memrefOperand));
|
||||
}
|
||||
}
|
||||
// Append the context argument
|
||||
if (withContext) {
|
||||
operands.push_back(getContextArgument(op));
|
||||
}
|
||||
|
||||
// Append additional arguments
|
||||
pushAdditionalArgs<Op>(castOp, operands, rewriter);
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, funcName).failed()) {
|
||||
@@ -422,6 +415,18 @@ template <>
|
||||
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// lwe_dim_in
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_inAttr()));
|
||||
// lwe_dim_out
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.lwe_dim_outAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
};
|
||||
@@ -430,6 +435,58 @@ template <>
|
||||
void pushAdditionalArgs(BConcrete::BootstrapLweBufferOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
};
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::KeySwitchLweBufferAsyncOffloadOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
};
|
||||
|
||||
template <>
|
||||
void pushAdditionalArgs(BConcrete::BootstrapLweBufferAsyncOffloadOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
RewriterBase &rewriter) {
|
||||
// input_lwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.inputLweDimAttr()));
|
||||
// poly_size
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.polySizeAttr()));
|
||||
// level
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.levelAttr()));
|
||||
// base_log
|
||||
operands.push_back(
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op.baseLogAttr()));
|
||||
// glwe_dim
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.glweDimensionAttr()));
|
||||
// out_precision
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op.outPrecisionAttr()));
|
||||
// context
|
||||
operands.push_back(getContextArgument(op));
|
||||
};
|
||||
@@ -488,31 +545,32 @@ void mlir::concretelang::BConcrete::
|
||||
BufferizableWithCallOpInterface<BConcrete::NegateLweBufferOp,
|
||||
memref_negate_lwe_ciphertext_u64>>(
|
||||
*ctx);
|
||||
BConcrete::KeySwitchLweGPUBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweGPUBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64, true>>(
|
||||
*ctx);
|
||||
BConcrete::BootstrapLweGPUBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweGPUBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64, true>>(
|
||||
*ctx);
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64>>(*ctx);
|
||||
if (mlir::concretelang::getEmitGPUOption()) {
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_cuda_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_cuda_u64>>(*ctx);
|
||||
} else {
|
||||
BConcrete::KeySwitchLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::KeySwitchLweBufferOp,
|
||||
memref_keyswitch_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::BootstrapLweBufferOp,
|
||||
memref_bootstrap_lwe_u64>>(*ctx);
|
||||
}
|
||||
BConcrete::WopPBSCRTLweBufferOp::attachInterface<
|
||||
BufferizableWithCallOpInterface<BConcrete::WopPBSCRTLweBufferOp,
|
||||
memref_wop_pbs_crt_buffer>>(*ctx);
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::KeySwitchLweBufferAsyncOffloadOp,
|
||||
memref_keyswitch_async_lwe_u64, true>>(*ctx);
|
||||
memref_keyswitch_async_lwe_u64>>(*ctx);
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp::attachInterface<
|
||||
BufferizableWithAsyncCallOpInterface<
|
||||
BConcrete::BootstrapLweBufferAsyncOffloadOp,
|
||||
memref_bootstrap_async_lwe_u64, true>>(*ctx);
|
||||
memref_bootstrap_async_lwe_u64>>(*ctx);
|
||||
BConcrete::AwaitFutureOp::attachInterface<
|
||||
BufferizableWithSyncCallOpInterface<BConcrete::AwaitFutureOp,
|
||||
memref_await_future>>(*ctx);
|
||||
|
||||
@@ -327,6 +327,8 @@ void memref_keyswitch_lwe_u64(uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
uint32_t level, uint32_t base_log,
|
||||
uint32_t input_lwe_dim, uint32_t output_lwe_dim,
|
||||
mlir::concretelang::RuntimeContext *context) {
|
||||
CAPI_ASSERT_ERROR(
|
||||
default_engine_discard_keyswitch_lwe_ciphertext_u64_raw_ptr_buffers(
|
||||
|
||||
@@ -44,6 +44,11 @@
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
// TODO: should be removed when bufferization is not related to CAPI lowering
|
||||
// Control whether we should call a cpu of gpu function when lowering
|
||||
// to CAPI
|
||||
static bool EMIT_GPU_OPS;
|
||||
bool getEmitGPUOption() { return EMIT_GPU_OPS; }
|
||||
|
||||
/// Creates a new compilation context that can be shared across
|
||||
/// compilation engines and results
|
||||
@@ -223,9 +228,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
CompilationResult res(this->compilationContext);
|
||||
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
CompilationOptions &options = this->compilerOptions;
|
||||
|
||||
// enable/disable usage of gpu functions during bufferization
|
||||
EMIT_GPU_OPS = options.emitGPUOps;
|
||||
|
||||
mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
||||
|
||||
if (options.verifyDiagnostics) {
|
||||
// Only build diagnostics verifier handler if diagnostics should
|
||||
// be verified in order to avoid diagnostic messages to be
|
||||
@@ -355,8 +364,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
|
||||
// Concrete -> BConcrete
|
||||
if (mlir::concretelang::pipeline::lowerConcreteToBConcrete(
|
||||
mlirContext, module, this->enablePass, loopParallelize,
|
||||
options.emitGPUOps)
|
||||
mlirContext, module, this->enablePass, loopParallelize)
|
||||
.failed()) {
|
||||
return StreamStringError(
|
||||
"Lowering from Concrete to Bufferized Concrete failed");
|
||||
|
||||
@@ -242,13 +242,13 @@ optimizeConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
mlir::LogicalResult
|
||||
lowerConcreteToBConcrete(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
std::function<bool(mlir::Pass *)> enablePass,
|
||||
bool parallelizeLoops, bool emitGPUOps) {
|
||||
bool parallelizeLoops) {
|
||||
mlir::PassManager pm(&context);
|
||||
pipelinePrinting("ConcreteToBConcrete", pm, context);
|
||||
|
||||
std::unique_ptr<Pass> conversionPass =
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass(parallelizeLoops,
|
||||
emitGPUOps);
|
||||
mlir::concretelang::createConvertConcreteToBConcretePass(
|
||||
parallelizeLoops);
|
||||
|
||||
bool passEnabled = enablePass(conversionPass.get());
|
||||
|
||||
|
||||
@@ -1,24 +1,12 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table(%[[A0:.*]]: tensor<1025xi64>, %[[A1:.*]]: tensor<16xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 1024 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 2 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 3 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C6:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<1025xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 1024 : i32, lwe_dim_out = 600 : i32} : (tensor<1025xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 1024 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<1025xi64>
|
||||
//CHECK: return %[[V2]] : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table(%arg0: !Concrete.lwe_ciphertext<1024,4>, %arg1: tensor<16xi64>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
%c1 = arith.constant 600 : i32
|
||||
%c2 = arith.constant 1024 : i32
|
||||
%c3 = arith.constant 2 : i32
|
||||
%c4 = arith.constant 3 : i32
|
||||
%c5 = arith.constant 1 : i32
|
||||
%c6 = arith.constant 4 : i32
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %arg1, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<1024,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %arg1) {baseLog = 2 : i32, polySize = 1024 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64> ) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<1024,4>
|
||||
}
|
||||
|
||||
@@ -1,25 +1,14 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 2048 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 5 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C3]]) : (tensor<601xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64>
|
||||
//CHECK: %[[V1:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 600 : i32} : (tensor<2049xi64>) -> tensor<601xi64>
|
||||
//CHECK: %[[V2:.*]] = "BConcrete.bootstrap_lwe_buffer"(%[[V1]], %cst) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<601xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V2]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @apply_lookup_table_cst(%arg0: !Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<2048,4> {
|
||||
%c1 = arith.constant 600 : i32
|
||||
%c2 = arith.constant 2048 : i32
|
||||
%c3 = arith.constant 4 : i32
|
||||
%c4 = arith.constant 5 : i32
|
||||
%c5 = arith.constant 1 : i32
|
||||
%c6 = arith.constant 4 : i32
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 600 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %tlu, %c1, %c2, %c3, %c4, %c5, %c6) : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
%1 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 3 : i32} : (!Concrete.lwe_ciphertext<2048,4>) -> !Concrete.lwe_ciphertext<600,4>
|
||||
%2 = "Concrete.bootstrap_lwe"(%1, %tlu) {baseLog = 2 : i32, polySize = 2048 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<600,4>, tensor<16xi64>) -> !Concrete.lwe_ciphertext<2048,4>
|
||||
return %2 : !Concrete.lwe_ciphertext<2048,4>
|
||||
}
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
// RUN: concretecompiler --passes concrete-to-bconcrete --action=dump-bconcrete --emit-gpu-ops %s 2>&1| FileCheck %s
|
||||
|
||||
|
||||
//CHECK: func.func @main(%arg0: tensor<1025xi64>) -> tensor<1025xi64> {
|
||||
//CHECK: %c1_i32 = arith.constant 1 : i32
|
||||
//CHECK: %c8_i32 = arith.constant 8 : i32
|
||||
//CHECK: %c2_i32 = arith.constant 2 : i32
|
||||
//CHECK: %c1024_i32 = arith.constant 1024 : i32
|
||||
//CHECK: %c575_i32 = arith.constant 575 : i32
|
||||
//CHECK: %cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
//CHECK: %c5_i32 = arith.constant 5 : i32
|
||||
//CHECK: %c2_i32_0 = arith.constant 2 : i32
|
||||
//CHECK: %c575_i32_1 = arith.constant 575 : i32
|
||||
//CHECK: %c1024_i32_2 = arith.constant 1024 : i32
|
||||
//CHECK: %0 = "BConcrete.keyswitch_lwe_gpu_buffer"(%arg0, %c5_i32, %c2_i32_0, %c1024_i32_2, %c575_i32_1) : (tensor<1025xi64>, i32, i32, i32, i32) -> tensor<576xi64>
|
||||
//CHECK: %1 = "BConcrete.bootstrap_lwe_gpu_buffer"(%0, %cst, %c575_i32, %c1024_i32, %c2_i32, %c8_i32, %c1_i32, %c2_i32) : (tensor<576xi64>, tensor<4xi64>, i32, i32, i32, i32, i32, i32) -> tensor<1025xi64>
|
||||
//CHECK: return %1 : tensor<1025xi64>
|
||||
//CHECK: }
|
||||
func.func @main(%arg0: !Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<1024,2> {
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%c2_i32 = arith.constant 2 : i32
|
||||
%c1024_i32 = arith.constant 1024 : i32
|
||||
%c575_i32 = arith.constant 575 : i32
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%0 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<575,2>
|
||||
%1 = "Concrete.bootstrap_lwe"(%0, %cst, %c575_i32, %c1024_i32, %c2_i32, %c8_i32, %c1_i32, %c2_i32) : (!Concrete.lwe_ciphertext<575,2>, tensor<4xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,2>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,2>
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --action=dump-llvm-dialect --emit-gpu-ops %s 2>&1| FileCheck %s
|
||||
|
||||
//CHECK: llvm.call @memref_keyswitch_lwe_cuda_u64
|
||||
//CHECK: llvm.call @memref_bootstrap_lwe_cuda_u64
|
||||
func.func @main(%arg0: !Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<1024,2> {
|
||||
%cst = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%0 = "Concrete.keyswitch_lwe"(%arg0) {baseLog = 2 : i32, level = 5 : i32} : (!Concrete.lwe_ciphertext<1024,2>) -> !Concrete.lwe_ciphertext<575,2>
|
||||
%1 = "Concrete.bootstrap_lwe"(%0, %cst) {baseLog = 2 : i32, level = 5 : i32, polySize = 1024: i32, glweDimension = 1 : i32} : (!Concrete.lwe_ciphertext<575,2>, tensor<4xi64>) -> !Concrete.lwe_ciphertext<1024,2>
|
||||
return %1 : !Concrete.lwe_ciphertext<1024,2>
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
// CHECK: func.func @apply_lookup_table(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{2}>, %[[LUT:.*]]: tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>) -> !TFHE.glwe<{_,_,_}{2}>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %[[LUT]]) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{3}>
|
||||
// CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{3}>
|
||||
func.func @apply_lookup_table(%arg0: !FHE.eint<2>, %arg1: tensor<4xi64>) -> !FHE.eint<3> {
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<2>, tensor<4xi64>) -> (!FHE.eint<3>)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//CHECK: func.func @apply_lookup_table_cst(%[[A0:.*]]: !TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}> {
|
||||
//CHECK-NEXT: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK-NEXT: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{7}>, tensor<128xi64>) -> !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: return %[[V2]] : !TFHE.glwe<{_,_,_}{7}>
|
||||
//CHECK-NEXT: }
|
||||
func.func @apply_lookup_table_cst(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
//CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> {
|
||||
//CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
//CHECK: %[[V1:.*]] = "TFHE.keyswitch_glwe"(%[[A0]]) {baseLog = 4 : i32, level = 3 : i32} : (!TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{750,1,64}{4}>
|
||||
//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = 23 : i32, glweDimension = 2 : i32, inputLweDim = 750 : i32, level = 1 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{750,1,64}{4}>, tensor<16xi64>) -> !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: %[[V2:.*]] = "TFHE.bootstrap_glwe"(%[[V1]], %cst) {baseLog = 23 : i32, glweDimension = 2 : i32, level = 1 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{750,1,64}{4}>, tensor<16xi64>) -> !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: return %[[V2]] : !TFHE.glwe<{2048,1,64}{4}>
|
||||
//CHECK: }
|
||||
func.func @main(%arg0: !TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}> {
|
||||
%cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64>
|
||||
%1 = "TFHE.keyswitch_glwe"(%arg0) {baseLog = -1 : i32, level = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, inputLweDim = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
%2 = "TFHE.bootstrap_glwe"(%1, %cst) {baseLog = -1 : i32, glweDimension = -1 : i32, level = -1 : i32, polySize = -1 : i32} : (!TFHE.glwe<{_,_,_}{4}>, tensor<16xi64>) -> !TFHE.glwe<{_,_,_}{4}>
|
||||
return %2 : !TFHE.glwe<{_,_,_}{4}>
|
||||
}
|
||||
|
||||
@@ -2,17 +2,11 @@
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%[[A0:.*]]: !Concrete.lwe_ciphertext<600,7>) -> !Concrete.lwe_ciphertext<1024,4> {
|
||||
//CHECK: %cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
//CHECK: %[[C1:.*]] = arith.constant 600 : i32
|
||||
//CHECK: %[[C2:.*]] = arith.constant 1024 : i32
|
||||
//CHECK: %[[C3:.*]] = arith.constant 3 : i32
|
||||
//CHECK: %[[C4:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C5:.*]] = arith.constant 1 : i32
|
||||
//CHECK: %[[C6:.*]] = arith.constant 4 : i32
|
||||
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %cst, %[[C1]], %[[C2]], %[[C3]], %[[C4]], %[[C5]], %[[C6]]) : (!Concrete.lwe_ciphertext<600,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, level = 3 : i32, polySize = 1024 : i32} : (!Concrete.lwe_ciphertext<600,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: return %[[V1]] : !Concrete.lwe_ciphertext<1024,4>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%ciphertext: !TFHE.glwe<{600,1,64}{7}>) -> !TFHE.glwe<{1024,1,64}{4}> {
|
||||
%cst = arith.constant dense<"0x00000000000000000100000000000000020000000000000003000000000000000400000000000000050000000000000006000000000000000700000000000000080000000000000009000000000000000A000000000000000B000000000000000C000000000000000D000000000000000E000000000000000F0000000000000010000000000000001100000000000000120000000000000013000000000000001400000000000000150000000000000016000000000000001700000000000000180000000000000019000000000000001A000000000000001B000000000000001C000000000000001D000000000000001E000000000000001F0000000000000020000000000000002100000000000000220000000000000023000000000000002400000000000000250000000000000026000000000000002700000000000000280000000000000029000000000000002A000000000000002B000000000000002C000000000000002D000000000000002E000000000000002F0000000000000030000000000000003100000000000000320000000000000033000000000000003400000000000000350000000000000036000000000000003700000000000000380000000000000039000000000000003A000000000000003B000000000000003C000000000000003D000000000000003E000000000000003F0000000000000040000000000000004100000000000000420000000000000043000000000000004400000000000000450000000000000046000000000000004700000000000000480000000000000049000000000000004A000000000000004B000000000000004C000000000000004D000000000000004E000000000000004F0000000000000050000000000000005100000000000000520000000000000053000000000000005400000000000000550000000000000056000000000000005700000000000000580000000000000059000000000000005A000000000000005B000000000000005C000000000000005D000000000000005E000000000000005F0000000000000060000000000000006100000000000000620000000000000063000000000000006400000000000000650000000000000066000000000000006700000000000000680000000000000069000000000000006A000000000000006B000000000000006C000000000000006D000000000000006E000000000000006F0000000000000070000000000000007100000000000000720000000000000073000000000000007400000000000000750000000000000076000000000000007700000000000000780000000000000079000000000000007A000000000000007B000000000000007C000000000000007D000000000000007E000000000000007F00000000000000"> : tensor<128xi64>
|
||||
%bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, inputLweDim = 600 : i32, level = 3 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1024,1,64}{4}>
|
||||
%bootstraped = "TFHE.bootstrap_glwe"(%ciphertext, %cst) {baseLog = 1 : i32, glweDimension = 1 : i32, level = 3 : i32, polySize = 1024 : i32} : (!TFHE.glwe<{600,1,64}{7}>, tensor<128xi64>) -> !TFHE.glwe<{1024,1,64}{4}>
|
||||
return %bootstraped : !TFHE.glwe<{1024,1,64}{4}>
|
||||
}
|
||||
|
||||
@@ -72,20 +72,20 @@ func.func @negate_crt_lwe_ciphertext(%arg0: tensor<5x2049xi64>) -> tensor<5x2049
|
||||
return %0 : tensor<5x2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> tensor<2049xi64>
|
||||
//CHECK: func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (tensor<2049xi64>, tensor<16xi64>, i32, i32, i32, i32, i32, i32) -> (tensor<2049xi64>)
|
||||
func.func @bootstrap_lwe(%arg0: tensor<2049xi64>, %arg1: tensor<16xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.bootstrap_lwe_buffer"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, inputLweDim = 600 : i32, level = 3 : i32, outPrecision = 4 : i32, polySize = 2048 : i32} : (tensor<2049xi64>, tensor<16xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
//CHECK: func.func @keyswitch_lwe(%[[A0:.*]]: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 1 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: %[[V0:.*]] = "BConcrete.keyswitch_lwe_buffer"(%[[A0]]) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> tensor<2049xi64>
|
||||
//CHECK: return %[[V0]] : tensor<2049xi64>
|
||||
//CHECK: }
|
||||
func.func @keyswitch_lwe(%arg0: tensor<2049xi64>) -> tensor<2049xi64> {
|
||||
%0 = "BConcrete.keyswitch_lwe_buffer"(%arg0) {baseLog = 2 : i32, inputLweDimension = 1 : i32, level = 3 : i32, outputLweDimension = 1 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
%0 = "BConcrete.keyswitch_lwe_buffer"(%arg0) {baseLog = 2 : i32, level = 3 : i32, lwe_dim_in = 2048 : i32, lwe_dim_out = 2048 : i32} : (tensor<2049xi64>) -> (tensor<2049xi64>)
|
||||
return %0 : tensor<2049xi64>
|
||||
}
|
||||
|
||||
@@ -36,11 +36,11 @@ func.func @negate_lwe_ciphertext(%arg0: !Concrete.lwe_ciphertext<2048,7>) -> !Co
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-LABEL: func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
func.func @bootstrap_lwe(%arg0: !Concrete.lwe_ciphertext<2048,7>, %arg1: tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, glweDimension = 4 : i32, level = 3 : i32, polySize = 2048 : i32} : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
// CHECK-NEXT: return %[[V1]] : !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>, i32, i32, i32, i32, i32, i32) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
%1 = "Concrete.bootstrap_lwe"(%arg0, %arg1) {baseLog = 2 : i32, polySize = 2048 : i32, level = 3 : i32, glweDimension = 4 : i32} : (!Concrete.lwe_ciphertext<2048,7>, tensor<128xi64>) -> !Concrete.lwe_ciphertext<2048,7>
|
||||
return %1: !Concrete.lwe_ciphertext<2048,7>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user