mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler): Add support for tensor.generate in lowering pipeline
Operations with regions currently need explicit patterns for type conversion throughout lowering. This change adds the required patterns for `tensor.generate` to the lowering passes, such that the operation can be used starting from the lowering from HLFHE to MidLFHE.
This commit is contained in:
@@ -59,12 +59,14 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
target.addIllegalDialect<mlir::zamalang::HLFHE::HLFHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
|
||||
[&](mlir::linalg::GenericOp op) {
|
||||
return (converter.isLegal(op.getOperandTypes()) &&
|
||||
converter.isLegal(op.getResultTypes()) &&
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
@@ -79,6 +81,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
HLFHEToMidLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
@@ -82,15 +82,20 @@ void LowLFHEUnparametrizePass::runOnOperation() {
|
||||
LowLFHEUnparametrizeTypeConverter converter;
|
||||
|
||||
// Conversion of linalg.generic operation
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
|
||||
[&](mlir::linalg::GenericOp op) {
|
||||
return (converter.isLegal(op.getOperandTypes()) &&
|
||||
converter.isLegal(op.getResultTypes()) &&
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
});
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
LowLFHEUnparametrizeTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
|
||||
// Conversion of function signature and arguments
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
|
||||
@@ -288,6 +288,9 @@ void MidLFHEGlobalParametrizationPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::linalg::GenericOp, MidLFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<
|
||||
mlir::tensor::GenerateOp, MidLFHEGlobalParametrizationTypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
|
||||
|
||||
@@ -58,12 +58,14 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
target.addIllegalDialect<mlir::zamalang::MidLFHE::MidLFHEDialect>();
|
||||
|
||||
// Make sure that no ops `linalg.generic` that have illegal types
|
||||
target.addDynamicallyLegalOp<mlir::linalg::GenericOp>(
|
||||
[&](mlir::linalg::GenericOp op) {
|
||||
return (converter.isLegal(op.getOperandTypes()) &&
|
||||
converter.isLegal(op.getResultTypes()) &&
|
||||
target
|
||||
.addDynamicallyLegalOp<mlir::linalg::GenericOp, mlir::tensor::GenerateOp>(
|
||||
[&](mlir::Operation *op) {
|
||||
return (
|
||||
converter.isLegal(op->getOperandTypes()) &&
|
||||
converter.isLegal(op->getResultTypes()) &&
|
||||
converter.isLegal(op->getRegion(0).front().getArgumentTypes()));
|
||||
});
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
@@ -78,6 +80,9 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::linalg::GenericOp,
|
||||
MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
patterns.add<RegionOpTypeConverterPattern<mlir::tensor::GenerateOp,
|
||||
MidLFHEToLowLFHETypeConverter>>(
|
||||
&getContext(), converter);
|
||||
mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target,
|
||||
converter);
|
||||
mlir::populateFuncOpTypeConversionPattern(patterns, converter);
|
||||
|
||||
Reference in New Issue
Block a user