mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(compiler): Remove the type conversion workaround, actually the order of addConversion calls matters (lifo)
This commit is contained in:
@@ -26,11 +26,12 @@ class HLFHEToMidLFHETypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
HLFHEToMidLFHETypeConverter() {
|
||||
addConversion([&](EncryptedIntegerType type) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([](EncryptedIntegerType type) {
|
||||
return mlir::zamalang::convertTypeEncryptedIntegerToGLWE(
|
||||
type.getContext(), type);
|
||||
});
|
||||
addConversion([&](mlir::MemRefType type) {
|
||||
addConversion([](mlir::MemRefType type) {
|
||||
auto eint =
|
||||
type.getElementType().dyn_cast_or_null<EncryptedIntegerType>();
|
||||
if (eint == nullptr) {
|
||||
@@ -44,25 +45,6 @@ public:
|
||||
return r;
|
||||
});
|
||||
}
|
||||
|
||||
/// [workaround] as converter.isLegal returns unexpected false for glwe with
|
||||
/// same parameters.
|
||||
static bool _isLegal(mlir::Type type) {
|
||||
if (type.isa<EncryptedIntegerType>()) {
|
||||
return false;
|
||||
}
|
||||
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
|
||||
if (memref != nullptr) {
|
||||
return _isLegal(memref.getElementType());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// [workaround]
|
||||
template <typename TypeRangeT> static bool _isLegal(TypeRangeT &&types) {
|
||||
return llvm::all_of(types,
|
||||
[&](const mlir::Type ty) { return _isLegal(ty); });
|
||||
}
|
||||
};
|
||||
|
||||
void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
@@ -86,16 +68,9 @@ void HLFHEToMidLFHEPass::runOnOperation() {
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp funcOp) {
|
||||
HLFHEToMidLFHETypeConverter converter;
|
||||
// [workaround] should be this commented code but for an unknown reasons
|
||||
// converter.isLegal returns false for glwe with same parameters.
|
||||
//
|
||||
// return converter.isSignatureLegal(op.getType()) &&
|
||||
// converter.isLegal(&op.getBody());
|
||||
auto funcType = funcOp.getType();
|
||||
return HLFHEToMidLFHETypeConverter::_isLegal(funcType.getInputs()) &&
|
||||
HLFHEToMidLFHETypeConverter::_isLegal(funcType.getResults());
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
// Add all patterns required to lower all ops from `HLFHE` to
|
||||
// `MidLFHE`
|
||||
|
||||
@@ -27,6 +27,7 @@ class MidLFHEToLowLFHETypeConverter : public mlir::TypeConverter {
|
||||
|
||||
public:
|
||||
MidLFHEToLowLFHETypeConverter() {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([&](GLWECipherTextType type) {
|
||||
return mlir::zamalang::convertTypeGLWEToLWE(type.getContext(), type);
|
||||
});
|
||||
@@ -41,10 +42,6 @@ public:
|
||||
type.getAffineMaps(), type.getMemorySpace());
|
||||
return r;
|
||||
});
|
||||
// [workaround] need these converters to consider those types legal
|
||||
addConversion([&](mlir::IntegerType type) { return type; });
|
||||
addConversion(
|
||||
[&](mlir::zamalang::LowLFHE::LweCiphertextType type) { return type; });
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,8 +66,7 @@ void MidLFHEToLowLFHEPass::runOnOperation() {
|
||||
});
|
||||
|
||||
// Make sure that func has legal signature
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([](mlir::FuncOp funcOp) {
|
||||
MidLFHEToLowLFHETypeConverter converter;
|
||||
target.addDynamicallyLegalOp<mlir::FuncOp>([&](mlir::FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType()) &&
|
||||
converter.isLegal(&funcOp.getBody());
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user