diff --git a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h index eea6d3b65..603fdd2b8 100644 --- a/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h +++ b/compiler/include/zamalang/Conversion/Utils/GlobalFHEContext.h @@ -26,6 +26,7 @@ struct V0Parameter { : k(k), polynomialSize(polynomialSize), nSmall(nSmall), brLevel(brLevel), brLogBase(brLogBase), ksLevel(ksLevel), ksLogBase(ksLogBase) {} + // TODO remove the shift when we have true polynomial size size_t getNBigGlweSize() { return k * (1 << polynomialSize); } }; diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index cad6e8c9f..dd67b0032 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -104,7 +104,8 @@ struct MidLFHEApplyLookupTableParametrizationPattern rewriter.getI32IntegerAttr(v0Parameter.k)), mlir::NamedAttribute( rewriter.getIdentifier("polynomialSize"), - rewriter.getI32IntegerAttr(v0Parameter.polynomialSize)), + // TODO remove the shift when we have true polynomial size + rewriter.getI32IntegerAttr(1 << v0Parameter.polynomialSize)), mlir::NamedAttribute(rewriter.getIdentifier("levelKS"), rewriter.getI32IntegerAttr(v0Parameter.ksLevel)), mlir::NamedAttribute(rewriter.getIdentifier("baseLogKS"), @@ -145,7 +146,8 @@ void populateWithMidLFHEApplyLookupTableParametrizationPattern( target.addDynamicallyLegalOp( [&](mlir::zamalang::MidLFHE::ApplyLookupTable op) { if (op.k() != v0Parameter.k || - op.polynomialSize() != v0Parameter.polynomialSize || + // TODO remove the shift when we have true polynomial size + op.polynomialSize() != (1 << v0Parameter.polynomialSize) || op.levelKS() != v0Parameter.ksLevel || op.baseLogKS() != v0Parameter.ksLogBase || op.levelBS() != v0Parameter.brLevel ||