diff --git a/compiler/include/concretelang/Dialect/CMakeLists.txt b/compiler/include/concretelang/Dialect/CMakeLists.txt index ed435f204..f05954ed4 100644 --- a/compiler/include/concretelang/Dialect/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(TFHE) add_subdirectory(Concrete) add_subdirectory(BConcrete) add_subdirectory(RT) +add_subdirectory(SDFG) diff --git a/compiler/include/concretelang/Dialect/SDFG/CMakeLists.txt b/compiler/include/concretelang/Dialect/SDFG/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/CMakeLists.txt b/compiler/include/concretelang/Dialect/SDFG/IR/CMakeLists.txt new file mode 100644 index 000000000..9964e8e8b --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +set(LLVM_TARGET_DEFINITIONS SDFGOps.td) +mlir_tablegen(SDFGEnums.h.inc -gen-enum-decls) +mlir_tablegen(SDFGEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(SDFGOps.h.inc -gen-op-decls) +mlir_tablegen(SDFGOps.cpp.inc -gen-op-defs) +mlir_tablegen(SDFGTypes.h.inc -gen-typedef-decls -typedefs-dialect=SDFG) +mlir_tablegen(SDFGTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=SDFG) +mlir_tablegen(SDFGDialect.h.inc -gen-dialect-decls -dialect=SDFG) +mlir_tablegen(SDFGDialect.cpp.inc -gen-dialect-defs -dialect=SDFG) +mlir_tablegen(SDFGAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=SDFG) +mlir_tablegen(SDFGAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=SDFG) +add_public_tablegen_target(MLIRSDFGOpsIncGen) +add_dependencies(mlir-headers MLIRSDFGOpsIncGen) + +add_concretelang_doc(SDFGOps SDFGDialect concretelang/ -gen-dialect-doc -dialect=SDFG) +add_concretelang_doc(SDFGOps SDFGOps concretelang/ -gen-op-doc) +add_concretelang_doc(SDFGTypes SDFGTypes concretelang/ -gen-typedef-doc) diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.h b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.h new file mode 100644 index 000000000..563e81897 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.h @@ -0,0 +1,14 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGDIALECT_H +#define CONCRETELANG_DIALECT_SDFG_IR_SDFGDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h.inc" + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.td b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.td new file mode 100644 index 000000000..0f4c0cf66 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.td @@ -0,0 +1,28 @@ +//===- SDFGDialect.td - SDFG dialect ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_DIALECT +#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_DIALECT + +include "mlir/IR/OpBase.td" + +def SDFG_Dialect : Dialect { + let name = "SDFG"; + let summary = "Dialect for the construction of static data flow graphs"; + let description = [{ + A dialect for the construction of static data flow graphs. The + data flow graph is composed of a set of processes, connected + through data streams. Special streams allow for data to be + injected into and to be retrieved from the data flow graph. + }]; + let cppNamespace = "::mlir::concretelang::SDFG"; + let useDefaultTypePrinterParser = 1; + let useDefaultAttributePrinterParser = 1; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.h b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.h new file mode 100644 index 000000000..bc8eb776d --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.h @@ -0,0 +1,22 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGOPS_H +#define CONCRETELANG_DIALECT_SDFG_IR_SDFGOPS_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "concretelang/Dialect/SDFG/IR/SDFGEnums.h.inc" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" + +#define GET_ATTRDEF_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.h.inc" + +#define GET_OP_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h.inc" + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td new file mode 100644 index 000000000..50bdbdc0c --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td @@ -0,0 +1,202 @@ +//===- SDFGOps.td - High level SDFG dialect ops ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_OPS +#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/EnumAttr.td" + +include "concretelang/Dialect/SDFG/IR/SDFGDialect.td" +include "concretelang/Dialect/SDFG/IR/SDFGTypes.td" + +class SDFG_Op traits = []> : + Op; + +def StreamKindHostToDevice : I32EnumAttrCase<"host_to_device", 0>; +def StreamKindOnDevice : I32EnumAttrCase<"on_device", 1>; +def StreamKindDeviceToHost : I32EnumAttrCase<"device_to_host", 2>; + +def StreamKind : I32EnumAttr<"StreamKind", "Stream kind", + [StreamKindOnDevice, StreamKindHostToDevice, StreamKindDeviceToHost]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::concretelang::SDFG"; +} + +def StreamKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def SDFG_Init : SDFG_Op<"init"> { + let summary = "Initializes the streaming framework"; + + let description = [{ + Initializes the streaming framework. This operation must be + performed before control reaches any other operation from the + dialect. + + Example: + ```mlir + "SDFG.init" : () -> !SDFG.dfg + ``` + }]; + + let arguments = (ins); + let results = (outs SDFG_DFG); +} + +def SDFG_MakeStream : SDFG_Op<"make_stream"> { + let summary = "Returns a new SDFG stream"; + + let description = [{ + Returns a new SDFG stream, transporting data either between + processes on the device, from the host to the device or from + the device to the host. All streams are typed, allowing data + to be read / written through `SDFG.get` and `SDFG.put` only + using the stream's type. + + Example: + ```mlir + "SDFG.make_stream" { name = "stream", type = #SDFG.stream_kind }(%dfg) + : (!SDFG.dfg) -> !SDFG.stream> + ``` + }]; + + let arguments = (ins SDFG_DFG:$dfg, StrAttr:$name, StreamKindAttr:$type); + let results = (outs SDFG_Stream); + + let extraClassDeclaration = [{ + bool createsInputStream() { + return type() == StreamKind::host_to_device || + type() == StreamKind::on_device; + } + + bool createsOutputStream() { + return type() == StreamKind::device_to_host || + type() == StreamKind::on_device; + } + }]; +} + +def ProcessKindAddEint : I32EnumAttrCase<"add_eint", 0>; +def ProcessKindAddEintInt : I32EnumAttrCase<"add_eint_int", 1>; +def ProcessKindMulEintInt : I32EnumAttrCase<"mul_eint_int", 2>; +def ProcessKindNegEint : I32EnumAttrCase<"neg_eint", 3>; +def ProcessKindKeyswitch : I32EnumAttrCase<"keyswitch", 4>; +def ProcessKindBootstrap : I32EnumAttrCase<"bootstrap", 5>; + +def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind", + [ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt, + ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::concretelang::SDFG"; +} + +def ProcessKindAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + +def SDFG_MakeProcess : SDFG_Op<"make_process"> { + let summary = "Creates a new SDFG process"; + + let description = [{ + Creates a new SDFG process and connects it to the input and + output streams. + + Example: + ```mlir + %in0 = "SDFG.make_stream" { type = #SDFG.stream_kind }(%dfg) : (!SDFG.dfg) -> !SDFG.stream> + %in1 = "SDFG.make_stream" { type = #SDFG.stream_kind }(%dfg) : (!SDFG.dfg) -> !SDFG.stream> + %out = "SDFG.make_stream" { type = #SDFG.stream_kind }(%dfg) : (!SDFG.dfg) -> !SDFG.stream> + "SDFG.make_process" { type = #SDFG.process_kind }(%dfg, %in0, %in1, %out) : + (!SDFG.dfg, !SDFG.stream>, !SDFG.stream>, !SDFG.stream>) -> () + ``` + }]; + + let extraClassDeclaration = [{ + ::mlir::LogicalResult checkStreams(size_t numIn, size_t numOut); + }]; + + let arguments = (ins ProcessKindAttr:$type, SDFG_DFG:$dfg, Variadic:$streams); + let results = (outs); + let hasVerifier = 1; +} + +def SDFG_Put : SDFG_Op<"put"> { + let summary = "Writes a data element to a stream"; + + let description = [{ + Writes the input operand to the specified stream. The + operand's type must meet the element type of the stream. + + Example: + ```mlir + "SDFG.put" (%stream, %data) : (!SDFG.stream<1024xi64>, tensor<1024xi64>) -> () + ``` + }]; + + let arguments = (ins SDFG_Stream:$stream, AnyType:$data); + let results = (outs); + let hasVerifier = 1; +} + +def SDFG_Get : SDFG_Op<"get"> { + let summary = "Retrieves a data element from a stream"; + + let description = [{ + Retrieves a single data element from the specified stream + (i.e., an instance of the element type of the stream). + + Example: + ```mlir + "SDFG.get" (%stream) : (!SDFG.stream<1024xi64>) -> (tensor<1024xi64>) + ``` + }]; + + let arguments = (ins SDFG_Stream:$stream); + let results = (outs AnyType:$data); +} + +def SDFG_Start : SDFG_Op<"start"> { + let summary = "Finalizes the creation of an SDFG and starts execution of its processes"; + + + let description = [{ + Finalizes the creation of an SDFG and starts execution of its + processes. Any creation of streams and processes must take + place before control reaches this operation. + + Example: + ```mlir + "SDFG.start"(%dfg) : !SDFG.dfg + ``` + }]; + + let arguments = (ins SDFG_DFG:$dfg); + let results = (outs); +} + + +def SDFG_Shutdown : SDFG_Op<"shutdown"> { + let summary = "Shuts down the streaming framework"; + + let description = [{ + Shuts down the streaming framework. This operation must be + performed after any other operation from the dialect. + + Example: + ```mlir + "SDFG.shutdown" (%dfg) : !SDFG.dfg + ``` + }]; + + let arguments = (ins SDFG_DFG:$dfg); + let results = (outs); +} + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.h b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.h new file mode 100644 index 000000000..2d3e8bd45 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.h @@ -0,0 +1,16 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGTYPES_H +#define CONCRETELANG_DIALECT_SDFG_IR_SDFGTYPES_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" + +#define GET_TYPEDEF_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h.inc" + +#endif diff --git a/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.td b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.td new file mode 100644 index 000000000..da3531bd6 --- /dev/null +++ b/compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.td @@ -0,0 +1,37 @@ +#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_TYPES +#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_TYPES + +include "concretelang/Dialect/SDFG/IR/SDFGDialect.td" +include "mlir/IR/BuiltinTypes.td" + +class SDFG_Type traits = []> : + TypeDef { } + +def SDFG_DFG : SDFG_Type<"DFG", []> { + let mnemonic = "dfg"; + + let summary = "An SDFG data flow graph"; + + let description = [{ + A handle to an SDFG data flow graph + }]; + + let parameters = (ins); + let hasCustomAssemblyFormat = 0; +} + + +def SDFG_Stream : SDFG_Type<"Stream", []> { + let mnemonic = "stream"; + + let summary = "An SDFG data stream"; + + let description = [{ + An SDFG stream to connect SDFG processes. + }]; + + let parameters = (ins "Type":$elementType); + let hasCustomAssemblyFormat = 1; +} + +#endif diff --git a/compiler/lib/Bindings/Rust/build.rs b/compiler/lib/Bindings/Rust/build.rs index aca30cc89..77a1e6288 100644 --- a/compiler/lib/Bindings/Rust/build.rs +++ b/compiler/lib/Bindings/Rust/build.rs @@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [ "LLVMX86Info", ]; -const CONCRETE_COMPILER_LIBS: [&str; 29] = [ +const CONCRETE_COMPILER_LIBS: [&str; 30] = [ "RTDialect", "RTDialectTransforms", "ConcretelangSupport", @@ -271,6 +271,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 29] = [ "FHEDialectAnalysis", "ConcreteDialect", "RTDialectAnalysis", + "SDFGDialect" ]; fn main() { diff --git a/compiler/lib/Dialect/CMakeLists.txt b/compiler/lib/Dialect/CMakeLists.txt index 2a809ceab..d0ca5b248 100644 --- a/compiler/lib/Dialect/CMakeLists.txt +++ b/compiler/lib/Dialect/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(TFHE) add_subdirectory(Concrete) add_subdirectory(BConcrete) add_subdirectory(RT) +add_subdirectory(SDFG) diff --git a/compiler/lib/Dialect/SDFG/CMakeLists.txt b/compiler/lib/Dialect/SDFG/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/compiler/lib/Dialect/SDFG/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compiler/lib/Dialect/SDFG/IR/CMakeLists.txt b/compiler/lib/Dialect/SDFG/IR/CMakeLists.txt new file mode 100644 index 000000000..2e1ed6fcd --- /dev/null +++ b/compiler/lib/Dialect/SDFG/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library( + SDFGDialect + SDFGDialect.cpp + SDFGOps.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR) + +target_link_libraries(SDFGDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/SDFG/IR/SDFGDialect.cpp b/compiler/lib/Dialect/SDFG/IR/SDFGDialect.cpp new file mode 100644 index 000000000..b0a667a01 --- /dev/null +++ b/compiler/lib/Dialect/SDFG/IR/SDFGDialect.cpp @@ -0,0 +1,58 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/IR/Builders.h" + +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" + +using namespace mlir::concretelang::SDFG; + +#define GET_TYPEDEF_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.cpp.inc" + +void SDFGDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "concretelang/Dialect/SDFG/IR/SDFGOps.cpp.inc" + >(); + + addTypes< +#define GET_TYPEDEF_LIST +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.cpp.inc" + >(); +} + +#define GET_ATTRDEF_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGDialect.cpp.inc" + +void StreamType::print(mlir::AsmPrinter &p) const { + p << "<" << getElementType() << ">"; +} + +mlir::Type StreamType::parse(mlir::AsmParser &p) { + if (p.parseLess()) + return mlir::Type(); + + mlir::Type t; + if (p.parseType(t)) + return mlir::Type(); + + if (p.parseGreater()) + return mlir::Type(); + + mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc()); + + return getChecked(loc, loc.getContext(), t); +} diff --git a/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp b/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp new file mode 100644 index 000000000..169d3f16c --- /dev/null +++ b/compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp @@ -0,0 +1,91 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/IR/Builders.h" + +#include "concretelang/Dialect/SDFG/IR/SDFGOps.h" +#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h" + +#include "concretelang/Dialect/SDFG/IR/SDFGEnums.cpp.inc" +#include + +#define GET_OP_CLASSES +#include "concretelang/Dialect/SDFG/IR/SDFGOps.cpp.inc" + +namespace mlir { +namespace concretelang { +namespace SDFG { +mlir::LogicalResult Put::verify() { + mlir::Type streamElementType = + stream().getType().cast().getElementType(); + mlir::Type elementType = data().getType(); + + if (streamElementType != elementType) { + emitError() + << "The type " << elementType + << " of the element to be written does not match the element type " + << streamElementType << " of the stream."; + return mlir::failure(); + } + + return mlir::success(); +} + +mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) { + mlir::OperandRange streams = this->streams(); + + if (streams.size() != numIn + numOut) { + emitError() << "Process `" << stringifyProcessKind(type()) + << "` expects 3 streams, but " << streams.size() + << " were given."; + return mlir::failure(); + } + + for (size_t i = 0; i < numIn; i++) { + MakeStream in = dyn_cast_or_null(streams[i].getDefiningOp()); + + if (in && !in.createsInputStream()) { + emitError() << "Stream #" << (i + 1) << " of process `" + << stringifyProcessKind(type()) + << "` must be an input stream."; + return mlir::failure(); + } + } + + for (size_t i = numIn; i < numIn + numOut; i++) { + MakeStream out = dyn_cast_or_null(streams[i].getDefiningOp()); + + if (out && !out.createsOutputStream()) { + emitError() << "Stream #" << (i + 1) << " of process `" + << stringifyProcessKind(type()) + << "` must be an output stream."; + return mlir::failure(); + } + } + + return mlir::success(); +} + +mlir::LogicalResult MakeProcess::verify() { + switch (type()) { + case ProcessKind::add_eint: + return checkStreams(2, 1); + case ProcessKind::add_eint_int: + return checkStreams(2, 1); + case ProcessKind::mul_eint_int: + return checkStreams(2, 1); + case ProcessKind::neg_eint: + return checkStreams(1, 1); + case ProcessKind::keyswitch: + return checkStreams(1, 1); + case ProcessKind::bootstrap: + return checkStreams(2, 1); + } + + return mlir::failure(); +} +} // namespace SDFG +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 05c449f2f..5eca491b3 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -69,16 +70,16 @@ CompilationContext::~CompilationContext() { mlir::MLIRContext *CompilationContext::getMLIRContext() { if (this->mlirContext == nullptr) { mlir::DialectRegistry registry; - registry.insert(); + registry.insert< + mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect, + mlir::concretelang::TFHE::TFHEDialect, + mlir::concretelang::FHELinalg::FHELinalgDialect, + mlir::concretelang::Concrete::ConcreteDialect, + mlir::concretelang::BConcrete::BConcreteDialect, + mlir::concretelang::SDFG::SDFGDialect, mlir::func::FuncDialect, + mlir::memref::MemRefDialect, mlir::linalg::LinalgDialect, + mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect, + mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); BConcrete::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( diff --git a/compiler/src/CMakeLists.txt b/compiler/src/CMakeLists.txt index 721c3e62d..804bb435a 100644 --- a/compiler/src/CMakeLists.txt +++ b/compiler/src/CMakeLists.txt @@ -16,6 +16,7 @@ target_link_libraries( ConcreteDialect TFHEDialect FHEDialect + SDFGDialect ConcretelangSupport ConcretelangTransforms MLIRIR diff --git a/compiler/tests/check_tests/Dialect/SDFG/invalid.mlir b/compiler/tests/check_tests/Dialect/SDFG/invalid.mlir new file mode 100644 index 000000000..acfb48b00 --- /dev/null +++ b/compiler/tests/check_tests/Dialect/SDFG/invalid.mlir @@ -0,0 +1,41 @@ +// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s + +func.func @wrong_element_type(%arg0: tensor<2xi32>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> { + %dfg = "SDFG.init"() : () -> !SDFG.dfg + %in0 = "SDFG.make_stream" (%dfg) { name = "in0", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + "SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind } : + (!SDFG.dfg, !SDFG.stream>, !SDFG.stream>, !SDFG.stream>) -> () + "SDFG.start"(%dfg) : (!SDFG.dfg) -> () + + // expected-error @+1 {{The type 'tensor<2xi32>' of the element to be written does not match the element type 'tensor<1024xi64>' of the stream.}} + "SDFG.put"(%in0, %arg0) : (!SDFG.stream>, tensor<2xi32>) -> () + "SDFG.put"(%in1, %arg1) : (!SDFG.stream>, tensor<1024xi64>) -> () + %res = "SDFG.get"(%out) : (!SDFG.stream>) -> tensor<1024xi64> + + "SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> () + + return %res : tensor<1024xi64> +} + +// ----- + +func.func @wrong_stream_direction(%arg0: tensor<1024xi64>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> { + %dfg = "SDFG.init"() : () -> !SDFG.dfg + %in0 = "SDFG.make_stream" (%dfg) { name = "inXXX0", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + // expected-error @+1 {{Stream #1 of process `add_eint` must be an input stream.}} + "SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind } : + (!SDFG.dfg, !SDFG.stream>, !SDFG.stream>, !SDFG.stream>) -> () + "SDFG.start"(%dfg) : (!SDFG.dfg) -> () + + "SDFG.put"(%in0, %arg0) : (!SDFG.stream>, tensor<1024xi64>) -> () + "SDFG.put"(%in1, %arg1) : (!SDFG.stream>, tensor<1024xi64>) -> () + %res = "SDFG.get"(%out) : (!SDFG.stream>) -> tensor<1024xi64> + + "SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> () + + return %res : tensor<1024xi64> +} diff --git a/compiler/tests/check_tests/Dialect/SDFG/ops.mlir b/compiler/tests/check_tests/Dialect/SDFG/ops.mlir new file mode 100644 index 000000000..1f7cbd16c --- /dev/null +++ b/compiler/tests/check_tests/Dialect/SDFG/ops.mlir @@ -0,0 +1,45 @@ +// RUN: concretecompiler --action=roundtrip --split-input-file %s 2>&1| FileCheck %s + +// CHECK: func.func @init_shutdown +func.func @init_shutdown() -> () { + // CHECK-NEXT: %[[DFG:.*]] = "SDFG.init"() : () -> !SDFG.dfg + // CHECK-NEXT: "SDFG.shutdown"(%[[DFG]]) : (!SDFG.dfg) -> () + // CHECK-NEXT: return + + %dfg = "SDFG.init"() : () -> !SDFG.dfg + "SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> () + return +} + +// ----- + +// CHECK: func.func @simple_graph(%[[Varg0:.*]]: tensor<1024xi64>, %[[Varg1:.*]]: tensor<1024xi64>) -> tensor<1024xi64> { +// CHECK-NEXT: %[[V0:.*]] = "SDFG.init"() : () -> !SDFG.dfg +// CHECK-NEXT: %[[V1:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "in0", type = #SDFG.stream_kind} : (!SDFG.dfg) -> !SDFG.stream> +// CHECK-NEXT: %[[V2:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "in1", type = #SDFG.stream_kind} : (!SDFG.dfg) -> !SDFG.stream> +// CHECK-NEXT: %[[V3:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "out", type = #SDFG.stream_kind} : (!SDFG.dfg) -> !SDFG.stream> +// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) {type = #SDFG.process_kind} : (!SDFG.dfg, !SDFG.stream>, !SDFG.stream>, !SDFG.stream>) -> () +// CHECK-NEXT: "SDFG.start"(%[[V0]]) : (!SDFG.dfg) -> () +// CHECK-NEXT: "SDFG.put"(%[[V1]], %[[Varg0]]) : (!SDFG.stream>, tensor<1024xi64>) -> () +// CHECK-NEXT: "SDFG.put"(%[[V2]], %[[Varg1]]) : (!SDFG.stream>, tensor<1024xi64>) -> () +// CHECK-NEXT: %[[V4:.*]] = "SDFG.get"(%[[V3]]) : (!SDFG.stream>) -> tensor<1024xi64> +// CHECK-NEXT: "SDFG.shutdown"(%[[V0]]) : (!SDFG.dfg) -> () +// CHECK-NEXT: return %[[V4]] : tensor<1024xi64> +// CHECK-NEXT: } +func.func @simple_graph(%arg0: tensor<1024xi64>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> { + %dfg = "SDFG.init"() : () -> !SDFG.dfg + %in0 = "SDFG.make_stream" (%dfg) { name = "in0", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + %out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind } : (!SDFG.dfg) -> !SDFG.stream> + "SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind } : + (!SDFG.dfg, !SDFG.stream>, !SDFG.stream>, !SDFG.stream>) -> () + "SDFG.start"(%dfg) : (!SDFG.dfg) -> () + + "SDFG.put"(%in0, %arg0) : (!SDFG.stream>, tensor<1024xi64>) -> () + "SDFG.put"(%in1, %arg1) : (!SDFG.stream>, tensor<1024xi64>) -> () + %res = "SDFG.get"(%out) : (!SDFG.stream>) -> tensor<1024xi64> + + "SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> () + + return %res : tensor<1024xi64> +}