diff --git a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp index ed5c1c1db..0d471aba0 100644 --- a/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/MidLFHEGlobalParametrization/MidLFHEGlobalParametrization.cpp @@ -129,6 +129,70 @@ private: mlir::zamalang::V0Parameter &v0Parameter; }; +struct MidLFHEApplyLookupTablePaddingPattern + : public mlir::OpRewritePattern { + MidLFHEApplyLookupTablePaddingPattern(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) + : mlir::OpRewritePattern( + context, benefit), + typeConverter(typeConverter), v0Parameter(v0Parameter) {} + + mlir::LogicalResult + matchAndRewrite(mlir::zamalang::MidLFHE::ApplyLookupTable op, + mlir::PatternRewriter &rewriter) const override { + auto glweInType = op.getOperandTypes()[0] + .cast(); + auto tabulatedLambdaType = + op.l_cst().getType().cast(); + auto glweOutType = + op.getType().cast(); + auto expectedSize = 1 << glweInType.getP(); + if (tabulatedLambdaType.getShape()[0] < expectedSize) { + auto constantOp = + mlir::dyn_cast_or_null(op.l_cst().getDefiningOp()); + if (constantOp == nullptr) { + op.emitError() << "padding for non-constant operator is NYI"; + return mlir::failure(); + } + mlir::DenseIntElementsAttr denseVals = + constantOp->getAttrOfType("value"); + if (denseVals == nullptr) { + op.emitError() << "value should be dense"; + return mlir::failure(); + } + // Create the new constant dense op with padding + auto integerSize = 64; + llvm::SmallVector rawNewDenseVals( + expectedSize, llvm::APInt(integerSize, 0)); + for (auto i = 0; i < expectedSize; i++) { + rawNewDenseVals[i] = llvm::APInt( + integerSize, + denseVals.getFlatValue(i % denseVals.size()) + .getZExtValue()); + } + auto newDenseValsType = mlir::RankedTensorType::get( + {expectedSize}, rewriter.getIntegerType(integerSize)); + auto newDenseVals = + mlir::DenseIntElementsAttr::get(newDenseValsType, rawNewDenseVals); + auto newConstantOp = + rewriter.create(constantOp.getLoc(), newDenseVals); + // Replace the apply_lookup_table with the new constant + mlir::SmallVector newResultTypes{op.getType()}; + llvm::SmallVector newOperands{op.ct(), newConstantOp}; + llvm::ArrayRef newAttrs = op->getAttrs(); + rewriter.replaceOpWithNewOp( + op, newResultTypes, newOperands, newAttrs); + return mlir::success(); + } + + return mlir::success(); + }; + +private: + mlir::TypeConverter &typeConverter; + mlir::zamalang::V0Parameter &v0Parameter; +}; + template void populateWithMidLFHEOpTypeConversionPattern( mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target, @@ -160,6 +224,24 @@ void populateWithMidLFHEApplyLookupTableParametrizationPattern( }); } +void populateWithMidLFHEApplyLookupTablePaddingPattern( + mlir::RewritePatternSet &patterns, mlir::ConversionTarget &target) { + patterns.add(patterns.getContext()); + target.addLegalOp(); + target.addDynamicallyLegalOp( + [&](mlir::zamalang::MidLFHE::ApplyLookupTable op) { + auto glweInType = + op.getOperandTypes()[0] + .cast(); + auto tabulatedLambdaType = + op.getOperandTypes()[1].cast(); + auto glweOutType = + op.getType().cast(); + + return tabulatedLambdaType.getShape()[0] == 1 << glweInType.getP(); + }); +} + /// Populate the RewritePatternSet with all patterns that rewrite LowLFHE /// operators to the corresponding function call to the `Concrete C API`. void populateWithMidLFHEOpTypeConversionPatterns( @@ -183,31 +265,47 @@ void populateWithMidLFHEOpTypeConversionPatterns( void MidLFHEGlobalParametrizationPass::runOnOperation() { auto op = this->getOperation(); - mlir::ConversionTarget target(getContext()); MidLFHEGlobalParametrizationTypeConverter converter(fheContext); - // Make sure that no ops from `MidLFHE` remain after the lowering - target.addIllegalDialect(); + // Parametrize + { + mlir::ConversionTarget target(getContext()); + mlir::OwningRewritePatternList patterns(&getContext()); - // Make sure func has legal signature - target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { - return converter.isSignatureLegal(funcOp.getType()) && - converter.isLegal(&funcOp.getBody()); - }); - // Add all patterns required to lower all ops from `MidLFHE` to - // `LowLFHE` - mlir::OwningRewritePatternList patterns(&getContext()); - populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, - fheContext.parameter); - patterns.add>(&getContext(), converter); - mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, - converter); - mlir::populateFuncOpTypeConversionPattern(patterns, converter); + // function signature + target.addDynamicallyLegalOp([&](mlir::FuncOp funcOp) { + return converter.isSignatureLegal(funcOp.getType()) && + converter.isLegal(&funcOp.getBody()); + }); + mlir::populateFuncOpTypeConversionPattern(patterns, converter); - // Apply conversion - if (mlir::applyPartialConversion(op, target, std::move(patterns)).failed()) { - this->signalPassFailure(); + // Add all patterns to convert MidLFHE types + populateWithMidLFHEOpTypeConversionPatterns(patterns, target, converter, + fheContext.parameter); + patterns.add>(&getContext(), converter); + mlir::zamalang::populateWithTensorTypeConverterPatterns(patterns, target, + converter); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } + } + + // Pad lookup table + { + mlir::ConversionTarget target(getContext()); + mlir::OwningRewritePatternList patterns(&getContext()); + + populateWithMidLFHEApplyLookupTablePaddingPattern(patterns, target); + + // Apply conversion + if (mlir::applyPartialConversion(op, target, std::move(patterns)) + .failed()) { + this->signalPassFailure(); + } } } diff --git a/compiler/tests/unittest/end_to_end_jit_test.cc b/compiler/tests/unittest/end_to_end_jit_test.cc index 92859c509..2100bee09 100644 --- a/compiler/tests/unittest/end_to_end_jit_test.cc +++ b/compiler/tests/unittest/end_to_end_jit_test.cc @@ -414,6 +414,23 @@ func @main(%arg0: tensor<4x!HLFHE.eint<7>>, ASSERT_EQ(res, 14); } +TEST(CompileAndRunTLU, identity_func_5) { + mlir::zamalang::CompilerEngine engine; + auto mlirStr = R"XXX( +func @main(%arg0: !HLFHE.eint<5>) -> !HLFHE.eint<5> { + %tlu = std.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]> : tensor<32xi64> + %1 = "HLFHE.apply_lookup_table"(%arg0, %tlu): (!HLFHE.eint<5>, tensor<32xi64>) -> (!HLFHE.eint<5>) + return %1: !HLFHE.eint<5> +} +)XXX"; + ASSERT_FALSE(engine.compile(mlirStr)); + uint64_t expected = 3; + auto maybeResult = engine.run({expected}); + ASSERT_TRUE((bool)maybeResult); + uint64_t result = maybeResult.get(); + ASSERT_EQ(result, expected); +} + TEST(CompileAndRunTLU, identity_func) { mlir::zamalang::CompilerEngine engine; auto mlirStr = R"XXX(