refactor: redesign GPU support

- unify CPU and GPU bootstrapping operations
- remove operations to build GLWE from table: this is now done in
  wrapper functions
- remove GPU memory management operations: done in wrappers now, but we
  will have to think about how to deal with it later in MLIR
This commit is contained in:
youben11
2022-08-25 14:32:54 +01:00
committed by Ayoub Benaissa
parent d169a27fc0
commit a7a65025ff
35 changed files with 293 additions and 708 deletions

View File

@@ -3,88 +3,10 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
char move_bsk_to_gpu[] = "move_bsk_to_gpu";
char free_from_gpu[] = "free_from_gpu";
/// \brief Rewrites `BConcrete.move_bsk_to_gpu` into a CAPI call to
/// `move_bsk_to_gpu`
///
/// Also insert the forward declaration of `move_bsk_to_gpu`
struct MoveBskOpPattern : public mlir::OpRewritePattern<
mlir::concretelang::BConcrete::MoveBskToGPUOp> {
MoveBskOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<mlir::concretelang::BConcrete::MoveBskToGPUOp>(
context, benefit) {}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::MoveBskToGPUOp moveBskOp,
::mlir::PatternRewriter &rewriter) const override {
auto ctx = getContextArgument(moveBskOp);
mlir::SmallVector<mlir::Value> operands{ctx};
// Insert forward declaration of the function
auto contextType =
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
auto funcType = mlir::FunctionType::get(
rewriter.getContext(), {contextType},
{mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())});
if (insertForwardDeclaration(moveBskOp, rewriter, move_bsk_to_gpu, funcType)
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
moveBskOp, move_bsk_to_gpu, moveBskOp.getResult().getType(), operands);
return ::mlir::success();
};
};
/// \brief Rewrites `BConcrete.free_bsk_from_gpu` into a CAPI call to
/// `free_from_gpu`
///
/// Also insert the forward declaration of `free_from_gpu`
struct FreeBskOpPattern : public mlir::OpRewritePattern<
mlir::concretelang::BConcrete::FreeBskFromGPUOp> {
FreeBskOpPattern(::mlir::MLIRContext *context,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<
mlir::concretelang::BConcrete::FreeBskFromGPUOp>(context, benefit) {
}
::mlir::LogicalResult
matchAndRewrite(mlir::concretelang::BConcrete::FreeBskFromGPUOp freeBskOp,
::mlir::PatternRewriter &rewriter) const override {
mlir::SmallVector<mlir::Value> operands{freeBskOp.bsk()};
// Insert forward declaration of the function
auto funcType = mlir::FunctionType::get(
rewriter.getContext(),
{mlir::LLVM::LLVMPointerType::get(rewriter.getI64Type())}, {});
if (insertForwardDeclaration(freeBskOp, rewriter, free_from_gpu, funcType)
.failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
freeBskOp, free_from_gpu, mlir::TypeRange({}), operands);
return ::mlir::success();
};
};
namespace {
struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
@@ -98,12 +20,6 @@ void BConcreteToCAPIPass::runOnOperation() {
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
target.addIllegalOp<mlir::concretelang::BConcrete::MoveBskToGPUOp>();
target.addLegalDialect<mlir::func::FuncDialect>();
patterns.insert<MoveBskOpPattern>(&getContext());
patterns.insert<FreeBskOpPattern>(&getContext());
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) {
this->signalPassFailure();