mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(concrete-compiler): adds a tracing op in all dialects.
This commit is contained in:
@@ -24,11 +24,13 @@
|
||||
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
|
||||
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
|
||||
#include "concretelang/Conversion/TracingToCAPI/Pass.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
|
||||
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "concretelang/Conversion/Passes.h.inc"
|
||||
|
||||
@@ -69,6 +69,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
|
||||
}
|
||||
|
||||
def TracingToCAPI : Pass<"tracing-to-capi", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the Tracing dialect to CAPI calls";
|
||||
let description = [{ Lowers operations from the Tracing dialect to CAPI calls }];
|
||||
let constructor = "mlir::concretelang::createConvertTracingToCAPIPass()";
|
||||
let dependentDialects = ["mlir::concretelang::Tracing::TracingDialect"];
|
||||
}
|
||||
|
||||
def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
|
||||
let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls";
|
||||
let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }];
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
|
||||
#define ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
/// Create a pass to convert `Tracing` dialect to CAPI calls.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTracingToCAPIPass();
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -5,3 +5,4 @@ add_subdirectory(Concrete)
|
||||
add_subdirectory(BConcrete)
|
||||
add_subdirectory(RT)
|
||||
add_subdirectory(SDFG)
|
||||
add_subdirectory(Tracing)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
#define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
|
||||
@@ -179,4 +179,5 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
|
||||
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$result);
|
||||
}
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
add_subdirectory(IR)
|
||||
@@ -0,0 +1,9 @@
|
||||
set(LLVM_TARGET_DEFINITIONS TracingOps.td)
|
||||
mlir_tablegen(TracingOps.h.inc -gen-op-decls)
|
||||
mlir_tablegen(TracingOps.cpp.inc -gen-op-defs)
|
||||
mlir_tablegen(TracingOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Tracing)
|
||||
mlir_tablegen(TracingOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Tracing)
|
||||
mlir_tablegen(TracingOpsDialect.h.inc -gen-dialect-decls -dialect=Tracing)
|
||||
mlir_tablegen(TracingOpsDialect.cpp.inc -gen-dialect-defs -dialect=Tracing)
|
||||
add_public_tablegen_target(MLIRTracingOpsIncGen)
|
||||
add_dependencies(mlir-headers MLIRTracingOpsIncGen)
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H
|
||||
#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,23 @@
|
||||
//===- TracingDialect.td - Tracing dialect ----------------*- tablegen -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT
|
||||
#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Tracing_Dialect : Dialect {
|
||||
let name = "Tracing";
|
||||
let summary = "Tracing dialect";
|
||||
let description = [{
|
||||
A dialect to print program values at runtime.
|
||||
}];
|
||||
let cppNamespace = "::mlir::concretelang::Tracing";
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,18 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H
|
||||
#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H
|
||||
|
||||
#include <mlir/IR/Builders.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h.inc"
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,57 @@
|
||||
//===- TracingOps.td - Tracing dialect ops ----------------*- tablegen
|
||||
//-*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS
|
||||
#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS
|
||||
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
|
||||
include "concretelang/Dialect/Tracing/IR/TracingDialect.td"
|
||||
include "concretelang/Dialect/FHE/IR/FHETypes.td"
|
||||
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
|
||||
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
|
||||
|
||||
class Tracing_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<Tracing_Dialect, mnemonic, traits>;
|
||||
|
||||
def Tracing_TraceCiphertextOp : Tracing_Op<"trace_ciphertext"> {
|
||||
let summary = "Prints a ciphertext.";
|
||||
|
||||
let arguments = (ins
|
||||
Type<Or<[
|
||||
FHE_EncryptedIntegerType.predicate,
|
||||
FHE_EncryptedSignedIntegerType.predicate,
|
||||
TFHE_GLWECipherTextType.predicate,
|
||||
Concrete_LweCiphertextType.predicate,
|
||||
1DTensorOf<[I64]>.predicate,
|
||||
MemRefRankOf<[I64], [1]>.predicate
|
||||
]>>: $ciphertext,
|
||||
OptionalAttr<StrAttr>: $msg,
|
||||
OptionalAttr<I32Attr>: $nmsb
|
||||
);
|
||||
}
|
||||
|
||||
def Tracing_TracePlaintextOp : Tracing_Op<"trace_plaintext"> {
|
||||
let summary = "Prints a plaintext.";
|
||||
|
||||
let arguments = (ins
|
||||
AnyInteger: $plaintext,
|
||||
OptionalAttr<StrAttr>: $msg,
|
||||
OptionalAttr<I32Attr>: $nmsb
|
||||
);
|
||||
}
|
||||
|
||||
def Tracing_TraceMessageOp : Tracing_Op<"trace_message"> {
|
||||
let summary = "Prints a message.";
|
||||
|
||||
let arguments = (ins OptionalAttr<StrAttr> : $msg);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,19 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
#define CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H
|
||||
|
||||
namespace mlir {
|
||||
class DialectRegistry;
|
||||
|
||||
namespace concretelang {
|
||||
namespace Tracing {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry);
|
||||
} // namespace Tracing
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -215,5 +215,18 @@ void memref_batched_bootstrap_lwe_cuda_u64(
|
||||
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
|
||||
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
|
||||
mlir::concretelang::RuntimeContext *context);
|
||||
|
||||
// Tracing ////////////////////////////////////////////////////////////////////
|
||||
void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, char *message_ptr,
|
||||
uint32_t message_len, uint32_t msb);
|
||||
|
||||
void memref_trace_plaintext(uint64_t input, uint64_t input_width,
|
||||
char *message_ptr, uint32_t message_len,
|
||||
uint32_t msb);
|
||||
|
||||
void memref_trace_message(char *message_ptr, uint32_t message_len);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -48,6 +48,7 @@ char memref_encode_expand_lut_for_bootstrap[] =
|
||||
"memref_encode_expand_lut_for_bootstrap";
|
||||
char memref_encode_expand_lut_for_woppbs[] =
|
||||
"memref_encode_expand_lut_for_woppbs";
|
||||
char memref_trace[] = "memref_trace";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
@@ -185,6 +186,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
{memref1DType, memref1DType, memref1DType, memref1DType,
|
||||
rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()},
|
||||
{});
|
||||
} else if (funcName == memref_trace) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
@@ -431,6 +438,7 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
// Make sure that no ops from `FHE` remain after the lowering
|
||||
target.addIllegalDialect<BConcrete::BConcreteDialect>();
|
||||
|
||||
@@ -5,6 +5,7 @@ add_subdirectory(TFHEToConcrete)
|
||||
add_subdirectory(FHETensorOpsToLinalg)
|
||||
add_subdirectory(ConcreteToBConcrete)
|
||||
add_subdirectory(BConcreteToCAPI)
|
||||
add_subdirectory(TracingToCAPI)
|
||||
add_subdirectory(SDFGToStreamEmulator)
|
||||
add_subdirectory(MLIRLowerableDialectsToLLVM)
|
||||
add_subdirectory(LinalgExtras)
|
||||
|
||||
@@ -36,6 +36,7 @@
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@@ -43,6 +44,7 @@
|
||||
|
||||
namespace Concrete = ::mlir::concretelang::Concrete;
|
||||
namespace BConcrete = ::mlir::concretelang::BConcrete;
|
||||
namespace Tracing = ::mlir::concretelang::Tracing;
|
||||
|
||||
namespace {
|
||||
struct ConcreteToBConcretePass
|
||||
@@ -986,6 +988,16 @@ void ConcreteToBConcretePass::runOnOperation() {
|
||||
converter.isLegal(op->getOperandTypes());
|
||||
});
|
||||
|
||||
// Conversion of Tracing dialect
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
Tracing::TraceCiphertextOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
Tracing::TracePlaintextOp>>(&getContext(), converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
|
||||
target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TracePlaintextOp>(
|
||||
target, converter);
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/Dialect/Linalg/IR/Linalg.h>
|
||||
#include <mlir/IR/Operation.h>
|
||||
|
||||
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
@@ -29,9 +28,12 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
namespace Tracing = mlir::concretelang::Tracing;
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
namespace fhe_to_tfhe_crt_conversion {
|
||||
@@ -583,6 +585,45 @@ struct ApplyLookupTableEintOpPattern
|
||||
};
|
||||
};
|
||||
|
||||
/// Rewriter for the `Tracing::trace_ciphertext` operation.
|
||||
struct TraceCiphertextOpPattern : CrtOpPattern<Tracing::TraceCiphertextOp> {
|
||||
|
||||
TraceCiphertextOpPattern(mlir::MLIRContext *context,
|
||||
concretelang::CrtLoweringParameters params,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: CrtOpPattern<Tracing::TraceCiphertextOp>(context, params, benefit) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(Tracing::TraceCiphertextOp op,
|
||||
Tracing::TraceCiphertextOp::Adaptor adaptor,
|
||||
mlir::ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
typing::TypeConverter converter{loweringParameters};
|
||||
mlir::Type ciphertextScalarType =
|
||||
converter.convertType(op.ciphertext().getType())
|
||||
.cast<mlir::RankedTensorType>()
|
||||
.getElementType();
|
||||
|
||||
for (size_t i = 0; i < (loweringParameters.nMods - 1); ++i) {
|
||||
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
|
||||
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
|
||||
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(i))});
|
||||
rewriter.create<Tracing::TraceCiphertextOp>(
|
||||
op.getLoc(), extractedCiphertext, op.msgAttr(), op.nmsbAttr());
|
||||
}
|
||||
|
||||
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
|
||||
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
|
||||
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(loweringParameters.nMods - 1))});
|
||||
rewriter.replaceOpWithNewOp<Tracing::TraceCiphertextOp>(
|
||||
op, extractedCiphertext, op.msgAttr(), op.nmsbAttr());
|
||||
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Rewriter for the `tensor::extract` operation.
|
||||
struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
|
||||
|
||||
@@ -924,6 +965,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
|
||||
target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<
|
||||
concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
|
||||
@@ -1006,6 +1049,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
|
||||
loweringParameters);
|
||||
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<lowering::TraceCiphertextOpPattern>(patterns.getContext(),
|
||||
loweringParameters);
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::tensor::GenerateOp, true>>(&getContext(), converter);
|
||||
|
||||
|
||||
@@ -28,9 +28,11 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
namespace FHE = mlir::concretelang::FHE;
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
namespace Tracing = mlir::concretelang::Tracing;
|
||||
namespace concretelang = mlir::concretelang;
|
||||
|
||||
namespace fhe_to_tfhe_scalar_conversion {
|
||||
@@ -546,10 +548,15 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
|
||||
converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
|
||||
converter);
|
||||
|
||||
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Patterns for `tracing` dialect.
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
Tracing::TraceCiphertextOp, true>>(&getContext(), converter);
|
||||
concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
|
||||
target, converter);
|
||||
|
||||
// Patterns for `bufferization` dialect operations.
|
||||
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
|
||||
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "concretelang/Support/Constants.h"
|
||||
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
|
||||
|
||||
@@ -332,6 +333,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::Tracing::TraceCiphertextOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
@@ -351,6 +354,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
|
||||
mlir::concretelang::GenericTypeConverterPattern<
|
||||
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
|
||||
converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::Tracing::TraceCiphertextOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<
|
||||
|
||||
@@ -19,9 +19,11 @@
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
namespace TFHE = mlir::concretelang::TFHE;
|
||||
namespace Concrete = mlir::concretelang::Concrete;
|
||||
namespace Tracing = mlir::concretelang::Tracing;
|
||||
|
||||
namespace {
|
||||
struct TFHEToConcretePass : public TFHEToConcreteBase<TFHEToConcretePass> {
|
||||
@@ -136,6 +138,31 @@ private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
struct TracePlaintextOpPattern
|
||||
: public mlir::OpRewritePattern<Tracing::TracePlaintextOp> {
|
||||
TracePlaintextOpPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<Tracing::TracePlaintextOp>(context, benefit) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Tracing::TracePlaintextOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto inputWidth =
|
||||
op.plaintext().getType().cast<mlir::IntegerType>().getWidth();
|
||||
if (inputWidth == 64) {
|
||||
op->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
|
||||
return mlir::success();
|
||||
}
|
||||
auto extendedInput = rewriter.create<mlir::arith::ExtUIOp>(
|
||||
op.getLoc(), rewriter.getI64Type(), op.plaintext());
|
||||
auto newOp = rewriter.replaceOpWithNewOp<Tracing::TracePlaintextOp>(
|
||||
op, extendedInput, op.msgAttr(), op.nmsbAttr());
|
||||
newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
|
||||
return ::mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void TFHEToConcretePass::runOnOperation() {
|
||||
auto op = this->getOperation();
|
||||
|
||||
@@ -171,6 +198,7 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
return FunctionConstantOpConversion<
|
||||
TFHEToConcreteTypeConverter>::isLegal(op, converter);
|
||||
});
|
||||
|
||||
// Add all patterns required to lower all ops from `TFHE` to
|
||||
// `Concrete`
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
@@ -228,6 +256,19 @@ void TFHEToConcretePass::runOnOperation() {
|
||||
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
|
||||
patterns, converter);
|
||||
|
||||
// Conversion of Tracing dialect
|
||||
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
|
||||
Tracing::TraceCiphertextOp>>(&getContext(), converter);
|
||||
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
|
||||
target, converter);
|
||||
patterns.add<TracePlaintextOpPattern>(&getContext(), converter);
|
||||
target.addLegalOp<mlir::arith::ExtUIOp>();
|
||||
target.addDynamicallyLegalOp<Tracing::TracePlaintextOp>(
|
||||
[&](Tracing::TracePlaintextOp op) {
|
||||
return (op.plaintext().getType().cast<mlir::IntegerType>().getWidth() ==
|
||||
64);
|
||||
});
|
||||
|
||||
// Conversion of RT Dialect Ops
|
||||
patterns.add<
|
||||
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
|
||||
|
||||
14
compiler/lib/Conversion/TracingToCAPI/CMakeLists.txt
Normal file
14
compiler/lib/Conversion/TracingToCAPI/CMakeLists.txt
Normal file
@@ -0,0 +1,14 @@
|
||||
add_mlir_dialect_library(
|
||||
TracingToCAPI
|
||||
TracingToCAPI.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
|
||||
DEPENDS
|
||||
TracingDialect
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR
|
||||
MLIRTransforms)
|
||||
|
||||
target_link_libraries(TracingToCAPI PUBLIC TracingDialect MLIRIR)
|
||||
238
compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp
Normal file
238
compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp
Normal file
@@ -0,0 +1,238 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <mlir/Pass/Pass.h>
|
||||
#include <mlir/Transforms/DialectConversion.h>
|
||||
|
||||
#include "concretelang/Conversion/Passes.h"
|
||||
#include "concretelang/Conversion/Tools.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
|
||||
|
||||
namespace {
|
||||
|
||||
namespace Tracing = mlir::concretelang::Tracing;
|
||||
namespace arith = mlir::arith;
|
||||
namespace func = mlir::func;
|
||||
namespace memref = mlir::memref;
|
||||
|
||||
char memref_trace_ciphertext[] = "memref_trace_ciphertext";
|
||||
char memref_trace_plaintext[] = "memref_trace_plaintext";
|
||||
char memref_trace_message[] = "memref_trace_message";
|
||||
|
||||
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
|
||||
size_t rank) {
|
||||
std::vector<int64_t> shape(rank, -1);
|
||||
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
|
||||
for (size_t i = 0; i < rank; i++) {
|
||||
expr = expr +
|
||||
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
|
||||
}
|
||||
return mlir::MemRefType::get(
|
||||
shape, rewriter.getI64Type(),
|
||||
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
|
||||
}
|
||||
|
||||
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
|
||||
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
|
||||
mlir::Type valueType = value.getType();
|
||||
|
||||
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
|
||||
return rewriter.create<mlir::memref::CastOp>(
|
||||
value.getLoc(),
|
||||
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
|
||||
value);
|
||||
} else {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
|
||||
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
|
||||
|
||||
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
|
||||
|
||||
mlir::FunctionType funcType;
|
||||
|
||||
if (funcName == memref_trace_ciphertext) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
} else if (funcName == memref_trace_plaintext) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{rewriter.getI64Type(), rewriter.getI64Type(),
|
||||
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
|
||||
rewriter.getI32Type(), rewriter.getI32Type()},
|
||||
{});
|
||||
} else if (funcName == memref_trace_message) {
|
||||
funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
|
||||
rewriter.getI32Type()},
|
||||
{});
|
||||
} else {
|
||||
op->emitError("unknwon external function") << funcName;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
return insertForwardDeclaration(op, rewriter, funcName, funcType);
|
||||
}
|
||||
|
||||
template <typename Op>
|
||||
void addNoOperands(Op op, mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {}
|
||||
|
||||
template <typename Op, char const *callee>
|
||||
struct TracingToCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
TracingToCAPICallPattern(
|
||||
::mlir::MLIRContext *context,
|
||||
std::function<void(Op op, llvm::SmallVector<mlir::Value> &,
|
||||
mlir::RewriterBase &)>
|
||||
addOperands = addNoOperands<Op>,
|
||||
mlir::PatternBenefit benefit = 1)
|
||||
: ::mlir::OpRewritePattern<Op>(context, benefit),
|
||||
addOperands(addOperands) {}
|
||||
|
||||
::mlir::LogicalResult
|
||||
matchAndRewrite(Op op, ::mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
// Create the operands
|
||||
mlir::SmallVector<mlir::Value> operands;
|
||||
// For all tensor operand get the corresponding casted buffer
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
mlir::Type type = operand.get().getType();
|
||||
if (!type.isa<mlir::MemRefType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(getCastedMemRef(rewriter, operand.get()));
|
||||
}
|
||||
}
|
||||
|
||||
// append additional argument
|
||||
addOperands(op, operands, rewriter);
|
||||
|
||||
// Insert forward declaration of the function
|
||||
if (insertForwardDeclarationOfTheCAPI(op, rewriter, callee).failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<func::CallOp>(op, callee, mlir::TypeRange{},
|
||||
operands);
|
||||
|
||||
return ::mlir::success();
|
||||
};
|
||||
|
||||
private:
|
||||
std::function<void(Op op, llvm::SmallVector<mlir::Value> &,
|
||||
mlir::RewriterBase &)>
|
||||
addOperands;
|
||||
};
|
||||
|
||||
void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
auto nmsb = op.nmsb().getValueOr(0);
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
operands.push_back(messageVal);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(nmsb)));
|
||||
}
|
||||
|
||||
void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
auto nmsb = op.nmsb().getValueOr(0);
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("input_width")));
|
||||
operands.push_back(messageVal);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(nmsb)));
|
||||
}
|
||||
|
||||
void traceMessageAddOperands(Tracing::TraceMessageOp op,
|
||||
mlir::SmallVector<mlir::Value> &operands,
|
||||
mlir::RewriterBase &rewriter) {
|
||||
auto msg = op.msg().getValueOr("");
|
||||
std::string msgName;
|
||||
std::stringstream stream;
|
||||
stream << rand();
|
||||
stream >> msgName;
|
||||
auto messageVal =
|
||||
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
|
||||
mlir::LLVM::linkage::Linkage::Linkonce);
|
||||
operands.push_back(messageVal);
|
||||
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
|
||||
}
|
||||
|
||||
struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {
|
||||
|
||||
TracingToCAPIPass() {}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto op = this->getOperation();
|
||||
|
||||
mlir::ConversionTarget target(getContext());
|
||||
mlir::RewritePatternSet patterns(&getContext());
|
||||
|
||||
// Mark ops from the target dialect as legal operations
|
||||
target.addLegalDialect<func::FuncDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
target.addLegalDialect<arith::ArithmeticDialect>();
|
||||
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
|
||||
|
||||
// Make sure that no ops from `Tracing` remain after the lowering
|
||||
target.addIllegalDialect<Tracing::TracingDialect>();
|
||||
|
||||
// Add patterns to transform Tracing operators to CAPI call
|
||||
patterns.add<TracingToCAPICallPattern<Tracing::TraceCiphertextOp,
|
||||
memref_trace_ciphertext>>(
|
||||
&getContext(), traceCiphertextAddOperands);
|
||||
patterns.add<TracingToCAPICallPattern<Tracing::TracePlaintextOp,
|
||||
memref_trace_plaintext>>(
|
||||
&getContext(), tracePlaintextAddOperands);
|
||||
patterns.add<TracingToCAPICallPattern<Tracing::TraceMessageOp,
|
||||
memref_trace_message>>(
|
||||
&getContext(), traceMessageAddOperands);
|
||||
|
||||
// Apply conversion
|
||||
if (mlir::applyPartialConversion(op, target, std::move(patterns))
|
||||
.failed()) {
|
||||
this->signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createConvertTracingToCAPIPass() {
|
||||
return std::make_unique<TracingToCAPIPass>();
|
||||
}
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
|
||||
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
|
||||
#include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include <mlir/IR/AffineExpr.h>
|
||||
#include <mlir/IR/AffineMap.h>
|
||||
@@ -30,6 +31,7 @@ using namespace mlir::tensor;
|
||||
namespace {
|
||||
|
||||
namespace BConcrete = mlir::concretelang::BConcrete;
|
||||
namespace Tracing = mlir::concretelang::Tracing;
|
||||
|
||||
template <typename TensorOp, typename MemrefOp>
|
||||
struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<
|
||||
|
||||
@@ -5,3 +5,4 @@ add_subdirectory(Concrete)
|
||||
add_subdirectory(BConcrete)
|
||||
add_subdirectory(RT)
|
||||
add_subdirectory(SDFG)
|
||||
add_subdirectory(Tracing)
|
||||
|
||||
@@ -26,6 +26,7 @@
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
|
||||
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
#include "concretelang/Support/V0Parameters.h"
|
||||
#include "concretelang/Support/logging.h"
|
||||
|
||||
@@ -120,7 +121,8 @@ struct FunctionToDag {
|
||||
auto encrypted_inputs = encryptedInputs(op);
|
||||
if (!hasEncryptedResult(op)) {
|
||||
// This op is unrelated to FHE
|
||||
assert(encrypted_inputs.empty());
|
||||
assert(encrypted_inputs.empty() ||
|
||||
mlir::isa<mlir::concretelang::Tracing::TraceCiphertextOp>(op));
|
||||
return;
|
||||
}
|
||||
assert(op.getNumResults() == 1);
|
||||
|
||||
2
compiler/lib/Dialect/Tracing/CMakeLists.txt
Normal file
2
compiler/lib/Dialect/Tracing/CMakeLists.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
add_subdirectory(IR)
|
||||
add_subdirectory(Transforms)
|
||||
13
compiler/lib/Dialect/Tracing/IR/CMakeLists.txt
Normal file
13
compiler/lib/Dialect/Tracing/IR/CMakeLists.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
add_mlir_dialect_library(
|
||||
TracingDialect
|
||||
TracingDialect.cpp
|
||||
TracingOps.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRIR)
|
||||
|
||||
target_link_libraries(TracingDialect PUBLIC MLIRIR)
|
||||
20
compiler/lib/Dialect/Tracing/IR/TracingDialect.cpp
Normal file
20
compiler/lib/Dialect/Tracing/IR/TracingDialect.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.cpp.inc"
|
||||
|
||||
#include "concretelang/Support/Constants.h"
|
||||
|
||||
using namespace mlir::concretelang::Tracing;
|
||||
|
||||
void TracingDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc"
|
||||
>();
|
||||
}
|
||||
20
compiler/lib/Dialect/Tracing/IR/TracingOps.cpp
Normal file
20
compiler/lib/Dialect/Tracing/IR/TracingOps.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/IR/Region.h"
|
||||
|
||||
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
|
||||
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace Tracing {} // namespace Tracing
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc"
|
||||
@@ -0,0 +1,89 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
|
||||
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
|
||||
#include <mlir/Dialect/Func/IR/FuncOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Transforms/RegionUtils.h>
|
||||
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
|
||||
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::bufferization;
|
||||
using namespace mlir::concretelang::Tracing;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename Op>
|
||||
struct TrivialBufferizableInterface
|
||||
: public BufferizableOpInterface::ExternalModel<
|
||||
TrivialBufferizableInterface<Op>, Op> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
|
||||
const AnalysisState &state) const {
|
||||
return {};
|
||||
}
|
||||
|
||||
BufferRelation bufferRelation(Operation *op, OpResult opResult,
|
||||
const AnalysisState &state) const {
|
||||
return BufferRelation::None;
|
||||
}
|
||||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationOptions &options) const {
|
||||
|
||||
mlir::SmallVector<mlir::Value> operands;
|
||||
for (auto &operand : op->getOpOperands()) {
|
||||
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
|
||||
operands.push_back(operand.get());
|
||||
} else {
|
||||
operands.push_back(
|
||||
bufferization::getBuffer(rewriter, operand.get(), options));
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<Op>(op, mlir::TypeRange{}, operands,
|
||||
op->getAttrs());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace Tracing {
|
||||
void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) {
|
||||
registry.addExtension(+[](MLIRContext *ctx, TracingDialect *dialect) {
|
||||
// trace_ciphretext
|
||||
Tracing::TraceCiphertextOp::attachInterface<
|
||||
TrivialBufferizableInterface<Tracing::TraceCiphertextOp>>(*ctx);
|
||||
// trace_plaintext
|
||||
Tracing::TracePlaintextOp::attachInterface<
|
||||
TrivialBufferizableInterface<Tracing::TracePlaintextOp>>(*ctx);
|
||||
// trace_message
|
||||
Tracing::TraceMessageOp::attachInterface<
|
||||
TrivialBufferizableInterface<Tracing::TraceMessageOp>>(*ctx);
|
||||
});
|
||||
}
|
||||
} // namespace Tracing
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
16
compiler/lib/Dialect/Tracing/Transforms/CMakeLists.txt
Normal file
16
compiler/lib/Dialect/Tracing/Transforms/CMakeLists.txt
Normal file
@@ -0,0 +1,16 @@
|
||||
add_mlir_dialect_library(
|
||||
TracingDialectTransforms
|
||||
BufferizableOpInterfaceImpl.cpp
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
|
||||
DEPENDS
|
||||
mlir-headers
|
||||
LINK_LIBS
|
||||
PUBLIC
|
||||
MLIRArithmeticDialect
|
||||
MLIRBufferizationDialect
|
||||
MLIRBufferizationTransforms
|
||||
MLIRIR
|
||||
MLIRMemRefDialect
|
||||
MLIRPass
|
||||
MLIRTransforms)
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/seeder.h"
|
||||
#include <assert.h>
|
||||
#include <bitset>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
@@ -658,3 +659,32 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
|
||||
src_aligned[src_offset + i * src_stride];
|
||||
}
|
||||
}
|
||||
|
||||
void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned,
|
||||
uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, char *message_ptr,
|
||||
uint32_t message_len, uint32_t msb) {
|
||||
std::string message{message_ptr, (size_t)message_len};
|
||||
std::cout << message << " : ";
|
||||
std::bitset<64> bits{ct0_aligned[ct0_offset + ct0_size - 1]};
|
||||
std::string bitstring = bits.to_string();
|
||||
bitstring.insert(msb, 1, ' ');
|
||||
std::cout << bitstring << std::endl;
|
||||
}
|
||||
|
||||
void memref_trace_plaintext(uint64_t input, uint64_t input_width,
|
||||
char *message_ptr, uint32_t message_len,
|
||||
uint32_t msb) {
|
||||
std::string message{message_ptr, (size_t)message_len};
|
||||
std::cout << message << " : ";
|
||||
std::bitset<64> bits{input};
|
||||
std::string bitstring = bits.to_string();
|
||||
bitstring.erase(0, 64 - input_width);
|
||||
bitstring.insert(msb, 1, ' ');
|
||||
std::cout << bitstring << std::endl;
|
||||
}
|
||||
|
||||
void memref_trace_message(char *message_ptr, uint32_t message_len) {
|
||||
std::string message{message_ptr, (size_t)message_len};
|
||||
std::cout << message << std::endl;
|
||||
}
|
||||
|
||||
@@ -38,6 +38,8 @@
|
||||
#include <concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
||||
#include <concretelang/Dialect/Tracing/IR/TracingDialect.h>
|
||||
#include <concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
@@ -73,6 +75,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
if (this->mlirContext == nullptr) {
|
||||
mlir::DialectRegistry registry;
|
||||
registry.insert<
|
||||
mlir::concretelang::Tracing::TracingDialect,
|
||||
mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect,
|
||||
mlir::concretelang::TFHE::TFHEDialect,
|
||||
mlir::concretelang::FHELinalg::FHELinalgDialect,
|
||||
@@ -83,6 +86,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
|
||||
mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect,
|
||||
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
|
||||
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
Tracing::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
|
||||
SDFG::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
arith::registerBufferizableOpInterfaceExternalModels(registry);
|
||||
|
||||
@@ -398,6 +398,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu),
|
||||
enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(
|
||||
|
||||
10
compiler/tests/check_tests/Tracing/trace_ciphertext_crt.mlir
Normal file
10
compiler/tests/check_tests/Tracing/trace_ciphertext_crt.mlir
Normal file
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: : [[BODY1:[01]{64}]]
|
||||
// CHECK-NEXT: : [[BODY2:[01]{64}]]
|
||||
// CHECK-NEXT: : [[BODY3:[01]{64}]]
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: : [[BODY:[01]{64}]]
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: Test : [[BODY01:[01]{3}]] [[BODY02:[01]{61}]]
|
||||
// CHECK-NEXT: Test : [[BODY11:[01]{3}]] [[BODY12:[01]{61}]]
|
||||
// CHECK-NEXT: Test : [[BODY21:[01]{3}]] [[BODY22:[01]{61}]]
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
"Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: Test : [[BODY:[01]{3}]] [[BODY2:[01]{61}]]
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
"Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
8
compiler/tests/check_tests/Tracing/trace_message.mlir
Normal file
8
compiler/tests/check_tests/Tracing/trace_message.mlir
Normal file
@@ -0,0 +1,8 @@
|
||||
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: Arbitrary message
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
"Tracing.trace_message"(){msg="Arbitrary message"}: () -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
9
compiler/tests/check_tests/Tracing/trace_plaintext.mlir
Normal file
9
compiler/tests/check_tests/Tracing/trace_plaintext.mlir
Normal file
@@ -0,0 +1,9 @@
|
||||
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: : 00000100
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = arith.constant 4 : i8
|
||||
"Tracing.trace_plaintext"(%0): (i8) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK: Test : 00000100
|
||||
// CHECK-NEXT: 1
|
||||
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
|
||||
%0 = arith.constant 4 : i8
|
||||
"Tracing.trace_plaintext"(%0){msg="Test"}: (i8) -> ()
|
||||
return %arg0: !FHE.eint<5>
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
description: trace_ciphertext_without_attributes
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> {
|
||||
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<1>) -> ()
|
||||
return %arg0: !FHE.eint<1>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_ciphertext_without_attributes_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<16>) -> ()
|
||||
return %arg0: !FHE.eint<16>
|
||||
}
|
||||
encoding: crt
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_ciphertextwith_attributes
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> {
|
||||
"Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<1>) -> ()
|
||||
return %arg0: !FHE.eint<1>
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_ciphertext_with_attributes_16bits
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
|
||||
"Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<16>) -> ()
|
||||
return %arg0: !FHE.eint<16>
|
||||
}
|
||||
encoding: crt
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_plaintext_without_attributes
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
|
||||
%c0 = arith.constant 1 : i8
|
||||
"Tracing.trace_plaintext"(%c0): (i8) -> ()
|
||||
%c1 = arith.constant 1 : i64
|
||||
return %c1: i64
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_plaintext_with_attributes
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
|
||||
%c0 = arith.constant 1 : i8
|
||||
"Tracing.trace_plaintext"(%c0){nmsb=3:i32, msg="test"}: (i8) -> ()
|
||||
%c1 = arith.constant 1 : i64
|
||||
return %c1: i64
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
---
|
||||
description: trace_message
|
||||
program: |
|
||||
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
|
||||
%c0 = arith.constant 1 : i8
|
||||
"Tracing.trace_plaintext"(%c0): (i8) -> ()
|
||||
"Tracing.trace_message"(){msg="Test"}: () -> ()
|
||||
%c1 = arith.constant 2 : i8
|
||||
"Tracing.trace_plaintext"(%c1): (i8) -> ()
|
||||
%c2 = arith.constant 1 : i64
|
||||
return %c2: i64
|
||||
}
|
||||
tests:
|
||||
- inputs:
|
||||
- scalar: 1
|
||||
outputs:
|
||||
- scalar: 1
|
||||
Reference in New Issue
Block a user