fix(compiler): Remove the type conversion workaround, actually the order of addConversion calls matters (lifo)

This commit is contained in:
Quentin Bourgerie
2021-08-13 11:33:49 +02:00
parent 42f12b22da
commit ed7dd36e70
2 changed files with 8 additions and 37 deletions

View File

@@ -26,11 +26,12 @@ class HLFHEToMidLFHETypeConverter : public mlir::TypeConverter {
public:
HLFHEToMidLFHETypeConverter() {
addConversion([&](EncryptedIntegerType type) {
addConversion([](mlir::Type type) { return type; });
addConversion([](EncryptedIntegerType type) {
return mlir::zamalang::convertTypeEncryptedIntegerToGLWE(
type.getContext(), type);
});
addConversion([&](mlir::MemRefType type) {
addConversion([](mlir::MemRefType type) {
auto eint =
type.getElementType().dyn_cast_or_null<EncryptedIntegerType>();
if (eint == nullptr) {
@@ -44,25 +45,6 @@ public:
return r;
});
}
/// [workaround] as converter.isLegal returns unexpected false for glwe with
/// same parameters.
static bool _isLegal(mlir::Type type) {
if (type.isa<EncryptedIntegerType>()) {
return false;
}
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
if (memref != nullptr) {
return _isLegal(memref.getElementType());
}
return true;
}
// [workaround]
template <typename TypeRangeT> static bool _isLegal(TypeRangeT &&types) {
return llvm::all_of(types,
[&](const mlir::Type ty) { return _isLegal(ty); });
}
};
void HLFHEToMidLFHEPass::runOnOperation() {
@@ -86,16 +68,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp funcOp) {
HLFHEToMidLFHETypeConverter converter;
// [workaround] should be this commented code but for an unknown reasons
// converter.isLegal returns false for glwe with same parameters.
//
// return converter.isSignatureLegal(op.getType()) &&
// converter.isLegal(&op.getBody());
auto funcType = funcOp.getType();
return HLFHEToMidLFHETypeConverter::_isLegal(funcType.getInputs()) &&
HLFHEToMidLFHETypeConverter::_isLegal(funcType.getResults());
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});
// Add all patterns required to lower all ops from `HLFHE` to
// `MidLFHE`

View File

@@ -27,6 +27,7 @@ class MidLFHEToLowLFHETypeConverter : public mlir::TypeConverter {
public:
MidLFHEToLowLFHETypeConverter() {
addConversion([](mlir::Type type) { return type; });
addConversion([&](GLWECipherTextType type) {
return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type);
});
@@ -41,10 +42,6 @@ public:
type.getAffineMaps(), type.getMemorySpace());
return r;
});
// [workaround] need these converters to consider those types legal
addConversion([&](mlir::IntegerType type) { return type; });
addConversion(
[&](mlir::zamalang::LowLFHE::LweCiphertextType type) { return type; });
}
};
@@ -69,8 +66,7 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
});
// Make sure that func has legal signature
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp funcOp) {
MidLFHEToLowLFHETypeConverter converter;
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType()) &&
converter.isLegal(&funcOp.getBody());
});