feat(concrete-compiler): adds a tracing op in all dialects.

This commit is contained in:
Alexandre Péré
2023-02-17 09:04:42 +01:00
committed by GitHub
parent 5d3af16617
commit 52ad40c9cf
43 changed files with 919 additions and 3 deletions

View File

@@ -24,11 +24,13 @@
#include "concretelang/Conversion/SDFGToStreamEmulator/Pass.h"
#include "concretelang/Conversion/TFHEGlobalParametrization/Pass.h"
#include "concretelang/Conversion/TFHEToConcrete/Pass.h"
#include "concretelang/Conversion/TracingToCAPI/Pass.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
#include "concretelang/Dialect/FHE/IR/FHEDialect.h"
#include "concretelang/Dialect/SDFG/IR/SDFGDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#define GEN_PASS_CLASSES
#include "concretelang/Conversion/Passes.h.inc"

View File

@@ -69,6 +69,13 @@ def BConcreteToCAPI : Pass<"bconcrete-to-capi", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::concretelang::BConcrete::BConcreteDialect"];
}
def TracingToCAPI : Pass<"tracing-to-capi", "mlir::ModuleOp"> {
let summary = "Lowers operations from the Tracing dialect to CAPI calls";
let description = [{ Lowers operations from the Tracing dialect to CAPI calls }];
let constructor = "mlir::concretelang::createConvertTracingToCAPIPass()";
let dependentDialects = ["mlir::concretelang::Tracing::TracingDialect"];
}
def SDFGToStreamEmulator : Pass<"sdfg-to-stream-emulator", "mlir::ModuleOp"> {
let summary = "Lowers operations from the SDFG dialect to Stream Emulator calls";
let description = [{ Lowers operations from the SDFG dialect to Stream Emulator calls }];

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
#define ZAMALANG_CONVERSION_TRACINGTOCAPI_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace concretelang {
/// Create a pass to convert `Tracing` dialect to CAPI calls.
std::unique_ptr<OperationPass<ModuleOp>> createConvertTracingToCAPIPass();
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -5,3 +5,4 @@ add_subdirectory(Concrete)
add_subdirectory(BConcrete)
add_subdirectory(RT)
add_subdirectory(SDFG)
add_subdirectory(Tracing)

View File

@@ -6,6 +6,7 @@
#ifndef CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
#define CONCRETELANG_DIALECT_FHE_IR_FHEOPS_H
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>

View File

@@ -179,4 +179,5 @@ def TFHE_WopPBSGLWEOp : TFHE_Op<"wop_pbs_glwe"> {
let results = (outs Type<And<[TensorOf<[TFHE_GLWECipherTextType]>.predicate, HasStaticShapePred]>>:$result);
}
#endif

View File

@@ -0,0 +1 @@
add_subdirectory(IR)

View File

@@ -0,0 +1,9 @@
set(LLVM_TARGET_DEFINITIONS TracingOps.td)
mlir_tablegen(TracingOps.h.inc -gen-op-decls)
mlir_tablegen(TracingOps.cpp.inc -gen-op-defs)
mlir_tablegen(TracingOpsTypes.h.inc -gen-typedef-decls -typedefs-dialect=Tracing)
mlir_tablegen(TracingOpsTypes.cpp.inc -gen-typedef-defs -typedefs-dialect=Tracing)
mlir_tablegen(TracingOpsDialect.h.inc -gen-dialect-decls -dialect=Tracing)
mlir_tablegen(TracingOpsDialect.cpp.inc -gen-dialect-defs -dialect=Tracing)
add_public_tablegen_target(MLIRTracingOpsIncGen)
add_dependencies(mlir-headers MLIRTracingOpsIncGen)

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H
#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGDIALECT_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.h.inc"
#endif

View File

@@ -0,0 +1,23 @@
//===- TracingDialect.td - Tracing dialect ----------------*- tablegen -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT
#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_DIALECT
include "mlir/IR/OpBase.td"
def Tracing_Dialect : Dialect {
let name = "Tracing";
let summary = "Tracing dialect";
let description = [{
A dialect to print program values at runtime.
}];
let cppNamespace = "::mlir::concretelang::Tracing";
}
#endif

View File

@@ -0,0 +1,18 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H
#define CONCRETELANG_DIALECT_TRACING_IR_TRACINGOPS_H
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#define GET_OP_CLASSES
#include "concretelang/Dialect/Tracing/IR/TracingOps.h.inc"
#endif

View File

@@ -0,0 +1,57 @@
//===- TracingOps.td - Tracing dialect ops ----------------*- tablegen
//-*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS
#define CONCRETELANG_DIALECT_TRACING_IR_TRACING_OPS
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "concretelang/Dialect/Tracing/IR/TracingDialect.td"
include "concretelang/Dialect/FHE/IR/FHETypes.td"
include "concretelang/Dialect/TFHE/IR/TFHETypes.td"
include "concretelang/Dialect/Concrete/IR/ConcreteTypes.td"
class Tracing_Op<string mnemonic, list<Trait> traits = []>
: Op<Tracing_Dialect, mnemonic, traits>;
def Tracing_TraceCiphertextOp : Tracing_Op<"trace_ciphertext"> {
let summary = "Prints a ciphertext.";
let arguments = (ins
Type<Or<[
FHE_EncryptedIntegerType.predicate,
FHE_EncryptedSignedIntegerType.predicate,
TFHE_GLWECipherTextType.predicate,
Concrete_LweCiphertextType.predicate,
1DTensorOf<[I64]>.predicate,
MemRefRankOf<[I64], [1]>.predicate
]>>: $ciphertext,
OptionalAttr<StrAttr>: $msg,
OptionalAttr<I32Attr>: $nmsb
);
}
def Tracing_TracePlaintextOp : Tracing_Op<"trace_plaintext"> {
let summary = "Prints a plaintext.";
let arguments = (ins
AnyInteger: $plaintext,
OptionalAttr<StrAttr>: $msg,
OptionalAttr<I32Attr>: $nmsb
);
}
def Tracing_TraceMessageOp : Tracing_Op<"trace_message"> {
let summary = "Prints a message.";
let arguments = (ins OptionalAttr<StrAttr> : $msg);
}
#endif

View File

@@ -0,0 +1,19 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H
#define CONCRETELANG_DIALECT_TRACING_BUFFERIZABLEOPINTERFACEIMPL_H
namespace mlir {
class DialectRegistry;
namespace concretelang {
namespace Tracing {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry);
} // namespace Tracing
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -215,5 +215,18 @@ void memref_batched_bootstrap_lwe_cuda_u64(
uint64_t tlu_stride, uint32_t input_lwe_dim, uint32_t poly_size,
uint32_t level, uint32_t base_log, uint32_t glwe_dim, uint32_t precision,
mlir::concretelang::RuntimeContext *context);
// Tracing ////////////////////////////////////////////////////////////////////
void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, char *message_ptr,
uint32_t message_len, uint32_t msb);
void memref_trace_plaintext(uint64_t input, uint64_t input_width,
char *message_ptr, uint32_t message_len,
uint32_t msb);
void memref_trace_message(char *message_ptr, uint32_t message_len);
}
#endif

View File

@@ -48,6 +48,7 @@ char memref_encode_expand_lut_for_bootstrap[] =
"memref_encode_expand_lut_for_bootstrap";
char memref_encode_expand_lut_for_woppbs[] =
"memref_encode_expand_lut_for_woppbs";
char memref_trace[] = "memref_trace";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
@@ -185,6 +186,12 @@ mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
{memref1DType, memref1DType, memref1DType, memref1DType,
rewriter.getI32Type(), rewriter.getI32Type(), rewriter.getI1Type()},
{});
} else if (funcName == memref_trace) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
@@ -431,6 +438,7 @@ struct BConcreteToCAPIPass : public BConcreteToCAPIBase<BConcreteToCAPIPass> {
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// Make sure that no ops from `FHE` remain after the lowering
target.addIllegalDialect<BConcrete::BConcreteDialect>();

View File

@@ -5,6 +5,7 @@ add_subdirectory(TFHEToConcrete)
add_subdirectory(FHETensorOpsToLinalg)
add_subdirectory(ConcreteToBConcrete)
add_subdirectory(BConcreteToCAPI)
add_subdirectory(TracingToCAPI)
add_subdirectory(SDFGToStreamEmulator)
add_subdirectory(MLIRLowerableDialectsToLLVM)
add_subdirectory(LinalgExtras)

View File

@@ -36,6 +36,7 @@
#include "concretelang/Dialect/Concrete/IR/ConcreteOps.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -43,6 +44,7 @@
namespace Concrete = ::mlir::concretelang::Concrete;
namespace BConcrete = ::mlir::concretelang::BConcrete;
namespace Tracing = ::mlir::concretelang::Tracing;
namespace {
struct ConcreteToBConcretePass
@@ -986,6 +988,16 @@ void ConcreteToBConcretePass::runOnOperation() {
converter.isLegal(op->getOperandTypes());
});
// Conversion of Tracing dialect
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
Tracing::TraceCiphertextOp>,
mlir::concretelang::GenericTypeConverterPattern<
Tracing::TracePlaintextOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TracePlaintextOp>(
target, converter);
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,

View File

@@ -7,7 +7,6 @@
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/Dialect/Linalg/IR/Linalg.h>
#include <mlir/IR/Operation.h>
#include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h"
@@ -29,9 +28,12 @@
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace Tracing = mlir::concretelang::Tracing;
namespace concretelang = mlir::concretelang;
namespace fhe_to_tfhe_crt_conversion {
@@ -583,6 +585,45 @@ struct ApplyLookupTableEintOpPattern
};
};
/// Rewriter for the `Tracing::trace_ciphertext` operation.
struct TraceCiphertextOpPattern : CrtOpPattern<Tracing::TraceCiphertextOp> {
TraceCiphertextOpPattern(mlir::MLIRContext *context,
concretelang::CrtLoweringParameters params,
mlir::PatternBenefit benefit = 1)
: CrtOpPattern<Tracing::TraceCiphertextOp>(context, params, benefit) {}
::mlir::LogicalResult
matchAndRewrite(Tracing::TraceCiphertextOp op,
Tracing::TraceCiphertextOp::Adaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
typing::TypeConverter converter{loweringParameters};
mlir::Type ciphertextScalarType =
converter.convertType(op.ciphertext().getType())
.cast<mlir::RankedTensorType>()
.getElementType();
for (size_t i = 0; i < (loweringParameters.nMods - 1); ++i) {
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(i))});
rewriter.create<Tracing::TraceCiphertextOp>(
op.getLoc(), extractedCiphertext, op.msgAttr(), op.nmsbAttr());
}
auto extractedCiphertext = rewriter.create<mlir::tensor::ExtractOp>(
op.getLoc(), ciphertextScalarType, adaptor.ciphertext(),
mlir::ValueRange{rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(loweringParameters.nMods - 1))});
rewriter.replaceOpWithNewOp<Tracing::TraceCiphertextOp>(
op, extractedCiphertext, op.msgAttr(), op.nmsbAttr());
return mlir::success();
}
};
/// Rewriter for the `tensor::extract` operation.
struct TensorExtractOpPattern : public CrtOpPattern<mlir::tensor::ExtractOp> {
@@ -924,6 +965,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
target, converter);
concretelang::addDynamicallyLegalTypeOp<mlir::tensor::CollapseShapeOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
target, converter);
concretelang::addDynamicallyLegalTypeOp<
concretelang::RT::MakeReadyFutureOp>(target, converter);
concretelang::addDynamicallyLegalTypeOp<concretelang::RT::AwaitFutureOp>(
@@ -1006,6 +1049,8 @@ struct FHEToTFHECrtPass : public FHEToTFHECrtBase<FHEToTFHECrtPass> {
loweringParameters);
patterns.add<lowering::InsertSliceOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<lowering::TraceCiphertextOpPattern>(patterns.getContext(),
loweringParameters);
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::tensor::GenerateOp, true>>(&getContext(), converter);

View File

@@ -28,9 +28,11 @@
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
namespace FHE = mlir::concretelang::FHE;
namespace TFHE = mlir::concretelang::TFHE;
namespace Tracing = mlir::concretelang::Tracing;
namespace concretelang = mlir::concretelang;
namespace fhe_to_tfhe_scalar_conversion {
@@ -546,10 +548,15 @@ struct FHEToTFHEScalarPass : public FHEToTFHEScalarBase<FHEToTFHEScalarPass> {
converter);
concretelang::addDynamicallyLegalTypeOp<mlir::scf::ForOp>(target,
converter);
patterns.add<FunctionConstantOpConversion<typing::TypeConverter>>(
&getContext(), converter);
// Patterns for `tracing` dialect.
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
Tracing::TraceCiphertextOp, true>>(&getContext(), converter);
concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
target, converter);
// Patterns for `bufferization` dialect operations.
patterns.add<concretelang::TypeConvertingReinstantiationPattern<
mlir::bufferization::AllocTensorOp, true>>(patterns.getContext(),

View File

@@ -15,6 +15,7 @@
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "concretelang/Support/Constants.h"
#include <mlir/Dialect/Bufferization/IR/Bufferization.h>
@@ -332,6 +333,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::Tracing::TraceCiphertextOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,
mlir::concretelang::GenericTypeConverterPattern<mlir::scf::YieldOp>,
mlir::concretelang::GenericTypeConverterPattern<
@@ -351,6 +354,8 @@ void TFHEGlobalParametrizationPass::runOnOperation() {
mlir::concretelang::GenericTypeConverterPattern<
mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(),
converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::Tracing::TraceCiphertextOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<
mlir::concretelang::RT::MakeReadyFutureOp>(target, converter);
mlir::concretelang::addDynamicallyLegalTypeOp<

View File

@@ -19,9 +19,11 @@
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
namespace TFHE = mlir::concretelang::TFHE;
namespace Concrete = mlir::concretelang::Concrete;
namespace Tracing = mlir::concretelang::Tracing;
namespace {
struct TFHEToConcretePass : public TFHEToConcreteBase<TFHEToConcretePass> {
@@ -136,6 +138,31 @@ private:
mlir::TypeConverter &converter;
};
struct TracePlaintextOpPattern
: public mlir::OpRewritePattern<Tracing::TracePlaintextOp> {
TracePlaintextOpPattern(mlir::MLIRContext *context,
mlir::TypeConverter &converter,
mlir::PatternBenefit benefit = 100)
: mlir::OpRewritePattern<Tracing::TracePlaintextOp>(context, benefit) {}
mlir::LogicalResult
matchAndRewrite(Tracing::TracePlaintextOp op,
mlir::PatternRewriter &rewriter) const override {
auto inputWidth =
op.plaintext().getType().cast<mlir::IntegerType>().getWidth();
if (inputWidth == 64) {
op->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
return mlir::success();
}
auto extendedInput = rewriter.create<mlir::arith::ExtUIOp>(
op.getLoc(), rewriter.getI64Type(), op.plaintext());
auto newOp = rewriter.replaceOpWithNewOp<Tracing::TracePlaintextOp>(
op, extendedInput, op.msgAttr(), op.nmsbAttr());
newOp->setAttr("input_width", rewriter.getI64IntegerAttr(inputWidth));
return ::mlir::success();
}
};
void TFHEToConcretePass::runOnOperation() {
auto op = this->getOperation();
@@ -171,6 +198,7 @@ void TFHEToConcretePass::runOnOperation() {
return FunctionConstantOpConversion<
TFHEToConcreteTypeConverter>::isLegal(op, converter);
});
// Add all patterns required to lower all ops from `TFHE` to
// `Concrete`
mlir::RewritePatternSet patterns(&getContext());
@@ -228,6 +256,19 @@ void TFHEToConcretePass::runOnOperation() {
mlir::populateFunctionOpInterfaceTypeConversionPattern<mlir::func::FuncOp>(
patterns, converter);
// Conversion of Tracing dialect
patterns.add<mlir::concretelang::GenericTypeConverterPattern<
Tracing::TraceCiphertextOp>>(&getContext(), converter);
mlir::concretelang::addDynamicallyLegalTypeOp<Tracing::TraceCiphertextOp>(
target, converter);
patterns.add<TracePlaintextOpPattern>(&getContext(), converter);
target.addLegalOp<mlir::arith::ExtUIOp>();
target.addDynamicallyLegalOp<Tracing::TracePlaintextOp>(
[&](Tracing::TracePlaintextOp op) {
return (op.plaintext().getType().cast<mlir::IntegerType>().getWidth() ==
64);
});
// Conversion of RT Dialect Ops
patterns.add<
mlir::concretelang::GenericTypeConverterPattern<mlir::func::ReturnOp>,

View File

@@ -0,0 +1,14 @@
add_mlir_dialect_library(
TracingToCAPI
TracingToCAPI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
DEPENDS
TracingDialect
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR
MLIRTransforms)
target_link_libraries(TracingToCAPI PUBLIC TracingDialect MLIRIR)

View File

@@ -0,0 +1,238 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <mlir/Pass/Pass.h>
#include <mlir/Transforms/DialectConversion.h>
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Tools.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h"
namespace {
namespace Tracing = mlir::concretelang::Tracing;
namespace arith = mlir::arith;
namespace func = mlir::func;
namespace memref = mlir::memref;
char memref_trace_ciphertext[] = "memref_trace_ciphertext";
char memref_trace_plaintext[] = "memref_trace_plaintext";
char memref_trace_message[] = "memref_trace_message";
mlir::Type getDynamicMemrefWithUnknownOffset(mlir::RewriterBase &rewriter,
size_t rank) {
std::vector<int64_t> shape(rank, -1);
mlir::AffineExpr expr = rewriter.getAffineSymbolExpr(0);
for (size_t i = 0; i < rank; i++) {
expr = expr +
(rewriter.getAffineDimExpr(i) * rewriter.getAffineSymbolExpr(i + 1));
}
return mlir::MemRefType::get(
shape, rewriter.getI64Type(),
mlir::AffineMap::get(rank, rank + 1, expr, rewriter.getContext()));
}
// Returns `memref.cast %0 : memref<...xAxT> to memref<...x?xT>`
mlir::Value getCastedMemRef(mlir::RewriterBase &rewriter, mlir::Value value) {
mlir::Type valueType = value.getType();
if (auto memrefTy = valueType.dyn_cast_or_null<mlir::MemRefType>()) {
return rewriter.create<mlir::memref::CastOp>(
value.getLoc(),
getDynamicMemrefWithUnknownOffset(rewriter, memrefTy.getShape().size()),
value);
} else {
return value;
}
}
mlir::LogicalResult insertForwardDeclarationOfTheCAPI(
mlir::Operation *op, mlir::RewriterBase &rewriter, char const *funcName) {
auto memref1DType = getDynamicMemrefWithUnknownOffset(rewriter, 1);
mlir::FunctionType funcType;
if (funcName == memref_trace_ciphertext) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{memref1DType, mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else if (funcName == memref_trace_plaintext) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{rewriter.getI64Type(), rewriter.getI64Type(),
mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type(), rewriter.getI32Type()},
{});
} else if (funcName == memref_trace_message) {
funcType = mlir::FunctionType::get(
rewriter.getContext(),
{mlir::LLVM::LLVMPointerType::get(rewriter.getI8Type()),
rewriter.getI32Type()},
{});
} else {
op->emitError("unknwon external function") << funcName;
return mlir::failure();
}
return insertForwardDeclaration(op, rewriter, funcName, funcType);
}
template <typename Op>
void addNoOperands(Op op, mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {}
template <typename Op, char const *callee>
struct TracingToCAPICallPattern : public mlir::OpRewritePattern<Op> {
TracingToCAPICallPattern(
::mlir::MLIRContext *context,
std::function<void(Op op, llvm::SmallVector<mlir::Value> &,
mlir::RewriterBase &)>
addOperands = addNoOperands<Op>,
mlir::PatternBenefit benefit = 1)
: ::mlir::OpRewritePattern<Op>(context, benefit),
addOperands(addOperands) {}
::mlir::LogicalResult
matchAndRewrite(Op op, ::mlir::PatternRewriter &rewriter) const override {
// Create the operands
mlir::SmallVector<mlir::Value> operands;
// For all tensor operand get the corresponding casted buffer
for (auto &operand : op->getOpOperands()) {
mlir::Type type = operand.get().getType();
if (!type.isa<mlir::MemRefType>()) {
operands.push_back(operand.get());
} else {
operands.push_back(getCastedMemRef(rewriter, operand.get()));
}
}
// append additional argument
addOperands(op, operands, rewriter);
// Insert forward declaration of the function
if (insertForwardDeclarationOfTheCAPI(op, rewriter, callee).failed()) {
return mlir::failure();
}
rewriter.replaceOpWithNewOp<func::CallOp>(op, callee, mlir::TypeRange{},
operands);
return ::mlir::success();
};
private:
std::function<void(Op op, llvm::SmallVector<mlir::Value> &,
mlir::RewriterBase &)>
addOperands;
};
void traceCiphertextAddOperands(Tracing::TraceCiphertextOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
auto nmsb = op.nmsb().getValueOr(0);
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
operands.push_back(messageVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(nmsb)));
}
void tracePlaintextAddOperands(Tracing::TracePlaintextOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
auto nmsb = op.nmsb().getValueOr(0);
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), op->getAttr("input_width")));
operands.push_back(messageVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(nmsb)));
}
void traceMessageAddOperands(Tracing::TraceMessageOp op,
mlir::SmallVector<mlir::Value> &operands,
mlir::RewriterBase &rewriter) {
auto msg = op.msg().getValueOr("");
std::string msgName;
std::stringstream stream;
stream << rand();
stream >> msgName;
auto messageVal =
mlir::LLVM::createGlobalString(op.getLoc(), rewriter, msgName, msg,
mlir::LLVM::linkage::Linkage::Linkonce);
operands.push_back(messageVal);
operands.push_back(rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), rewriter.getI32IntegerAttr(msg.size())));
}
struct TracingToCAPIPass : public TracingToCAPIBase<TracingToCAPIPass> {
TracingToCAPIPass() {}
void runOnOperation() override {
auto op = this->getOperation();
mlir::ConversionTarget target(getContext());
mlir::RewritePatternSet patterns(&getContext());
// Mark ops from the target dialect as legal operations
target.addLegalDialect<func::FuncDialect>();
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<arith::ArithmeticDialect>();
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
// Make sure that no ops from `Tracing` remain after the lowering
target.addIllegalDialect<Tracing::TracingDialect>();
// Add patterns to transform Tracing operators to CAPI call
patterns.add<TracingToCAPICallPattern<Tracing::TraceCiphertextOp,
memref_trace_ciphertext>>(
&getContext(), traceCiphertextAddOperands);
patterns.add<TracingToCAPICallPattern<Tracing::TracePlaintextOp,
memref_trace_plaintext>>(
&getContext(), tracePlaintextAddOperands);
patterns.add<TracingToCAPICallPattern<Tracing::TraceMessageOp,
memref_trace_message>>(
&getContext(), traceMessageAddOperands);
// Apply conversion
if (mlir::applyPartialConversion(op, target, std::move(patterns))
.failed()) {
this->signalPassFailure();
}
}
};
} // namespace
namespace mlir {
namespace concretelang {
std::unique_ptr<OperationPass<ModuleOp>> createConvertTracingToCAPIPass() {
return std::make_unique<TracingToCAPIPass>();
}
} // namespace concretelang
} // namespace mlir

View File

@@ -18,6 +18,7 @@
#include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h"
#include "concretelang/Dialect/BConcrete/IR/BConcreteOps.h"
#include "concretelang/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "concretelang/Support/CompilerEngine.h"
#include <mlir/IR/AffineExpr.h>
#include <mlir/IR/AffineMap.h>
@@ -30,6 +31,7 @@ using namespace mlir::tensor;
namespace {
namespace BConcrete = mlir::concretelang::BConcrete;
namespace Tracing = mlir::concretelang::Tracing;
template <typename TensorOp, typename MemrefOp>
struct TensorToMemrefOp : public BufferizableOpInterface::ExternalModel<

View File

@@ -5,3 +5,4 @@ add_subdirectory(Concrete)
add_subdirectory(BConcrete)
add_subdirectory(RT)
add_subdirectory(SDFG)
add_subdirectory(Tracing)

View File

@@ -26,6 +26,7 @@
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/FHE/IR/FHETypes.h"
#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "concretelang/Support/V0Parameters.h"
#include "concretelang/Support/logging.h"
@@ -120,7 +121,8 @@ struct FunctionToDag {
auto encrypted_inputs = encryptedInputs(op);
if (!hasEncryptedResult(op)) {
// This op is unrelated to FHE
assert(encrypted_inputs.empty());
assert(encrypted_inputs.empty() ||
mlir::isa<mlir::concretelang::Tracing::TraceCiphertextOp>(op));
return;
}
assert(op.getNumResults() == 1);

View File

@@ -0,0 +1,2 @@
add_subdirectory(IR)
add_subdirectory(Transforms)

View File

@@ -0,0 +1,13 @@
add_mlir_dialect_library(
TracingDialect
TracingDialect.cpp
TracingOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRIR)
target_link_libraries(TracingDialect PUBLIC MLIRIR)

View File

@@ -0,0 +1,20 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
#include "concretelang/Dialect/Tracing/IR/TracingOpsDialect.cpp.inc"
#include "concretelang/Support/Constants.h"
using namespace mlir::concretelang::Tracing;
void TracingDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc"
>();
}

View File

@@ -0,0 +1,20 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/IR/Region.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteTypes.h"
#include "concretelang/Dialect/FHE/IR/FHEOps.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
namespace mlir {
namespace concretelang {
namespace Tracing {} // namespace Tracing
} // namespace concretelang
} // namespace mlir
#define GET_OP_CLASSES
#include "concretelang/Dialect/Tracing/IR/TracingOps.cpp.inc"

View File

@@ -0,0 +1,89 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
#include <mlir/Dialect/Func/IR/FuncOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Transforms/RegionUtils.h>
#include "concretelang/Dialect/Tracing/IR/TracingDialect.h"
#include "concretelang/Dialect/Tracing/IR/TracingOps.h"
using namespace mlir;
using namespace mlir::bufferization;
using namespace mlir::concretelang::Tracing;
namespace {
template <typename Op>
struct TrivialBufferizableInterface
: public BufferizableOpInterface::ExternalModel<
TrivialBufferizableInterface<Op>, Op> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return true;
}
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return false;
}
SmallVector<OpResult> getAliasingOpResult(Operation *op, OpOperand &opOperand,
const AnalysisState &state) const {
return {};
}
BufferRelation bufferRelation(Operation *op, OpResult opResult,
const AnalysisState &state) const {
return BufferRelation::None;
}
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
const BufferizationOptions &options) const {
mlir::SmallVector<mlir::Value> operands;
for (auto &operand : op->getOpOperands()) {
if (!operand.get().getType().isa<mlir::RankedTensorType>()) {
operands.push_back(operand.get());
} else {
operands.push_back(
bufferization::getBuffer(rewriter, operand.get(), options));
}
}
rewriter.replaceOpWithNewOp<Op>(op, mlir::TypeRange{}, operands,
op->getAttrs());
return success();
}
};
} // namespace
namespace mlir {
namespace concretelang {
namespace Tracing {
void registerBufferizableOpInterfaceExternalModels(DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, TracingDialect *dialect) {
// trace_ciphretext
Tracing::TraceCiphertextOp::attachInterface<
TrivialBufferizableInterface<Tracing::TraceCiphertextOp>>(*ctx);
// trace_plaintext
Tracing::TracePlaintextOp::attachInterface<
TrivialBufferizableInterface<Tracing::TracePlaintextOp>>(*ctx);
// trace_message
Tracing::TraceMessageOp::attachInterface<
TrivialBufferizableInterface<Tracing::TraceMessageOp>>(*ctx);
});
}
} // namespace Tracing
} // namespace concretelang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_dialect_library(
TracingDialectTransforms
BufferizableOpInterfaceImpl.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/concretelang/Dialect/Tracing
DEPENDS
mlir-headers
LINK_LIBS
PUBLIC
MLIRArithmeticDialect
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRIR
MLIRMemRefDialect
MLIRPass
MLIRTransforms)

View File

@@ -7,6 +7,7 @@
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/seeder.h"
#include <assert.h>
#include <bitset>
#include <cmath>
#include <functional>
#include <iostream>
@@ -658,3 +659,32 @@ void memref_copy_one_rank(uint64_t *src_allocated, uint64_t *src_aligned,
src_aligned[src_offset + i * src_stride];
}
}
void memref_trace_ciphertext(uint64_t *ct0_allocated, uint64_t *ct0_aligned,
uint64_t ct0_offset, uint64_t ct0_size,
uint64_t ct0_stride, char *message_ptr,
uint32_t message_len, uint32_t msb) {
std::string message{message_ptr, (size_t)message_len};
std::cout << message << " : ";
std::bitset<64> bits{ct0_aligned[ct0_offset + ct0_size - 1]};
std::string bitstring = bits.to_string();
bitstring.insert(msb, 1, ' ');
std::cout << bitstring << std::endl;
}
void memref_trace_plaintext(uint64_t input, uint64_t input_width,
char *message_ptr, uint32_t message_len,
uint32_t msb) {
std::string message{message_ptr, (size_t)message_len};
std::cout << message << " : ";
std::bitset<64> bits{input};
std::string bitstring = bits.to_string();
bitstring.erase(0, 64 - input_width);
bitstring.insert(msb, 1, ' ');
std::cout << bitstring << std::endl;
}
void memref_trace_message(char *message_ptr, uint32_t message_len) {
std::string message{message_ptr, (size_t)message_len};
std::cout << message << std::endl;
}

View File

@@ -38,6 +38,8 @@
#include <concretelang/Dialect/SDFG/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Dialect/SDFG/Transforms/SDFGConvertibleOpInterfaceImpl.h>
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <concretelang/Dialect/Tracing/IR/TracingDialect.h>
#include <concretelang/Dialect/Tracing/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Error.h>
@@ -73,6 +75,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
if (this->mlirContext == nullptr) {
mlir::DialectRegistry registry;
registry.insert<
mlir::concretelang::Tracing::TracingDialect,
mlir::concretelang::RT::RTDialect, mlir::concretelang::FHE::FHEDialect,
mlir::concretelang::TFHE::TFHEDialect,
mlir::concretelang::FHELinalg::FHELinalgDialect,
@@ -83,6 +86,7 @@ mlir::MLIRContext *CompilationContext::getMLIRContext() {
mlir::LLVM::LLVMDialect, mlir::scf::SCFDialect,
mlir::omp::OpenMPDialect, mlir::bufferization::BufferizationDialect>();
BConcrete::registerBufferizableOpInterfaceExternalModels(registry);
Tracing::registerBufferizableOpInterfaceExternalModels(registry);
SDFG::registerSDFGConvertibleOpInterfaceExternalModels(registry);
SDFG::registerBufferizableOpInterfaceExternalModels(registry);
arith::registerBufferizableOpInterfaceExternalModels(registry);

View File

@@ -398,6 +398,8 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertBConcreteToCAPIPass(gpu),
enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createConvertTracingToCAPIPass(), enablePass);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: : [[BODY1:[01]{64}]]
// CHECK-NEXT: : [[BODY2:[01]{64}]]
// CHECK-NEXT: : [[BODY3:[01]{64}]]
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,8 @@
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: : [[BODY:[01]{64}]]
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<5>) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --force-encoding=crt --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: Test : [[BODY01:[01]{3}]] [[BODY02:[01]{61}]]
// CHECK-NEXT: Test : [[BODY11:[01]{3}]] [[BODY12:[01]{61}]]
// CHECK-NEXT: Test : [[BODY21:[01]{3}]] [[BODY22:[01]{61}]]
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
"Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,8 @@
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: Test : [[BODY:[01]{3}]] [[BODY2:[01]{61}]]
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
"Tracing.trace_ciphertext"(%arg0){msg="Test", nmsb=3:i32}: (!FHE.eint<5>) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,8 @@
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: Arbitrary message
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
"Tracing.trace_message"(){msg="Arbitrary message"}: () -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,9 @@
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: : 00000100
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = arith.constant 4 : i8
"Tracing.trace_plaintext"(%0): (i8) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,9 @@
// RUN: concretecompiler --action=jit-invoke --jit-args=1 --funcname=main %s 2>&1| FileCheck %s
// CHECK: Test : 00000100
// CHECK-NEXT: 1
func.func @main(%arg0: !FHE.eint<5>) -> !FHE.eint<5> {
%0 = arith.constant 4 : i8
"Tracing.trace_plaintext"(%0){msg="Test"}: (i8) -> ()
return %arg0: !FHE.eint<5>
}

View File

@@ -0,0 +1,94 @@
description: trace_ciphertext_without_attributes
program: |
func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> {
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<1>) -> ()
return %arg0: !FHE.eint<1>
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_ciphertext_without_attributes_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
"Tracing.trace_ciphertext"(%arg0): (!FHE.eint<16>) -> ()
return %arg0: !FHE.eint<16>
}
encoding: crt
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_ciphertextwith_attributes
program: |
func.func @main(%arg0: !FHE.eint<1>) -> !FHE.eint<1> {
"Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<1>) -> ()
return %arg0: !FHE.eint<1>
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_ciphertext_with_attributes_16bits
program: |
func.func @main(%arg0: !FHE.eint<16>) -> !FHE.eint<16> {
"Tracing.trace_ciphertext"(%arg0){nmsb=4:i32, msg="test"}: (!FHE.eint<16>) -> ()
return %arg0: !FHE.eint<16>
}
encoding: crt
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_plaintext_without_attributes
program: |
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
%c0 = arith.constant 1 : i8
"Tracing.trace_plaintext"(%c0): (i8) -> ()
%c1 = arith.constant 1 : i64
return %c1: i64
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_plaintext_with_attributes
program: |
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
%c0 = arith.constant 1 : i8
"Tracing.trace_plaintext"(%c0){nmsb=3:i32, msg="test"}: (i8) -> ()
%c1 = arith.constant 1 : i64
return %c1: i64
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1
---
description: trace_message
program: |
func.func @main(%arg0: !FHE.eint<1>) -> i64 {
%c0 = arith.constant 1 : i8
"Tracing.trace_plaintext"(%c0): (i8) -> ()
"Tracing.trace_message"(){msg="Test"}: () -> ()
%c1 = arith.constant 2 : i8
"Tracing.trace_plaintext"(%c1): (i8) -> ()
%c2 = arith.constant 1 : i64
return %c2: i64
}
tests:
- inputs:
- scalar: 1
outputs:
- scalar: 1