diff --git a/compiler/include/concretelang/Conversion/Passes.h b/compiler/include/concretelang/Conversion/Passes.h index c0b2d5d3b..6cdfb800b 100644 --- a/compiler/include/concretelang/Conversion/Passes.h +++ b/compiler/include/concretelang/Conversion/Passes.h @@ -24,11 +24,13 @@ #include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h" #include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h" #include "concretelang/Conversion/TFHEToConcrete/Pass.h" +#include "concretelang/Conversion/TracingToCAPI/Pass.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/SDFG/IR/SDFGDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" +#include "concretelang/Dialect/Tracing/IR/TracingDialect.h" #define GEN_PASS_CLASSES #include "concretelang/Conversion/Passes.h.inc" diff --git a/compiler/include/concretelang/Conversion/Passes.td b/compiler/include/concretelang/Conversion/Passes.td index 2280789dc..2fb409a0b 100644 --- a/compiler/include/concretelang/Conversion/Passes.td +++ b/compiler/include/concretelang/Conversion/Passes.td @@ -69,6 +69,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> { let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"]; } +def TracingToCAPI : Pass<"tracing-to-capi", "mlir::ModuleOp"> { + let summary = "Lowers operations from the Tracing dialect to CAPI calls"; + let description = [{ Lowers operations from the Tracing dialect to CAPI calls }]; + let constructor = "mlir::concretelang::createConvertTracingToCAPIPass()"; + let dependentDialects = ["mlir::concretelang::Tracing::TracingDialect"]; +} + def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> { let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls"; let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }]; diff --git a/compiler/include/concretelang/Conversion/TracingToCAPI/Pass.h b/compiler/include/concretelang/Conversion/TracingToCAPI/Pass.h new file mode 100644 index 000000000..54eb0e6e1 --- /dev/null +++ b/compiler/include/concretelang/Conversion/TracingToCAPI/Pass.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_ +#define ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_ + +#include "mlir/Pass/Pass.h" + +namespace mlir { +namespace concretelang { +/// Create a pass to convert `Tracing` dialect to CAPI calls. +std::unique_ptr> createConvertTracingToCAPIPass(); +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/CMakeLists.txt b/compiler/include/concretelang/Dialect/CMakeLists.txt index f05954ed4..3889967e6 100644 --- a/compiler/include/concretelang/Dialect/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(Concrete) add_subdirectory(BConcrete) add_subdirectory(RT) add_subdirectory(SDFG) +add_subdirectory(Tracing) diff --git a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h index e58216dc4..9f7d2a1db 100644 --- a/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h +++ b/compiler/include/concretelang/Dialect/FHE/IR/FHEOps.h @@ -6,6 +6,7 @@ #ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H #define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H +#include #include #include #include diff --git a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td index cd952e1aa..d7d6b4ce4 100644 --- a/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td +++ b/compiler/include/concretelang/Dialect/TFHE/IR/TFHEOps.td @@ -179,4 +179,5 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> { let results = (outs Type.predicate, HasStaticShapePred]>>:$result); } + #endif diff --git a/compiler/include/concretelang/Dialect/Tracing/CMakeLists.txt b/compiler/include/concretelang/Dialect/Tracing/CMakeLists.txt new file mode 100644 index 000000000..f33061b2d --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/CMakeLists.txt b/compiler/include/concretelang/Dialect/Tracing/IR/CMakeLists.txt new file mode 100644 index 000000000..fc5a84445 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/IR/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_TARGET_DEFINITIONS TracingOps.td) +mlir_tablegen(TracingOps.h.inc -gen-op-decls) +mlir_tablegen(TracingOps.cpp.inc -gen-op-defs) +mlir_tablegen(TracingOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Tracing) +mlir_tablegen(TracingOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Tracing) +mlir_tablegen(TracingOpsDialect.h.inc -gen-dialect-decls -dialect=Tracing) +mlir_tablegen(TracingOpsDialect.cpp.inc -gen-dialect-defs -dialect=Tracing) +add_public_tablegen_target(MLIRTracingOpsIncGen) +add_dependencies(mlir-headers MLIRTracingOpsIncGen) diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.h b/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.h new file mode 100644 index 000000000..7f835e239 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H +#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" + +#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.h.inc" + +#endif diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.td b/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.td new file mode 100644 index 000000000..dd06ed131 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/IR/TracingDialect.td @@ -0,0 +1,23 @@ +//===- TracingDialect.td - Tracing dialect ----------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT +#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT + +include "mlir/IR/OpBase.td" + +def Tracing_Dialect : Dialect { + let name = "Tracing"; + let summary = "Tracing dialect"; + let description = [{ + A dialect to print program values at runtime. + }]; + let cppNamespace = "::mlir::concretelang::Tracing"; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.h b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.h new file mode 100644 index 000000000..564d54fe9 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.h @@ -0,0 +1,18 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H +#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H + +#include +#include +#include +#include +#include + +#define GET_OP_CLASSES +#include "concretelang/Dialect/Tracing/IR/TracingOps.h.inc" + +#endif diff --git a/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td new file mode 100644 index 000000000..186c83414 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/IR/TracingOps.td @@ -0,0 +1,57 @@ +//===- TracingOps.td - Tracing dialect ops ----------------*- tablegen +//-*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS +#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" + +include "concretelang/Dialect/Tracing/IR/TracingDialect.td" +include "concretelang/Dialect/FHE/IR/FHETypes.td" +include "concretelang/Dialect/TFHE/IR/TFHETypes.td" +include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td" + +class Tracing_Op traits = []> + : Op; + +def Tracing_TraceCiphertextOp : Tracing_Op<"trace_ciphertext"> { + let summary = "Prints a ciphertext."; + + let arguments = (ins + Type.predicate, + MemRefRankOf<[I64], [1]>.predicate + ]>>: $ciphertext, + OptionalAttr: $msg, + OptionalAttr: $nmsb + ); +} + +def Tracing_TracePlaintextOp : Tracing_Op<"trace_plaintext"> { + let summary = "Prints a plaintext."; + + let arguments = (ins + AnyInteger: $plaintext, + OptionalAttr: $msg, + OptionalAttr: $nmsb + ); +} + +def Tracing_TraceMessageOp : Tracing_Op<"trace_message"> { + let summary = "Prints a message."; + + let arguments = (ins OptionalAttr : $msg); +} + +#endif diff --git a/compiler/include/concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h b/compiler/include/concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h new file mode 100644 index 000000000..23c3c1294 --- /dev/null +++ b/compiler/include/concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h @@ -0,0 +1,19 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H +#define CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace concretelang { +namespace Tracing { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); +} // namespace Tracing +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Runtime/wrappers.h b/compiler/include/concretelang/Runtime/wrappers.h index 200786e59..387708ba7 100644 --- a/compiler/include/concretelang/Runtime/wrappers.h +++ b/compiler/include/concretelang/Runtime/wrappers.h @@ -215,5 +215,18 @@ void memref_batched_bootstrap_lwe_cuda_u64( uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size, uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision, mlir::concretelang::RuntimeContext *context); + +// Tracing //////////////////////////////////////////////////////////////////// +void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, char *message_ptr, + uint32_t message_len, uint32_t msb); + +void memref_trace_plaintext(uint64_t input, uint64_t input_width, + char *message_ptr, uint32_t message_len, + uint32_t msb); + +void memref_trace_message(char *message_ptr, uint32_t message_len); } + #endif diff --git a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp index b5f6ebe0a..de1729668 100644 --- a/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp +++ b/compiler/lib/Conversion/BConcreteToCAPI/BConcreteToCAPI.cpp @@ -48,6 +48,7 @@ char memref_encode_expand_lut_for_bootstrap[] = "memref_encode_expand_lut_for_bootstrap"; char memref_encode_expand_lut_for_woppbs[] = "memref_encode_expand_lut_for_woppbs"; +char memref_trace[] = "memref_trace"; mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, size_t rank) { @@ -185,6 +186,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI( {memref1DType, memref1DType, memref1DType, memref1DType, rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()}, {}); + } else if (funcName == memref_trace) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getI32Type(), rewriter.getI32Type()}, + {}); } else { op->emitError("unknwon external function") << funcName; return mlir::failure(); @@ -431,6 +438,7 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase { target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); // Make sure that no ops from `FHE` remain after the lowering target.addIllegalDialect(); diff --git a/compiler/lib/Conversion/CMakeLists.txt b/compiler/lib/Conversion/CMakeLists.txt index 999fc2107..e26073876 100644 --- a/compiler/lib/Conversion/CMakeLists.txt +++ b/compiler/lib/Conversion/CMakeLists.txt @@ -5,6 +5,7 @@ add_subdirectory(TFHEToConcrete) add_subdirectory(FHETensorOpsToLinalg) add_subdirectory(ConcreteToBConcrete) add_subdirectory(BConcreteToCAPI) +add_subdirectory(TracingToCAPI) add_subdirectory(SDFGToStreamEmulator) add_subdirectory(MLIRLowerableDialectsToLLVM) add_subdirectory(LinalgExtras) diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index ec32fd1cf..bdbce6ee9 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -36,6 +36,7 @@ #include "concretelang/Dialect/Concrete/IR/ConcreteOps.h" #include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" #include "concretelang/Dialect/RT/IR/RTOps.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -43,6 +44,7 @@ namespace Concrete = ::mlir::concretelang::Concrete; namespace BConcrete = ::mlir::concretelang::BConcrete; +namespace Tracing = ::mlir::concretelang::Tracing; namespace { struct ConcreteToBConcretePass @@ -986,6 +988,16 @@ void ConcreteToBConcretePass::runOnOperation() { converter.isLegal(op->getOperandTypes()); }); + // Conversion of Tracing dialect + patterns.add, + mlir::concretelang::GenericTypeConverterPattern< + Tracing::TracePlaintextOp>>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + // Conversion of RT Dialect Ops patterns.add< mlir::concretelang::GenericTypeConverterPattern, diff --git a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp index 1b1ac5f18..f27dccfa6 100644 --- a/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp +++ b/compiler/lib/Conversion/FHEToTFHECrt/FHEToTFHECrt.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" @@ -29,9 +28,12 @@ #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" +#include "concretelang/Dialect/Tracing/IR/TracingDialect.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" namespace FHE = mlir::concretelang::FHE; namespace TFHE = mlir::concretelang::TFHE; +namespace Tracing = mlir::concretelang::Tracing; namespace concretelang = mlir::concretelang; namespace fhe_to_tfhe_crt_conversion { @@ -583,6 +585,45 @@ struct ApplyLookupTableEintOpPattern }; }; +/// Rewriter for the `Tracing::trace_ciphertext` operation. +struct TraceCiphertextOpPattern : CrtOpPattern { + + TraceCiphertextOpPattern(mlir::MLIRContext *context, + concretelang::CrtLoweringParameters params, + mlir::PatternBenefit benefit = 1) + : CrtOpPattern(context, params, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(Tracing::TraceCiphertextOp op, + Tracing::TraceCiphertextOp::Adaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + + typing::TypeConverter converter{loweringParameters}; + mlir::Type ciphertextScalarType = + converter.convertType(op.ciphertext().getType()) + .cast() + .getElementType(); + + for (size_t i = 0; i < (loweringParameters.nMods - 1); ++i) { + auto extractedCiphertext = rewriter.create( + op.getLoc(), ciphertextScalarType, adaptor.ciphertext(), + mlir::ValueRange{rewriter.create( + op.getLoc(), rewriter.getIndexAttr(i))}); + rewriter.create( + op.getLoc(), extractedCiphertext, op.msgAttr(), op.nmsbAttr()); + } + + auto extractedCiphertext = rewriter.create( + op.getLoc(), ciphertextScalarType, adaptor.ciphertext(), + mlir::ValueRange{rewriter.create( + op.getLoc(), rewriter.getIndexAttr(loweringParameters.nMods - 1))}); + rewriter.replaceOpWithNewOp( + op, extractedCiphertext, op.msgAttr(), op.nmsbAttr()); + + return mlir::success(); + } +}; + /// Rewriter for the `tensor::extract` operation. struct TensorExtractOpPattern : public CrtOpPattern { @@ -924,6 +965,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { target, converter); concretelang::addDynamicallyLegalTypeOp( target, converter); + concretelang::addDynamicallyLegalTypeOp( + target, converter); concretelang::addDynamicallyLegalTypeOp< concretelang::RT::MakeReadyFutureOp>(target, converter); concretelang::addDynamicallyLegalTypeOp( @@ -1006,6 +1049,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase { loweringParameters); patterns.add(patterns.getContext(), loweringParameters); + patterns.add(patterns.getContext(), + loweringParameters); patterns.add>(&getContext(), converter); diff --git a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp index 301a43fec..801a92b33 100644 --- a/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp +++ b/compiler/lib/Conversion/FHEToTFHEScalar/FHEToTFHEScalar.cpp @@ -28,9 +28,11 @@ #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" namespace FHE = mlir::concretelang::FHE; namespace TFHE = mlir::concretelang::TFHE; +namespace Tracing = mlir::concretelang::Tracing; namespace concretelang = mlir::concretelang; namespace fhe_to_tfhe_scalar_conversion { @@ -546,10 +548,15 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase { converter); concretelang::addDynamicallyLegalTypeOp(target, converter); - patterns.add>( &getContext(), converter); + // Patterns for `tracing` dialect. + patterns.add>(&getContext(), converter); + concretelang::addDynamicallyLegalTypeOp( + target, converter); + // Patterns for `bufferization` dialect operations. patterns.add>(patterns.getContext(), diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 2e875d5c7..c656280c5 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -15,6 +15,7 @@ #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEOps.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/Constants.h" #include @@ -332,6 +333,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() { // Conversion of RT Dialect Ops patterns.add< + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::Tracing::TraceCiphertextOp>, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern< @@ -351,6 +354,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() { mlir::concretelang::GenericTypeConverterPattern< mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::Tracing::TraceCiphertextOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index e51d202ed..d80424cc3 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -19,9 +19,11 @@ #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" namespace TFHE = mlir::concretelang::TFHE; namespace Concrete = mlir::concretelang::Concrete; +namespace Tracing = mlir::concretelang::Tracing; namespace { struct TFHEToConcretePass : public TFHEToConcreteBase { @@ -136,6 +138,31 @@ private: mlir::TypeConverter &converter; }; +struct TracePlaintextOpPattern + : public mlir::OpRewritePattern { + TracePlaintextOpPattern(mlir::MLIRContext *context, + mlir::TypeConverter &converter, + mlir::PatternBenefit benefit = 100) + : mlir::OpRewritePattern(context, benefit) {} + + mlir::LogicalResult + matchAndRewrite(Tracing::TracePlaintextOp op, + mlir::PatternRewriter &rewriter) const override { + auto inputWidth = + op.plaintext().getType().cast().getWidth(); + if (inputWidth == 64) { + op->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth)); + return mlir::success(); + } + auto extendedInput = rewriter.create( + op.getLoc(), rewriter.getI64Type(), op.plaintext()); + auto newOp = rewriter.replaceOpWithNewOp( + op, extendedInput, op.msgAttr(), op.nmsbAttr()); + newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth)); + return ::mlir::success(); + } +}; + void TFHEToConcretePass::runOnOperation() { auto op = this->getOperation(); @@ -171,6 +198,7 @@ void TFHEToConcretePass::runOnOperation() { return FunctionConstantOpConversion< TFHEToConcreteTypeConverter>::isLegal(op, converter); }); + // Add all patterns required to lower all ops from `TFHE` to // `Concrete` mlir::RewritePatternSet patterns(&getContext()); @@ -228,6 +256,19 @@ void TFHEToConcretePass::runOnOperation() { mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); + // Conversion of Tracing dialect + patterns.add>(&getContext(), converter); + mlir::concretelang::addDynamicallyLegalTypeOp( + target, converter); + patterns.add(&getContext(), converter); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](Tracing::TracePlaintextOp op) { + return (op.plaintext().getType().cast().getWidth() == + 64); + }); + // Conversion of RT Dialect Ops patterns.add< mlir::concretelang::GenericTypeConverterPattern, diff --git a/compiler/lib/Conversion/TracingToCAPI/CMakeLists.txt b/compiler/lib/Conversion/TracingToCAPI/CMakeLists.txt new file mode 100644 index 000000000..92e6a92ce --- /dev/null +++ b/compiler/lib/Conversion/TracingToCAPI/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library( + TracingToCAPI + TracingToCAPI.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing + DEPENDS + TracingDialect + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR + MLIRTransforms) + +target_link_libraries(TracingToCAPI PUBLIC TracingDialect MLIRIR) diff --git a/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp b/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp new file mode 100644 index 000000000..c5202a6b5 --- /dev/null +++ b/compiler/lib/Conversion/TracingToCAPI/TracingToCAPI.cpp @@ -0,0 +1,238 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include + +#include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Tools.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" +#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" + +namespace { + +namespace Tracing = mlir::concretelang::Tracing; +namespace arith = mlir::arith; +namespace func = mlir::func; +namespace memref = mlir::memref; + +char memref_trace_ciphertext[] = "memref_trace_ciphertext"; +char memref_trace_plaintext[] = "memref_trace_plaintext"; +char memref_trace_message[] = "memref_trace_message"; + +mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter, + size_t rank) { + std::vector shape(rank, -1); + mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0); + for (size_t i = 0; i < rank; i++) { + expr = expr + + (rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1)); + } + return mlir::MemRefType::get( + shape, rewriter.getI64Type(), + mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext())); +} + +// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>` +mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) { + mlir::Type valueType = value.getType(); + + if (auto memrefTy = valueType.dyn_cast_or_null()) { + return rewriter.create( + value.getLoc(), + getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()), + value); + } else { + return value; + } +} + +mlir::LogicalResult insertForwardDeclarationOfTheCAPI( + mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) { + + auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1); + + mlir::FunctionType funcType; + + if (funcName == memref_trace_ciphertext) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getI32Type(), rewriter.getI32Type()}, + {}); + } else if (funcName == memref_trace_plaintext) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {rewriter.getI64Type(), rewriter.getI64Type(), + mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getI32Type(), rewriter.getI32Type()}, + {}); + } else if (funcName == memref_trace_message) { + funcType = mlir::FunctionType::get( + rewriter.getContext(), + {mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()), + rewriter.getI32Type()}, + {}); + } else { + op->emitError("unknwon external function") << funcName; + return mlir::failure(); + } + + return insertForwardDeclaration(op, rewriter, funcName, funcType); +} + +template +void addNoOperands(Op op, mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) {} + +template +struct TracingToCAPICallPattern : public mlir::OpRewritePattern { + TracingToCAPICallPattern( + ::mlir::MLIRContext *context, + std::function &, + mlir::RewriterBase &)> + addOperands = addNoOperands, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(context, benefit), + addOperands(addOperands) {} + + ::mlir::LogicalResult + matchAndRewrite(Op op, ::mlir::PatternRewriter &rewriter) const override { + + // Create the operands + mlir::SmallVector operands; + // For all tensor operand get the corresponding casted buffer + for (auto &operand : op->getOpOperands()) { + mlir::Type type = operand.get().getType(); + if (!type.isa()) { + operands.push_back(operand.get()); + } else { + operands.push_back(getCastedMemRef(rewriter, operand.get())); + } + } + + // append additional argument + addOperands(op, operands, rewriter); + + // Insert forward declaration of the function + if (insertForwardDeclarationOfTheCAPI(op, rewriter, callee).failed()) { + return mlir::failure(); + } + + rewriter.replaceOpWithNewOp(op, callee, mlir::TypeRange{}, + operands); + + return ::mlir::success(); + }; + +private: + std::function &, + mlir::RewriterBase &)> + addOperands; +}; + +void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + auto msg = op.msg().getValueOr(""); + auto nmsb = op.nmsb().getValueOr(0); + std::string msgName; + std::stringstream stream; + stream << rand(); + stream >> msgName; + auto messageVal = + mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg, + mlir::LLVM::linkage::Linkage::Linkonce); + operands.push_back(messageVal); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(msg.size()))); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(nmsb))); +} + +void tracePlaintextAddOperands(Tracing::TracePlaintextOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + auto msg = op.msg().getValueOr(""); + auto nmsb = op.nmsb().getValueOr(0); + std::string msgName; + std::stringstream stream; + stream << rand(); + stream >> msgName; + auto messageVal = + mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg, + mlir::LLVM::linkage::Linkage::Linkonce); + operands.push_back(rewriter.create( + op.getLoc(), op->getAttr("input_width"))); + operands.push_back(messageVal); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(msg.size()))); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(nmsb))); +} + +void traceMessageAddOperands(Tracing::TraceMessageOp op, + mlir::SmallVector &operands, + mlir::RewriterBase &rewriter) { + auto msg = op.msg().getValueOr(""); + std::string msgName; + std::stringstream stream; + stream << rand(); + stream >> msgName; + auto messageVal = + mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg, + mlir::LLVM::linkage::Linkage::Linkonce); + operands.push_back(messageVal); + operands.push_back(rewriter.create( + op.getLoc(), rewriter.getI32IntegerAttr(msg.size()))); +} + +struct TracingToCAPIPass : public TracingToCAPIBase { + + TracingToCAPIPass() {} + + void runOnOperation() override { + auto op = this->getOperation(); + + mlir::ConversionTarget target(getContext()); + mlir::RewritePatternSet patterns(&getContext()); + + // Mark ops from the target dialect as legal operations + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + + // Make sure that no ops from `Tracing` remain after the lowering + target.addIllegalDialect(); + + // Add patterns to transform Tracing operators to CAPI call + patterns.add>( + &getContext(), traceCiphertextAddOperands); + patterns.add>( + &getContext(), tracePlaintextAddOperands); + patterns.add>( + &getContext(), traceMessageAddOperands); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } +}; + +} // namespace + +namespace mlir { +namespace concretelang { +std::unique_ptr> createConvertTracingToCAPIPass() { + return std::make_unique(); +} +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index f95286660..b5c9b31ba 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -18,6 +18,7 @@ #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h" #include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/CompilerEngine.h" #include #include @@ -30,6 +31,7 @@ using namespace mlir::tensor; namespace { namespace BConcrete = mlir::concretelang::BConcrete; +namespace Tracing = mlir::concretelang::Tracing; template struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel< diff --git a/compiler/lib/Dialect/CMakeLists.txt b/compiler/lib/Dialect/CMakeLists.txt index d0ca5b248..76ea8788b 100644 --- a/compiler/lib/Dialect/CMakeLists.txt +++ b/compiler/lib/Dialect/CMakeLists.txt @@ -5,3 +5,4 @@ add_subdirectory(Concrete) add_subdirectory(BConcrete) add_subdirectory(RT) add_subdirectory(SDFG) +add_subdirectory(Tracing) diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp index a7738b555..69110193c 100644 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -26,6 +26,7 @@ #include "concretelang/Dialect/FHE/IR/FHEOps.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" #include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" #include "concretelang/Support/V0Parameters.h" #include "concretelang/Support/logging.h" @@ -120,7 +121,8 @@ struct FunctionToDag { auto encrypted_inputs = encryptedInputs(op); if (!hasEncryptedResult(op)) { // This op is unrelated to FHE - assert(encrypted_inputs.empty()); + assert(encrypted_inputs.empty() || + mlir::isa(op)); return; } assert(op.getNumResults() == 1); diff --git a/compiler/lib/Dialect/Tracing/CMakeLists.txt b/compiler/lib/Dialect/Tracing/CMakeLists.txt new file mode 100644 index 000000000..9f57627c3 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/compiler/lib/Dialect/Tracing/IR/CMakeLists.txt b/compiler/lib/Dialect/Tracing/IR/CMakeLists.txt new file mode 100644 index 000000000..f09c37a60 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library( + TracingDialect + TracingDialect.cpp + TracingOps.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRIR) + +target_link_libraries(TracingDialect PUBLIC MLIRIR) diff --git a/compiler/lib/Dialect/Tracing/IR/TracingDialect.cpp b/compiler/lib/Dialect/Tracing/IR/TracingDialect.cpp new file mode 100644 index 000000000..6b05fffc4 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/IR/TracingDialect.cpp @@ -0,0 +1,20 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/Tracing/IR/TracingDialect.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" + +#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.cpp.inc" + +#include "concretelang/Support/Constants.h" + +using namespace mlir::concretelang::Tracing; + +void TracingDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc" + >(); +} diff --git a/compiler/lib/Dialect/Tracing/IR/TracingOps.cpp b/compiler/lib/Dialect/Tracing/IR/TracingOps.cpp new file mode 100644 index 000000000..4fced8204 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/IR/TracingOps.cpp @@ -0,0 +1,20 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/IR/Region.h" + +#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h" +#include "concretelang/Dialect/FHE/IR/FHEOps.h" +#include "concretelang/Dialect/TFHE/IR/TFHETypes.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" + +namespace mlir { +namespace concretelang { +namespace Tracing {} // namespace Tracing +} // namespace concretelang +} // namespace mlir + +#define GET_OP_CLASSES +#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc" diff --git a/compiler/lib/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.cpp new file mode 100644 index 000000000..f32aa8e76 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.cpp @@ -0,0 +1,89 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include +#include +#include +#include + +#include "concretelang/Dialect/Tracing/IR/TracingDialect.h" +#include "concretelang/Dialect/Tracing/IR/TracingOps.h" + +using namespace mlir; +using namespace mlir::bufferization; +using namespace mlir::concretelang::Tracing; + +namespace { + +template +struct TrivialBufferizableInterface + : public BufferizableOpInterface::ExternalModel< + TrivialBufferizableInterface, Op> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return true; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + const BufferizationOptions &options) const { + + mlir::SmallVector operands; + for (auto &operand : op->getOpOperands()) { + if (!operand.get().getType().isa()) { + operands.push_back(operand.get()); + } else { + operands.push_back( + bufferization::getBuffer(rewriter, operand.get(), options)); + } + } + + rewriter.replaceOpWithNewOp(op, mlir::TypeRange{}, operands, + op->getAttrs()); + + return success(); + } +}; + +} // namespace + +namespace mlir { +namespace concretelang { +namespace Tracing { +void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, TracingDialect *dialect) { + // trace_ciphretext + Tracing::TraceCiphertextOp::attachInterface< + TrivialBufferizableInterface>(*ctx); + // trace_plaintext + Tracing::TracePlaintextOp::attachInterface< + TrivialBufferizableInterface>(*ctx); + // trace_message + Tracing::TraceMessageOp::attachInterface< + TrivialBufferizableInterface>(*ctx); + }); +} +} // namespace Tracing +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/Tracing/Transforms/CMakeLists.txt b/compiler/lib/Dialect/Tracing/Transforms/CMakeLists.txt new file mode 100644 index 000000000..66eab1ba4 --- /dev/null +++ b/compiler/lib/Dialect/Tracing/Transforms/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library( + TracingDialectTransforms + BufferizableOpInterfaceImpl.cpp + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing + DEPENDS + mlir-headers + LINK_LIBS + PUBLIC + MLIRArithmeticDialect + MLIRBufferizationDialect + MLIRBufferizationTransforms + MLIRIR + MLIRMemRefDialect + MLIRPass + MLIRTransforms) diff --git a/compiler/lib/Runtime/wrappers.cpp b/compiler/lib/Runtime/wrappers.cpp index 3dc732f2e..3fc914557 100644 --- a/compiler/lib/Runtime/wrappers.cpp +++ b/compiler/lib/Runtime/wrappers.cpp @@ -7,6 +7,7 @@ #include "concretelang/Common/Error.h" #include "concretelang/Runtime/seeder.h" #include +#include #include #include #include @@ -658,3 +659,32 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned, src_aligned[src_offset + i * src_stride]; } } + +void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned, + uint64_t ct0_offset, uint64_t ct0_size, + uint64_t ct0_stride, char *message_ptr, + uint32_t message_len, uint32_t msb) { + std::string message{message_ptr, (size_t)message_len}; + std::cout << message << " : "; + std::bitset<64> bits{ct0_aligned[ct0_offset + ct0_size - 1]}; + std::string bitstring = bits.to_string(); + bitstring.insert(msb, 1, ' '); + std::cout << bitstring << std::endl; +} + +void memref_trace_plaintext(uint64_t input, uint64_t input_width, + char *message_ptr, uint32_t message_len, + uint32_t msb) { + std::string message{message_ptr, (size_t)message_len}; + std::cout << message << " : "; + std::bitset<64> bits{input}; + std::string bitstring = bits.to_string(); + bitstring.erase(0, 64 - input_width); + bitstring.insert(msb, 1, ' '); + std::cout << bitstring << std::endl; +} + +void memref_trace_message(char *message_ptr, uint32_t message_len) { + std::string message{message_ptr, (size_t)message_len}; + std::cout << message << std::endl; +} diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 6c9199d60..78d338503 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -38,6 +38,8 @@ #include #include #include +#include +#include #include #include #include @@ -73,6 +75,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { if (this->mlirContext == nullptr) { mlir::DialectRegistry registry; registry.insert< + mlir::concretelang::Tracing::TracingDialect, mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect, mlir::concretelang::TFHE::TFHEDialect, mlir::concretelang::FHELinalg::FHELinalgDialect, @@ -83,6 +86,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() { mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect, mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>(); BConcrete::registerBufferizableOpInterfaceExternalModels(registry); + Tracing::registerBufferizableOpInterfaceExternalModels(registry); SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry); SDFG::registerBufferizableOpInterfaceExternalModels(registry); arith::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 123d79faf..a82802ab3 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -398,6 +398,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass); // Convert to MLIR LLVM Dialect addPotentiallyNestedPass( diff --git a/compiler/tests/check_tests/Tracing/trace_ciphertext_crt.mlir b/compiler/tests/check_tests/Tracing/trace_ciphertext_crt.mlir new file mode 100644 index 000000000..8862a707a --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_ciphertext_crt.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: : [[BODY1:[01]{64}]] +// CHECK-NEXT: : [[BODY2:[01]{64}]] +// CHECK-NEXT: : [[BODY3:[01]{64}]] +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + "Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_ciphertext_native.mlir b/compiler/tests/check_tests/Tracing/trace_ciphertext_native.mlir new file mode 100644 index 000000000..de8d4d7b9 --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_ciphertext_native.mlir @@ -0,0 +1,8 @@ +// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: : [[BODY:[01]{64}]] +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + "Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_crt.mlir b/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_crt.mlir new file mode 100644 index 000000000..573a277c7 --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_crt.mlir @@ -0,0 +1,10 @@ +// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: Test : [[BODY01:[01]{3}]] [[BODY02:[01]{61}]] +// CHECK-NEXT: Test : [[BODY11:[01]{3}]] [[BODY12:[01]{61}]] +// CHECK-NEXT: Test : [[BODY21:[01]{3}]] [[BODY22:[01]{61}]] +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + "Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_native.mlir b/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_native.mlir new file mode 100644 index 000000000..282fb74e5 --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_ciphertext_with_args_native.mlir @@ -0,0 +1,8 @@ +// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: Test : [[BODY:[01]{3}]] [[BODY2:[01]{61}]] +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + "Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_message.mlir b/compiler/tests/check_tests/Tracing/trace_message.mlir new file mode 100644 index 000000000..4973ef98e --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_message.mlir @@ -0,0 +1,8 @@ +// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: Arbitrary message +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + "Tracing.trace_message"(){msg="Arbitrary message"}: () -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_plaintext.mlir b/compiler/tests/check_tests/Tracing/trace_plaintext.mlir new file mode 100644 index 000000000..0e90917b7 --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_plaintext.mlir @@ -0,0 +1,9 @@ +// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: : 00000100 +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = arith.constant 4 : i8 + "Tracing.trace_plaintext"(%0): (i8) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/check_tests/Tracing/trace_plaintext_with_args.mlir b/compiler/tests/check_tests/Tracing/trace_plaintext_with_args.mlir new file mode 100644 index 000000000..fa40deb05 --- /dev/null +++ b/compiler/tests/check_tests/Tracing/trace_plaintext_with_args.mlir @@ -0,0 +1,9 @@ +// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s + +// CHECK: Test : 00000100 +// CHECK-NEXT: 1 +func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> { + %0 = arith.constant 4 : i8 + "Tracing.trace_plaintext"(%0){msg="Test"}: (i8) -> () + return %arg0: !FHE.eint<5> +} diff --git a/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_trace.yaml b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_trace.yaml new file mode 100644 index 000000000..56a05a2a3 --- /dev/null +++ b/compiler/tests/end_to_end_fixture/tests_cpu/end_to_end_trace.yaml @@ -0,0 +1,94 @@ +description: trace_ciphertext_without_attributes +program: | + func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> { + "Tracing.trace_ciphertext"(%arg0): (!FHE.eint<1>) -> () + return %arg0: !FHE.eint<1> + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_ciphertext_without_attributes_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + "Tracing.trace_ciphertext"(%arg0): (!FHE.eint<16>) -> () + return %arg0: !FHE.eint<16> + } +encoding: crt +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_ciphertextwith_attributes +program: | + func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> { + "Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<1>) -> () + return %arg0: !FHE.eint<1> + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_ciphertext_with_attributes_16bits +program: | + func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> { + "Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<16>) -> () + return %arg0: !FHE.eint<16> + } +encoding: crt +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_plaintext_without_attributes +program: | + func.func @main(%arg0: !FHE.eint<1>) -> i64 { + %c0 = arith.constant 1 : i8 + "Tracing.trace_plaintext"(%c0): (i8) -> () + %c1 = arith.constant 1 : i64 + return %c1: i64 + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_plaintext_with_attributes +program: | + func.func @main(%arg0: !FHE.eint<1>) -> i64 { + %c0 = arith.constant 1 : i8 + "Tracing.trace_plaintext"(%c0){nmsb=3:i32, msg="test"}: (i8) -> () + %c1 = arith.constant 1 : i64 + return %c1: i64 + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1 +--- +description: trace_message +program: | + func.func @main(%arg0: !FHE.eint<1>) -> i64 { + %c0 = arith.constant 1 : i8 + "Tracing.trace_plaintext"(%c0): (i8) -> () + "Tracing.trace_message"(){msg="Test"}: () -> () + %c1 = arith.constant 2 : i8 + "Tracing.trace_plaintext"(%c1): (i8) -> () + %c2 = arith.constant 1 : i64 + return %c2: i64 + } +tests: + - inputs: + - scalar: 1 + outputs: + - scalar: 1