mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): Add new dialect SDFG for static data flow graphs
This adds a new dialect called "SDFG" for data flow graphs. An SDFG data flow graph is composed of a set of processes, connected through data streams. Special streams allow for data to be injected into and to be retrieved from the data flow graph. The dialect is intended to be lowered to API calls that allow for offloading of the graph on hardware accelerators.
This commit is contained in:
@@ -4,3 +4,4 @@ add_subdirectory(TFHE)
|
||||
add_subdirectory(Concrete)
|
||||
add_subdirectory(BConcrete)
|
||||
add_subdirectory(RT)
|
||||
add_subdirectory(SDFG)
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
17
compiler/include/concretelang/Dialect/SDFG/IR/CMakeLists.txt
Normal file
17
compiler/include/concretelang/Dialect/SDFG/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,17 @@
|
||||
set(LLVM_TARGET_DEFINITIONS SDFGOps.td)
|
||||
mlir_tablegen(SDFGEnums.h.inc -gen-enum-decls)
|
||||
mlir_tablegen(SDFGEnums.cpp.inc -gen-enum-defs)
|
||||
mlir_tablegen(SDFGOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(SDFGOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(SDFGTypes.h.inc -gen-typedef-decls -typedefs-dialect=SDFG)
|
||||
mlir_tablegen(SDFGTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=SDFG)
|
||||
mlir_tablegen(SDFGDialect.h.inc -gen-dialect-decls -dialect=SDFG)
|
||||
mlir_tablegen(SDFGDialect.cpp.inc -gen-dialect-defs -dialect=SDFG)
|
||||
mlir_tablegen(SDFGAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=SDFG)
|
||||
mlir_tablegen(SDFGAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=SDFG)
|
||||
add_public_tablegen_target(MLIRSDFGOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRSDFGOpsIncGen)
|
||||
|
||||
add_concretelang_doc(SDFGOps SDFGDialect concretelang/ -gen-dialect-doc -dialect=SDFG)
|
||||
add_concretelang_doc(SDFGOps SDFGOps concretelang/ -gen-op-doc)
|
||||
add_concretelang_doc(SDFGTypes SDFGTypes concretelang/ -gen-typedef-doc)
|
||||
14
compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.h
Normal file
14
compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.h
Normal file
@@ -0,0 +1,14 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFGDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h.inc"
|
||||
|
||||
#endif
|
||||
28
compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.td
Normal file
28
compiler/include/concretelang/Dialect/SDFG/IR/SDFGDialect.td
Normal file
@@ -0,0 +1,28 @@
|
||||
//===- SDFGDialect.td - SDFG dialect ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_DIALECT
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def SDFG_Dialect : Dialect {
|
||||
let name = "SDFG";
|
||||
let summary = "Dialect for the construction of static data flow graphs";
|
||||
let description = [{
|
||||
A dialect for the construction of static data flow graphs. The
|
||||
data flow graph is composed of a set of processes, connected
|
||||
through data streams. Special streams allow for data to be
|
||||
injected into and to be retrieved from the data flow graph.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::SDFG";
|
||||
let useDefaultTypePrinterParser = 1;
|
||||
let useDefaultAttributePrinterParser = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
22
compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.h
Normal file
22
compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.h
Normal file
@@ -0,0 +1,22 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGOPS_H
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFGOPS_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGEnums.h.inc"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.h.inc"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h.inc"
|
||||
|
||||
#endif
|
||||
202
compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td
Normal file
202
compiler/include/concretelang/Dialect/SDFG/IR/SDFGOps.td
Normal file
@@ -0,0 +1,202 @@
|
||||
//===- SDFGOps.td - High level SDFG dialect ops ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_OPS
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/IR/EnumAttr.td"
|
||||
|
||||
include "concretelang/Dialect/SDFG/IR/SDFGDialect.td"
|
||||
include "concretelang/Dialect/SDFG/IR/SDFGTypes.td"
|
||||
|
||||
class SDFG_Op<string mnemonic, list<Trait> traits = []> :
|
||||
Op<SDFG_Dialect, mnemonic, traits>;
|
||||
|
||||
def StreamKindHostToDevice : I32EnumAttrCase<"host_to_device", 0>;
|
||||
def StreamKindOnDevice : I32EnumAttrCase<"on_device", 1>;
|
||||
def StreamKindDeviceToHost : I32EnumAttrCase<"device_to_host", 2>;
|
||||
|
||||
def StreamKind : I32EnumAttr<"StreamKind", "Stream kind",
|
||||
[StreamKindOnDevice, StreamKindHostToDevice, StreamKindDeviceToHost]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::concretelang::SDFG";
|
||||
}
|
||||
|
||||
def StreamKindAttr : EnumAttr<SDFG_Dialect, StreamKind, "stream_kind"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
def SDFG_Init : SDFG_Op<"init"> {
|
||||
let summary = "Initializes the streaming framework";
|
||||
|
||||
let description = [{
|
||||
Initializes the streaming framework. This operation must be
|
||||
performed before control reaches any other operation from the
|
||||
dialect.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.init" : () -> !SDFG.dfg
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
let results = (outs SDFG_DFG);
|
||||
}
|
||||
|
||||
def SDFG_MakeStream : SDFG_Op<"make_stream"> {
|
||||
let summary = "Returns a new SDFG stream";
|
||||
|
||||
let description = [{
|
||||
Returns a new SDFG stream, transporting data either between
|
||||
processes on the device, from the host to the device or from
|
||||
the device to the host. All streams are typed, allowing data
|
||||
to be read / written through `SDFG.get` and `SDFG.put` only
|
||||
using the stream's type.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.make_stream" { name = "stream", type = #SDFG.stream_kind<host_to_device> }(%dfg)
|
||||
: (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SDFG_DFG:$dfg, StrAttr:$name, StreamKindAttr:$type);
|
||||
let results = (outs SDFG_Stream);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool createsInputStream() {
|
||||
return type() == StreamKind::host_to_device ||
|
||||
type() == StreamKind::on_device;
|
||||
}
|
||||
|
||||
bool createsOutputStream() {
|
||||
return type() == StreamKind::device_to_host ||
|
||||
type() == StreamKind::on_device;
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def ProcessKindAddEint : I32EnumAttrCase<"add_eint", 0>;
|
||||
def ProcessKindAddEintInt : I32EnumAttrCase<"add_eint_int", 1>;
|
||||
def ProcessKindMulEintInt : I32EnumAttrCase<"mul_eint_int", 2>;
|
||||
def ProcessKindNegEint : I32EnumAttrCase<"neg_eint", 3>;
|
||||
def ProcessKindKeyswitch : I32EnumAttrCase<"keyswitch", 4>;
|
||||
def ProcessKindBootstrap : I32EnumAttrCase<"bootstrap", 5>;
|
||||
|
||||
def ProcessKind : I32EnumAttr<"ProcessKind", "Process kind",
|
||||
[ProcessKindAddEint, ProcessKindAddEintInt, ProcessKindMulEintInt,
|
||||
ProcessKindNegEint, ProcessKindKeyswitch, ProcessKindBootstrap]> {
|
||||
let genSpecializedAttr = 0;
|
||||
let cppNamespace = "::mlir::concretelang::SDFG";
|
||||
}
|
||||
|
||||
def ProcessKindAttr : EnumAttr<SDFG_Dialect, ProcessKind, "process_kind"> {
|
||||
let assemblyFormat = "`<` $value `>`";
|
||||
}
|
||||
|
||||
def SDFG_MakeProcess : SDFG_Op<"make_process"> {
|
||||
let summary = "Creates a new SDFG process";
|
||||
|
||||
let description = [{
|
||||
Creates a new SDFG process and connects it to the input and
|
||||
output streams.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%in0 = "SDFG.make_stream" { type = #SDFG.stream_kind<host_to_device> }(%dfg) : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%in1 = "SDFG.make_stream" { type = #SDFG.stream_kind<host_to_device> }(%dfg) : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%out = "SDFG.make_stream" { type = #SDFG.stream_kind<device_to_host> }(%dfg) : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
"SDFG.make_process" { type = #SDFG.process_kind<add_eint> }(%dfg, %in0, %in1, %out) :
|
||||
(!SDFG.dfg, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>) -> ()
|
||||
```
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::LogicalResult checkStreams(size_t numIn, size_t numOut);
|
||||
}];
|
||||
|
||||
let arguments = (ins ProcessKindAttr:$type, SDFG_DFG:$dfg, Variadic<SDFG_Stream>:$streams);
|
||||
let results = (outs);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SDFG_Put : SDFG_Op<"put"> {
|
||||
let summary = "Writes a data element to a stream";
|
||||
|
||||
let description = [{
|
||||
Writes the input operand to the specified stream. The
|
||||
operand's type must meet the element type of the stream.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.put" (%stream, %data) : (!SDFG.stream<1024xi64>, tensor<1024xi64>) -> ()
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SDFG_Stream:$stream, AnyType:$data);
|
||||
let results = (outs);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SDFG_Get : SDFG_Op<"get"> {
|
||||
let summary = "Retrieves a data element from a stream";
|
||||
|
||||
let description = [{
|
||||
Retrieves a single data element from the specified stream
|
||||
(i.e., an instance of the element type of the stream).
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.get" (%stream) : (!SDFG.stream<1024xi64>) -> (tensor<1024xi64>)
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SDFG_Stream:$stream);
|
||||
let results = (outs AnyType:$data);
|
||||
}
|
||||
|
||||
def SDFG_Start : SDFG_Op<"start"> {
|
||||
let summary = "Finalizes the creation of an SDFG and starts execution of its processes";
|
||||
|
||||
|
||||
let description = [{
|
||||
Finalizes the creation of an SDFG and starts execution of its
|
||||
processes. Any creation of streams and processes must take
|
||||
place before control reaches this operation.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.start"(%dfg) : !SDFG.dfg
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SDFG_DFG:$dfg);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
|
||||
def SDFG_Shutdown : SDFG_Op<"shutdown"> {
|
||||
let summary = "Shuts down the streaming framework";
|
||||
|
||||
let description = [{
|
||||
Shuts down the streaming framework. This operation must be
|
||||
performed after any other operation from the dialect.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
"SDFG.shutdown" (%dfg) : !SDFG.dfg
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins SDFG_DFG:$dfg);
|
||||
let results = (outs);
|
||||
}
|
||||
|
||||
#endif
|
||||
16
compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.h
Normal file
16
compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.h
Normal file
@@ -0,0 +1,16 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFGTYPES_H
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFGTYPES_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h.inc"
|
||||
|
||||
#endif
|
||||
37
compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.td
Normal file
37
compiler/include/concretelang/Dialect/SDFG/IR/SDFGTypes.td
Normal file
@@ -0,0 +1,37 @@
|
||||
#ifndef CONCRETELANG_DIALECT_SDFG_IR_SDFG_TYPES
|
||||
#define CONCRETELANG_DIALECT_SDFG_IR_SDFG_TYPES
|
||||
|
||||
include "concretelang/Dialect/SDFG/IR/SDFGDialect.td"
|
||||
include "mlir/IR/BuiltinTypes.td"
|
||||
|
||||
class SDFG_Type<string name, list<Trait> traits = []> :
|
||||
TypeDef<SDFG_Dialect, name, traits> { }
|
||||
|
||||
def SDFG_DFG : SDFG_Type<"DFG", []> {
|
||||
let mnemonic = "dfg";
|
||||
|
||||
let summary = "An SDFG data flow graph";
|
||||
|
||||
let description = [{
|
||||
A handle to an SDFG data flow graph
|
||||
}];
|
||||
|
||||
let parameters = (ins);
|
||||
let hasCustomAssemblyFormat = 0;
|
||||
}
|
||||
|
||||
|
||||
def SDFG_Stream : SDFG_Type<"Stream", []> {
|
||||
let mnemonic = "stream";
|
||||
|
||||
let summary = "An SDFG data stream";
|
||||
|
||||
let description = [{
|
||||
An SDFG stream to connect SDFG processes.
|
||||
}];
|
||||
|
||||
let parameters = (ins "Type":$elementType);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -241,7 +241,7 @@ const LLVM_STATIC_LIBS: [&str; 51] = [
|
||||
"LLVMX86Info",
|
||||
];
|
||||
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 29] = [
|
||||
const CONCRETE_COMPILER_LIBS: [&str; 30] = [
|
||||
"RTDialect",
|
||||
"RTDialectTransforms",
|
||||
"ConcretelangSupport",
|
||||
@@ -271,6 +271,7 @@ const CONCRETE_COMPILER_LIBS: [&str; 29] = [
|
||||
"FHEDialectAnalysis",
|
||||
"ConcreteDialect",
|
||||
"RTDialectAnalysis",
|
||||
"SDFGDialect"
|
||||
];
|
||||
|
||||
fn main() {
|
||||
|
||||
@@ -4,3 +4,4 @@ add_subdirectory(TFHE)
|
||||
add_subdirectory(Concrete)
|
||||
add_subdirectory(BConcrete)
|
||||
add_subdirectory(RT)
|
||||
add_subdirectory(SDFG)
|
||||
|
||||
1
compiler/lib/Dialect/SDFG/CMakeLists.txt
Normal file
1
compiler/lib/Dialect/SDFG/CMakeLists.txt
Normal file
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
13
compiler/lib/Dialect/SDFG/IR/CMakeLists.txt
Normal file
13
compiler/lib/Dialect/SDFG/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(
|
||||
SDFGDialect
|
||||
SDFGDialect.cpp
|
||||
SDFGOps.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/SDFG
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR)
|
||||
|
||||
target_link_libraries(SDFGDialect PUBLIC MLIRIR)
|
||||
58
compiler/lib/Dialect/SDFG/IR/SDFGDialect.cpp
Normal file
58
compiler/lib/Dialect/SDFG/IR/SDFGDialect.cpp
Normal file
@@ -0,0 +1,58 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
|
||||
using namespace mlir::concretelang::SDFG;
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.cpp.inc"
|
||||
|
||||
void SDFGDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.cpp.inc"
|
||||
>();
|
||||
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.cpp.inc"
|
||||
>();
|
||||
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGAttributes.cpp.inc"
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.cpp.inc"
|
||||
|
||||
void StreamType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<" << getElementType() << ">";
|
||||
}
|
||||
|
||||
mlir::Type StreamType::parse(mlir::AsmParser &p) {
|
||||
if (p.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
mlir::Type t;
|
||||
if (p.parseType(t))
|
||||
return mlir::Type();
|
||||
|
||||
if (p.parseGreater())
|
||||
return mlir::Type();
|
||||
|
||||
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
|
||||
|
||||
return getChecked(loc, loc.getContext(), t);
|
||||
}
|
||||
91
compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp
Normal file
91
compiler/lib/Dialect/SDFG/IR/SDFGOps.cpp
Normal file
@@ -0,0 +1,91 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGTypes.h"
|
||||
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGEnums.cpp.inc"
|
||||
#include <mlir/Support/LogicalResult.h>
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGOps.cpp.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace SDFG {
|
||||
mlir::LogicalResult Put::verify() {
|
||||
mlir::Type streamElementType =
|
||||
stream().getType().cast<StreamType>().getElementType();
|
||||
mlir::Type elementType = data().getType();
|
||||
|
||||
if (streamElementType != elementType) {
|
||||
emitError()
|
||||
<< "The type " << elementType
|
||||
<< " of the element to be written does not match the element type "
|
||||
<< streamElementType << " of the stream.";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult MakeProcess::checkStreams(size_t numIn, size_t numOut) {
|
||||
mlir::OperandRange streams = this->streams();
|
||||
|
||||
if (streams.size() != numIn + numOut) {
|
||||
emitError() << "Process `" << stringifyProcessKind(type())
|
||||
<< "` expects 3 streams, but " << streams.size()
|
||||
<< " were given.";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < numIn; i++) {
|
||||
MakeStream in = dyn_cast_or_null<MakeStream>(streams[i].getDefiningOp());
|
||||
|
||||
if (in && !in.createsInputStream()) {
|
||||
emitError() << "Stream #" << (i + 1) << " of process `"
|
||||
<< stringifyProcessKind(type())
|
||||
<< "` must be an input stream.";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = numIn; i < numIn + numOut; i++) {
|
||||
MakeStream out = dyn_cast_or_null<MakeStream>(streams[i].getDefiningOp());
|
||||
|
||||
if (out && !out.createsOutputStream()) {
|
||||
emitError() << "Stream #" << (i + 1) << " of process `"
|
||||
<< stringifyProcessKind(type())
|
||||
<< "` must be an output stream.";
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
mlir::LogicalResult MakeProcess::verify() {
|
||||
switch (type()) {
|
||||
case ProcessKind::add_eint:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::add_eint_int:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::mul_eint_int:
|
||||
return checkStreams(2, 1);
|
||||
case ProcessKind::neg_eint:
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::keyswitch:
|
||||
return checkStreams(1, 1);
|
||||
case ProcessKind::bootstrap:
|
||||
return checkStreams(2, 1);
|
||||
}
|
||||
|
||||
return mlir::failure();
|
||||
}
|
||||
} // namespace SDFG
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <concretelang/Dialect/FHELinalg/IR/FHELinalgDialect.h>
|
||||
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/SDFG/IR/SDFGDialect.h>
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
@@ -69,16 +70,16 @@ CompilationContext::~CompilationContext() {
|
||||
mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
if (this->mlirContext == nullptr) {
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<mlir::concretelang::RT::RTDialect,
|
||||
mlir::concretelang::FHE::FHEDialect,
|
||||
mlir::concretelang::TFHE::TFHEDialect,
|
||||
mlir::concretelang::FHELinalg::FHELinalgDialect,
|
||||
mlir::concretelang::Concrete::ConcreteDialect,
|
||||
mlir::concretelang::BConcrete::BConcreteDialect,
|
||||
mlir::func::FuncDialect, mlir::memref::MemRefDialect,
|
||||
mlir::linalg::LinalgDialect, mlir::LLVM::LLVMDialect,
|
||||
mlir::scf::SCFDialect, mlir::omp::OpenMPDialect,
|
||||
mlir::bufferization::BufferizationDialect>();
|
||||
registry.insert<
|
||||
mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect,
|
||||
mlir::concretelang::TFHE::TFHEDialect,
|
||||
mlir::concretelang::FHELinalg::FHELinalgDialect,
|
||||
mlir::concretelang::Concrete::ConcreteDialect,
|
||||
mlir::concretelang::BConcrete::BConcreteDialect,
|
||||
mlir::concretelang::SDFG::SDFGDialect, mlir::func::FuncDialect,
|
||||
mlir::memref::MemRefDialect, mlir::linalg::LinalgDialect,
|
||||
mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect,
|
||||
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
|
||||
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
|
||||
|
||||
@@ -16,6 +16,7 @@ target_link_libraries(
|
||||
ConcreteDialect
|
||||
TFHEDialect
|
||||
FHEDialect
|
||||
SDFGDialect
|
||||
ConcretelangSupport
|
||||
ConcretelangTransforms
|
||||
MLIRIR
|
||||
|
||||
41
compiler/tests/check_tests/Dialect/SDFG/invalid.mlir
Normal file
41
compiler/tests/check_tests/Dialect/SDFG/invalid.mlir
Normal file
@@ -0,0 +1,41 @@
|
||||
// RUN: concretecompiler --split-input-file --verify-diagnostics --action=roundtrip %s
|
||||
|
||||
func.func @wrong_element_type(%arg0: tensor<2xi32>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> {
|
||||
%dfg = "SDFG.init"() : () -> !SDFG.dfg
|
||||
%in0 = "SDFG.make_stream" (%dfg) { name = "in0", type = #SDFG.stream_kind<host_to_device> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind<host_to_device> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind<device_to_host> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
"SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind<add_eint> } :
|
||||
(!SDFG.dfg, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>) -> ()
|
||||
"SDFG.start"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
// expected-error @+1 {{The type 'tensor<2xi32>' of the element to be written does not match the element type 'tensor<1024xi64>' of the stream.}}
|
||||
"SDFG.put"(%in0, %arg0) : (!SDFG.stream<tensor<1024xi64>>, tensor<2xi32>) -> ()
|
||||
"SDFG.put"(%in1, %arg1) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
%res = "SDFG.get"(%out) : (!SDFG.stream<tensor<1024xi64>>) -> tensor<1024xi64>
|
||||
|
||||
"SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
return %res : tensor<1024xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @wrong_stream_direction(%arg0: tensor<1024xi64>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> {
|
||||
%dfg = "SDFG.init"() : () -> !SDFG.dfg
|
||||
%in0 = "SDFG.make_stream" (%dfg) { name = "inXXX0", type = #SDFG.stream_kind<device_to_host> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind<host_to_device> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind<device_to_host> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
// expected-error @+1 {{Stream #1 of process `add_eint` must be an input stream.}}
|
||||
"SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind<add_eint> } :
|
||||
(!SDFG.dfg, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>) -> ()
|
||||
"SDFG.start"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
"SDFG.put"(%in0, %arg0) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
"SDFG.put"(%in1, %arg1) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
%res = "SDFG.get"(%out) : (!SDFG.stream<tensor<1024xi64>>) -> tensor<1024xi64>
|
||||
|
||||
"SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
return %res : tensor<1024xi64>
|
||||
}
|
||||
45
compiler/tests/check_tests/Dialect/SDFG/ops.mlir
Normal file
45
compiler/tests/check_tests/Dialect/SDFG/ops.mlir
Normal file
@@ -0,0 +1,45 @@
|
||||
// RUN: concretecompiler --action=roundtrip --split-input-file %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: func.func @init_shutdown
|
||||
func.func @init_shutdown() -> () {
|
||||
// CHECK-NEXT: %[[DFG:.*]] = "SDFG.init"() : () -> !SDFG.dfg
|
||||
// CHECK-NEXT: "SDFG.shutdown"(%[[DFG]]) : (!SDFG.dfg) -> ()
|
||||
// CHECK-NEXT: return
|
||||
|
||||
%dfg = "SDFG.init"() : () -> !SDFG.dfg
|
||||
"SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: func.func @simple_graph(%[[Varg0:.*]]: tensor<1024xi64>, %[[Varg1:.*]]: tensor<1024xi64>) -> tensor<1024xi64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "SDFG.init"() : () -> !SDFG.dfg
|
||||
// CHECK-NEXT: %[[V1:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "in0", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
// CHECK-NEXT: %[[V2:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "in1", type = #SDFG.stream_kind<host_to_device>} : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
// CHECK-NEXT: %[[V3:.*]] = "SDFG.make_stream"(%[[V0]]) {name = "out", type = #SDFG.stream_kind<device_to_host>} : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
// CHECK-NEXT: "SDFG.make_process"(%[[V0]], %[[V1]], %[[V2]], %[[V3]]) {type = #SDFG.process_kind<add_eint>} : (!SDFG.dfg, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>) -> ()
|
||||
// CHECK-NEXT: "SDFG.start"(%[[V0]]) : (!SDFG.dfg) -> ()
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V1]], %[[Varg0]]) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
// CHECK-NEXT: "SDFG.put"(%[[V2]], %[[Varg1]]) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
// CHECK-NEXT: %[[V4:.*]] = "SDFG.get"(%[[V3]]) : (!SDFG.stream<tensor<1024xi64>>) -> tensor<1024xi64>
|
||||
// CHECK-NEXT: "SDFG.shutdown"(%[[V0]]) : (!SDFG.dfg) -> ()
|
||||
// CHECK-NEXT: return %[[V4]] : tensor<1024xi64>
|
||||
// CHECK-NEXT: }
|
||||
func.func @simple_graph(%arg0: tensor<1024xi64>, %arg1: tensor<1024xi64>) -> tensor<1024xi64> {
|
||||
%dfg = "SDFG.init"() : () -> !SDFG.dfg
|
||||
%in0 = "SDFG.make_stream" (%dfg) { name = "in0", type = #SDFG.stream_kind<host_to_device> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%in1 = "SDFG.make_stream" (%dfg) { name = "in1", type = #SDFG.stream_kind<host_to_device> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
%out = "SDFG.make_stream" (%dfg) { name = "out", type = #SDFG.stream_kind<device_to_host> } : (!SDFG.dfg) -> !SDFG.stream<tensor<1024xi64>>
|
||||
"SDFG.make_process" (%dfg, %in0, %in1, %out) { type = #SDFG.process_kind<add_eint> } :
|
||||
(!SDFG.dfg, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>, !SDFG.stream<tensor<1024xi64>>) -> ()
|
||||
"SDFG.start"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
"SDFG.put"(%in0, %arg0) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
"SDFG.put"(%in1, %arg1) : (!SDFG.stream<tensor<1024xi64>>, tensor<1024xi64>) -> ()
|
||||
%res = "SDFG.get"(%out) : (!SDFG.stream<tensor<1024xi64>>) -> tensor<1024xi64>
|
||||
|
||||
"SDFG.shutdown"(%dfg) : (!SDFG.dfg) -> ()
|
||||
|
||||
return %res : tensor<1024xi64>
|
||||
}
|
||||
Reference in New Issue
Block a user