From 0dbb86bb36b3c96a74a85664c2e45cee6439f9c2 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Wed, 30 Nov 2022 09:41:34 +0000 Subject: [PATCH] feat(compiler): add lowering and bufferization for SDFG dialect, generate code to Stream Emulator API. --- .../include/concretelang/Conversion/Passes.h | 1 + .../include/concretelang/Conversion/Passes.td | 7 + .../Conversion/SDFGToStreamEmulator/Pass.h | 19 + .../Transforms/BufferizableOpInterfaceImpl.h | 19 + .../include/concretelang/Support/Pipeline.h | 4 + compiler/lib/Bindings/Rust/build.rs | 5 +- compiler/lib/Conversion/CMakeLists.txt | 1 + .../MLIRLowerableDialectsToLLVM.cpp | 5 +- .../SDFGToStreamEmulator/CMakeLists.txt | 14 + .../SDFGToStreamEmulator.cpp | 390 ++++++++++++++++++ .../BufferizableOpInterfaceImpl.cpp | 139 +++++++ .../Dialect/SDFG/Transforms/CMakeLists.txt | 5 + compiler/lib/Support/CompilerEngine.cpp | 9 + compiler/lib/Support/Pipeline.cpp | 11 + 14 files changed, 626 insertions(+), 3 deletions(-) create mode 100644 compiler/include/concretelang/Conversion/SDFGToStreamEmulator/Pass.h create mode 100644 compiler/include/concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h create mode 100644 compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt create mode 100644 compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp create mode 100644 compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index ab644af9b..62f55593d 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -20,6 +20,7 @@ #include "concretelang/Conversion/FHEToTFHE/Pass.h" #include "concretelang/Conversion/LinalgExtras/Passes.h" #include "concretelang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h" +#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h" #include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h" #include "concretelang/Conversion/TFHEToConcrete/Pass.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index 498fc4dfd..6ebf0facc 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -61,6 +61,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"]; } +def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> { + let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls"; + let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }]; + let constructor = "mlir::concretelang::createConvertSDFGToStreamEmulatorPass()"; + let dependentDialects = ["mlir::concretelang::SDFG::SDFGDialect"]; +} + def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> { let summary = "Lowers operations from MLIR lowerable dialects to LLVM"; let constructor = "mlir::concretelang::createConvertMLIRLowerableDialectsToLLVMPass()"; diff --git a/compiler/include/concretelang/Conversion/SDFGToStreamEmulator/Pass.h b/compiler/include/concretelang/Conversion/SDFGToStreamEmulator/Pass.h new file mode 100644 index 000000000..5c7fe612f --- /dev/null +++ b/compiler/include/concretelang/Conversion/SDFGToStreamEmulator/Pass.h @@ -0,0 +1,19 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_ +#define ZAMALANG_CONVERSION_SDFGTOSTREAMEMULATOR_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace concretelang { +/// Create a pass to convert `SDFG` dialect to Stream Emulator calls. +std::unique_ptr> +createConvertSDFGToStreamEmulatorPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h b/compiler/include/concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..89bc6d6d0 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,19 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H +#define CONCRETELANG_DIALECT_SDFG_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace concretelang { +namespace SDFG { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace SDFG +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index f42533d6d..063fff10e 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -66,6 +66,10 @@ mlir::LogicalResult lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); +mlir::LogicalResult +lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); + mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass, diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index 1269f52e3..3ba3bebea 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [ "LLVMX86Info", ]; -const CONCRETE_COMPILER_LIBS: [&str; 33] = [ +const CONCRETE_COMPILER_LIBS: [&str; 34] = [ "RTDialect", "RTDialectTransforms", "ConcretelangSupport", @@ -274,7 +274,8 @@ const CONCRETE_COMPILER_LIBS: [&str; 33] = [ "ConcreteDialect", "RTDialectAnalysis", "SDFGDialect", - "ExtractSDFGOps" + "ExtractSDFGOps", + "SDFGToStreamEmulator", ]; fn main() { diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 13db1bd1f..dd16024a6 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -4,6 +4,7 @@ add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) add_subdirectory(BConcreteToCAPI) +add_subdirectory(SDFGToStreamEmulator) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) add_subdirectory(ExtractSDFGOps) diff --git a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp index fa5d491d6..d57ae774a 100644 --- a/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp +++ b/compiler/lib/Conversion/MLIRLowerableDialectsToLLVM/MLIRLowerableDialectsToLLVM.cpp @@ -32,6 +32,7 @@ #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/Analysis/Autopar.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" namespace { struct MLIRLowerableDialectsToLLVMPass @@ -155,7 +156,9 @@ MLIRLowerableDialectsToLLVMPass::convertTypes(mlir::Type type) { if (type.isa() || type.isa() || type.isa() || - type.isa()) { + type.isa() || + type.isa() || + type.isa()) { return mlir::LLVM::LLVMPointerType::get( mlir::IntegerType::get(type.getContext(), 64)); } diff --git a/compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt b/compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt new file mode 100644 index 000000000..383e5ccec --- /dev/null +++ b/compiler/lib/Conversion/SDFGToStreamEmulator/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library( + SDFGToStreamEmulator + SDFGToStreamEmulator.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG + DEPENDS + SDFGDialect + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR + MLIRTransforms) + +target_link_libraries(SDFGToStreamEmulator PUBLIC SDFGDialect MLIRIR) diff --git a/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp b/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp new file mode 100644 index 000000000..c0fefd51e --- /dev/null +++ b/compiler/lib/Conversion/SDFGToStreamEmulator/SDFGToStreamEmulator.cpp @@ -0,0 +1,390 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Tools.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" +#include "concretelang/Runtime/stream_emulator_api.h" + +namespace SDFG = mlir::concretelang::SDFG; + +namespace { +struct SDFGToStreamEmulatorPass + : public SDFGToStreamEmulatorBase { + void runOnOperation() final; +}; + +char stream_emulator_init[] = "stream_emulator_init"; +char stream_emulator_run[] = "stream_emulator_run"; +char stream_emulator_delete[] = "stream_emulator_delete"; +char stream_emulator_make_memref_add_lwe_ciphertexts_u64_process[] = + "stream_emulator_make_memref_add_lwe_ciphertexts_u64_process"; +char stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process[] = + "stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process"; +char stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process[] = + "stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process"; +char stream_emulator_make_memref_negate_lwe_ciphertext_u64_process[] = + "stream_emulator_make_memref_negate_lwe_ciphertext_u64_process"; +char stream_emulator_make_memref_keyswitch_lwe_u64_process[] = + "stream_emulator_make_memref_keyswitch_lwe_u64_process"; +char stream_emulator_make_memref_bootstrap_lwe_u64_process[] = + "stream_emulator_make_memref_bootstrap_lwe_u64_process"; + +char stream_emulator_make_memref_stream[] = + "stream_emulator_make_memref_stream"; +char stream_emulator_put_memref[] = "stream_emulator_put_memref"; +char stream_emulator_make_uint64_stream[] = + "stream_emulator_make_uint64_stream"; +char stream_emulator_put_uint64[] = "stream_emulator_put_uint64"; +char stream_emulator_get_uint64[] = "stream_emulator_get_uint64"; + +mlir::Type getDynamicTensor(mlir::OpBuilder &rewriter, size_t rank) { + std::vector shape(rank, -1); + return mlir::RankedTensorType::get(shape, rewriter.getI64Type()); +} + +mlir::Type makeDynamicTensorTypes(mlir::OpBuilder &rewriter, mlir::Type oldTy) { + if (auto ttype = oldTy.dyn_cast_or_null()) + return getDynamicTensor(rewriter, ttype.getRank()); + if (auto stTy = oldTy.dyn_cast_or_null()) + return SDFG::StreamType::get( + rewriter.getContext(), + makeDynamicTensorTypes(rewriter, stTy.getElementType())); + return oldTy; +} + +mlir::LogicalResult insertGenericForwardDeclaration(mlir::Operation *op, + mlir::OpBuilder &rewriter, + llvm::StringRef funcName, + mlir::TypeRange opTys, + mlir::TypeRange resTys) { + mlir::SmallVector operands; + for (mlir::Type opTy : opTys) + operands.push_back(makeDynamicTensorTypes(rewriter, opTy)); + mlir::SmallVector results; + for (mlir::Type resTy : resTys) + results.push_back(makeDynamicTensorTypes(rewriter, resTy)); + + mlir::FunctionType funcType = + mlir::FunctionType::get(rewriter.getContext(), operands, results); + return insertForwardDeclaration(op, rewriter, funcName, funcType); +} + +void castDynamicTensorOps(mlir::Operation *op, mlir::OpBuilder &rewriter, + mlir::ValueRange operands, + mlir::SmallVector &newOps) { + for (auto val : operands) { + auto oldTy = val.getType(); + if (auto ttype = oldTy.dyn_cast_or_null()) + newOps.push_back(rewriter.create( + op->getLoc(), getDynamicTensor(rewriter, ttype.getRank()), val)); + else + newOps.push_back(val); + } +} + +struct LowerSDFGInit + : public mlir::OpRewritePattern { + LowerSDFGInit(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::Init initOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::FunctionType funcType = mlir::FunctionType::get( + rewriter.getContext(), {}, {SDFG::DFGType::get(rewriter.getContext())}); + if (insertForwardDeclaration(initOp, rewriter, stream_emulator_init, + funcType) + .failed()) + return ::mlir::failure(); + rewriter.replaceOpWithNewOp( + initOp, stream_emulator_init, + mlir::TypeRange{SDFG::DFGType::get(rewriter.getContext())}); + return ::mlir::success(); + }; +}; + +struct LowerSDFGStart + : public mlir::OpRewritePattern { + LowerSDFGStart(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::Start startOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::FunctionType funcType = mlir::FunctionType::get( + rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {}); + if (insertForwardDeclaration(startOp, rewriter, stream_emulator_run, + funcType) + .failed()) + return ::mlir::failure(); + rewriter.replaceOpWithNewOp( + startOp, stream_emulator_run, mlir::TypeRange{}, + startOp.getOperation()->getOperands()); + return ::mlir::success(); + }; +}; + +struct LowerSDFGShutdown + : public mlir::OpRewritePattern { + LowerSDFGShutdown(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::Shutdown desOp, + ::mlir::PatternRewriter &rewriter) const override { + mlir::FunctionType funcType = mlir::FunctionType::get( + rewriter.getContext(), {SDFG::DFGType::get(rewriter.getContext())}, {}); + if (insertForwardDeclaration(desOp, rewriter, stream_emulator_delete, + funcType) + .failed()) + return ::mlir::failure(); + rewriter.replaceOpWithNewOp( + desOp, stream_emulator_delete, mlir::TypeRange{}, + desOp.getOperation()->getOperands()); + return ::mlir::success(); + }; +}; + +struct LowerSDFGMakeProcess + : public mlir::OpRewritePattern { + LowerSDFGMakeProcess(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::MakeProcess mpOp, + ::mlir::PatternRewriter &rewriter) const override { + const char *funcName; + mlir::SmallVector operands(mpOp->getOperands()); + switch (mpOp.type()) { + case SDFG::ProcessKind::add_eint: + funcName = stream_emulator_make_memref_add_lwe_ciphertexts_u64_process; + break; + case SDFG::ProcessKind::add_eint_int: + funcName = + stream_emulator_make_memref_add_plaintext_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::mul_eint_int: + funcName = + stream_emulator_make_memref_mul_cleartext_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::neg_eint: + funcName = stream_emulator_make_memref_negate_lwe_ciphertext_u64_process; + break; + case SDFG::ProcessKind::keyswitch: + funcName = stream_emulator_make_memref_keyswitch_lwe_u64_process; + // level + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("level"))); + // base_log + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("baseLog"))); + // lwe_dim_in + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("lwe_dim_in"))); + // lwe_dim_out + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("lwe_dim_out"))); + // context + operands.push_back(getContextArgument(mpOp)); + break; + case SDFG::ProcessKind::bootstrap: + funcName = stream_emulator_make_memref_bootstrap_lwe_u64_process; + // input_lwe_dim + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("inputLweDim"))); + // poly_size + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("polySize"))); + // level + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("level"))); + // base_log + operands.push_back(rewriter.create( + mpOp.getLoc(), mpOp->getAttrOfType("baseLog"))); + // glwe_dim + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("glweDimension"))); + // out_precision + operands.push_back(rewriter.create( + mpOp.getLoc(), + mpOp->getAttrOfType("outPrecision"))); + // context + operands.push_back(getContextArgument(mpOp)); + break; + } + if (insertGenericForwardDeclaration(mpOp, rewriter, funcName, + mlir::ValueRange{operands}.getTypes(), + mpOp->getResultTypes()) + .failed()) + return ::mlir::failure(); + rewriter.replaceOpWithNewOp( + mpOp, funcName, mpOp->getResultTypes(), operands); + return ::mlir::success(); + }; +}; + +struct LowerSDFGMakeStream + : public mlir::OpRewritePattern { + LowerSDFGMakeStream(::mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::MakeStream msOp, + ::mlir::PatternRewriter &rewriter) const override { + const char *funcName; + + stream_type t; + switch (msOp.type()) { + case SDFG::StreamKind::host_to_device: + t = TS_STREAM_TYPE_X86_TO_TOPO_LSAP; + break; + case SDFG::StreamKind::on_device: + t = TS_STREAM_TYPE_TOPO_TO_TOPO_LSAP; + break; + case SDFG::StreamKind::device_to_host: + t = TS_STREAM_TYPE_TOPO_TO_X86_LSAP; + break; + } + auto sType = msOp->getResultTypes()[0].dyn_cast_or_null(); + assert(sType && "SDFG MakeStream operation should return a stream type"); + + if (sType.getElementType().isa()) { + funcName = stream_emulator_make_memref_stream; + } else { + assert(sType.getElementType().isa() && + "SDFG streams only support memrefs and integers."); + funcName = stream_emulator_make_uint64_stream; + } + if (insertGenericForwardDeclaration( + msOp, rewriter, funcName, + {rewriter.getI64Type(), rewriter.getI64Type()}, + msOp->getResultTypes()) + .failed()) + return ::mlir::failure(); + mlir::Value nullStringPtr = rewriter.create( + msOp.getLoc(), rewriter.getI64IntegerAttr(0)); + mlir::Value streamTypeCst = rewriter.create( + msOp.getLoc(), rewriter.getI64IntegerAttr((int)t)); + auto callop = rewriter.replaceOpWithNewOp( + msOp, funcName, + makeDynamicTensorTypes(rewriter, msOp->getResultTypes()[0]), + mlir::ValueRange{nullStringPtr, streamTypeCst}); + for (auto &use : llvm::make_early_inc_range(msOp->getUses())) + use.set(callop.getResult(0)); + return ::mlir::success(); + }; +}; + +struct LowerSDFGPut + : public mlir::OpRewritePattern { + LowerSDFGPut(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::Put putOp, + ::mlir::PatternRewriter &rewriter) const override { + const char *funcName; + auto sType = + putOp->getOperandTypes()[0].dyn_cast_or_null(); + assert(sType && + "SDFG Put operation must take a stream type as first parameter."); + if (sType.getElementType().isa()) { + funcName = stream_emulator_put_memref; + } else { + assert(sType.getElementType().isa() && + "SDFG streams only support memrefs and integers."); + funcName = stream_emulator_put_uint64; + } + if (insertGenericForwardDeclaration(putOp, rewriter, funcName, + putOp->getOperandTypes(), + putOp->getResultTypes()) + .failed()) + return ::mlir::failure(); + mlir::SmallVector newOps; + castDynamicTensorOps(putOp, rewriter, putOp->getOperands(), newOps); + rewriter.replaceOpWithNewOp( + putOp, funcName, putOp->getResultTypes(), newOps); + return ::mlir::success(); + }; +}; + +struct LowerSDFGGet + : public mlir::OpRewritePattern { + LowerSDFGGet(::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, + benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::concretelang::SDFG::Get getOp, + ::mlir::PatternRewriter &rewriter) const override { + const char *funcName; + auto sType = + getOp->getOperandTypes()[0].dyn_cast_or_null(); + assert(sType && + "SDFG Get operation must take a stream type as first parameter."); + if (sType.getElementType().isa()) { + // TODO: SDFG.Get for memref streams is lowered during bufferization + // as returning a memref requires allocation for now + return ::mlir::success(); + } else { + assert(sType.getElementType().isa() && + "SDFG streams only support memrefs and integers."); + funcName = stream_emulator_get_uint64; + } + if (insertGenericForwardDeclaration(getOp, rewriter, funcName, + getOp->getOperandTypes(), + getOp->getResultTypes()) + .failed()) + return ::mlir::failure(); + rewriter.replaceOpWithNewOp( + getOp, funcName, getOp->getResultTypes(), getOp->getOperands()); + return ::mlir::success(); + }; +}; +} // namespace + +void SDFGToStreamEmulatorPass::runOnOperation() { + auto op = this->getOperation(); + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + patterns.insert(&getContext()); + + target.addIllegalOp(); + // All BConcrete ops are legal after the conversion + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { + this->signalPassFailure(); + } +} + +namespace mlir { +namespace concretelang { +std::unique_ptr> +createConvertSDFGToStreamEmulatorPass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..e3c01e291 --- /dev/null +++ b/compiler/lib/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,139 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +#include "concretelang/Conversion/Tools.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" +#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" +#include "concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h" +#include "concretelang/Support/CompilerEngine.h" +#include +#include +#include + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::tensor; + +namespace SDFG = mlir::concretelang::SDFG; + +namespace mlir { +namespace concretelang { +namespace SDFG { +namespace {} // namespace +} // namespace SDFG +} // namespace concretelang +} // namespace mlir + +namespace { +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank) { + std::vector shape(rank, -1); + mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); + for (size_t i = 0; i < rank; i++) { + expr = expr + + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); + } + return mlir::MemRefType::get( + shape, rewriter.getI64Type(), + mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); +} + +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Location loc, + mlir::Value value) { + mlir::Type valueType = value.getType(); + if (auto memrefTy = valueType.dyn_cast_or_null()) { + return rewriter.create( + loc, + getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), + value); + } else { + return value; + } +} + +char stream_emulator_get_memref[] = "stream_emulator_get_memref"; + +template +struct BufferizableWithCallOpInterface + : public BufferizableOpInterface::ExternalModel< + BufferizableWithCallOpInterface, Op> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + + auto loc = op->getLoc(); + + // TODO: For now we allocate for the result of GET but we might be + // able to avoid the copy depending on the stream semantics. + auto resTensorType = + op->getResultTypes()[0].template cast(); + auto outMemrefType = MemRefType::get(resTensorType.getShape(), + resTensorType.getElementType()); + auto outMemref = options.createAlloc(rewriter, loc, outMemrefType, {}); + if (mlir::failed(outMemref)) { + return mlir::failure(); + } + + // The last operand is the result + mlir::SmallVector operands(op->getOperands()); + operands.push_back(getCastedMemRef(rewriter, loc, *outMemref)); + + mlir::FunctionType funcType = mlir::FunctionType::get( + rewriter.getContext(), mlir::ValueRange{operands}.getTypes(), + mlir::TypeRange()); + if (insertForwardDeclaration(op, rewriter, funcName, funcType).failed()) + return ::mlir::failure(); + rewriter.create(loc, funcName, mlir::TypeRange{}, + operands); + replaceOpWithBufferizedValues(rewriter, op, *outMemref); + + return success(); + } +}; + +} // namespace + +void mlir::concretelang::SDFG::registerBufferizableOpInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, SDFG::SDFGDialect *dialect) { + SDFG::Get::attachInterface< + BufferizableWithCallOpInterface>( + *ctx); + }); +} diff --git a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt index 05bfa724c..723edc0fc 100644 --- a/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt +++ b/compiler/lib/Dialect/SDFG/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library( ConcretelangSDFGTransforms + BufferizableOpInterfaceImpl.cpp SDFGConvertibleOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/BConcrete @@ -12,6 +13,10 @@ add_mlir_dialect_library( PUBLIC SDFGDialect ConcretelangSDFGInterfaces + ConcretelangConversion + MLIRArithmeticDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms MLIRIR MLIRMemRefDialect MLIRPass diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 3af3dd0a7..5165c3e34 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -35,6 +35,7 @@ #include #include #include +#include #include #include #include @@ -83,6 +84,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); BConcrete::registerBufferizableOpInterfaceExternalModels(registry); SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry); + SDFG::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( registry); @@ -417,6 +419,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { "Lowering from Bufferized Concrete to canonical MLIR dialects failed"); } + // SDFG -> Canonical dialects + if (mlir::concretelang::pipeline::lowerSDFGToStd(mlirContext, module, + enablePass) + .failed()) { + return errorDiag("Lowering from SDFG to canonical MLIR dialects failed"); + } + if (target == Target::STD) return std::move(res); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index ea12e60f2..67421c482 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -296,6 +296,17 @@ lowerBConcreteToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, return pm.run(module.getOperation()); } +mlir::LogicalResult +lowerSDFGToStd(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { + mlir::PassManager pm(&context); + pipelinePrinting("SDFGToStd", pm, context); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertSDFGToStreamEmulatorPass(), + enablePass); + return pm.run(module.getOperation()); +} + mlir::LogicalResult lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass,