enhance(compiler): Add support for tensor.generate in lowering pipeline

Operations with regions currently need explicit patterns for type
conversion throughout lowering. This change adds the required patterns
for `tensor.generate` to the lowering passes, such that the operation
can be used starting from the lowering from HLFHE to MidLFHE.
This commit is contained in:
Andi Drebes
2021-12-01 10:40:22 +01:00
parent 975ee86a5e
commit 0b151724b8
4 changed files with 33 additions and 15 deletions

View File

@@ -59,12 +59,14 @@ void HLFHEToMidLFHEPass::runOnOperation() {
target.addIllegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
// Make sure that no ops `linalg.generic` that have illegal types
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
[&](mlir::linalg::GenericOp op) {
return (converter.isLegal(op.getOperandTypes()) &&
converter.isLegal(op.getResultTypes()) &&
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
@@ -79,6 +81,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
HLFHEToMidLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);

View File

@@ -82,15 +82,20 @@ void LowLFHEUnparametrizePass::runOnOperation() {
LowLFHEUnparametrizeTypeConverter converter;
// Conversion of linalg.generic operation
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
[&](mlir::linalg::GenericOp op) {
return (converter.isLegal(op.getOperandTypes()) &&
converter.isLegal(op.getResultTypes()) &&
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
});
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
LowLFHEUnparametrizeTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
LowLFHEUnparametrizeTypeConverter>>(
&getContext(), converter);
// Conversion of function signature and arguments
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {

View File

@@ -288,6 +288,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<
mlir::linalg::GenericOp, MidLFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<
mlir::tensor::GenerateOp, MidLFHEGlobalParametrizationTypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);

View File

@@ -58,12 +58,14 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
// Make sure that no ops `linalg.generic` that have illegal types
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
[&](mlir::linalg::GenericOp op) {
return (converter.isLegal(op.getOperandTypes()) &&
converter.isLegal(op.getResultTypes()) &&
target
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
[&](mlir::Operation *op) {
return (
converter.isLegal(op->getOperandTypes()) &&
converter.isLegal(op->getResultTypes()) &&
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
});
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
@@ -78,6 +80,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
MidLFHEToLowLFHETypeConverter>>(
&getContext(), converter);
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
MidLFHEToLowLFHETypeConverter>>(
&getContext(), converter);
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
converter);
mlir::populateFuncOpTypeConversionPattern(patterns, converter);