mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Add tensor ops type rewriting on high level pipepline
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
#ifndef ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
|
||||
#define ZAMALANG_CONVERSION_GENERICOPTYPECONVERSIONPATTERN_H_
|
||||
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
template <typename Op>
|
||||
struct GenericTypeConverterPattern : public mlir::OpRewritePattern<Op> {
|
||||
GenericTypeConverterPattern(mlir::MLIRContext *context,
|
||||
mlir::TypeConverter &converter,
|
||||
mlir::PatternBenefit benefit = 100)
|
||||
: mlir::OpRewritePattern<Op>(context, benefit), converter(converter) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(Op op, mlir::PatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.startRootUpdate(op);
|
||||
// Rewrite arguments
|
||||
{
|
||||
for (auto i = 0; i < op->getNumOperands(); i++) {
|
||||
auto operand = op->getOperand(i);
|
||||
mlir::Type type = converter.convertType(operand.getType());
|
||||
if (type != mlir::Type()) {
|
||||
operand.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Rewrite results
|
||||
{
|
||||
for (auto i = 0; i < op->getNumResults(); i++) {
|
||||
auto result = op->getResult(i);
|
||||
mlir::Type type = converter.convertType(result.getType());
|
||||
if (type != mlir::Type()) {
|
||||
result.setType(type);
|
||||
}
|
||||
}
|
||||
}
|
||||
rewriter.finalizeRootUpdate(op);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
private:
|
||||
mlir::TypeConverter &converter;
|
||||
};
|
||||
|
||||
template <typename Op>
|
||||
void addDynamicallyLegalTypeOp(mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
target.addDynamicallyLegalOp<Op>([&](Op op) {
|
||||
return typeConverter.isLegal(op->getOperandTypes()) &&
|
||||
typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,27 @@
|
||||
#ifndef ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
#define ZAMALANG_CONVERSION_TENSOROPTYPECONVERSIONPATTERN_H_
|
||||
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "zamalang/Conversion/Utils/GenericOpTypeConversionPattern.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
inline void
|
||||
populateWithTensorTypeConverterPatterns(mlir::RewritePatternSet &patterns,
|
||||
mlir::ConversionTarget &target,
|
||||
mlir::TypeConverter &typeConverter) {
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::ExtractOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::ExtractOp>(target, typeConverter);
|
||||
patterns.add<GenericTypeConverterPattern<mlir::tensor::FromElementsOp>>(
|
||||
patterns.getContext(), typeConverter);
|
||||
addDynamicallyLegalTypeOp<mlir::tensor::FromElementsOp>(target,
|
||||
typeConverter);
|
||||
}
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "zamalang/Conversion/HLFHEToMidLFHE/Patterns.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
|
||||
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
|
||||
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
@@ -77,6 +78,8 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
populateWithGeneratedHLFHEToMidLFHE(patterns);
|
||||
patterns.add<LinalgGenericTypeConverterPattern<HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "zamalang/Conversion/MidLFHEToLowLFHE/Patterns.h"
|
||||
#include "zamalang/Conversion/Passes.h"
|
||||
#include "zamalang/Conversion/Utils/LinalgGenericTypeConverterPattern.h"
|
||||
#include "zamalang/Conversion/Utils/TensorOpTypeConversion.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHEDialect.h"
|
||||
#include "zamalang/Dialect/LowLFHE/IR/LowLFHETypes.h"
|
||||
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
|
||||
@@ -77,6 +78,8 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
patterns
|
||||
.add<LinalgGenericTypeConverterPattern<MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
// Apply conversion
|
||||
|
||||
Reference in New Issue
Block a user