diff --git a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h index 132766d14..ef72b0f6a 100644 --- a/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h +++ b/compiler/include/zamalang/Dialect/HLFHE/IR/HLFHEOps.h @@ -19,6 +19,22 @@ bool verifyEncryptedIntegerAndIntegerInputsConsistency(OpState &op, EncryptedIntegerType &a, IntegerType &b); +/** Shared error message for all ApplyLookupTable variant Op (several Dialect) + * E.g. HLFHE.apply_lookup_table(input, lut) + * Message when the lut tensor has an invalid size, + * i.e. it cannot accomodate the input elements bitwidth + */ +template +void emitErrorBadLutSize(Op &op, std::string lutName, std::string inputName, + int expectedSize, int bitWidth) { + auto s = op.emitOpError(); + s << ": `" << lutName << "` (operand #2)" + << " inner dimension should have size " << expectedSize << "(=2^" + << bitWidth << ") to match " + << "`" << inputName << "` (operand #1)" + << " elements bitwidth (" << bitWidth << ")"; +} + } // namespace HLFHE } // namespace zamalang } // namespace mlir diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index 88455ba63..5c8e74498 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -342,6 +342,62 @@ def ApplyMultiLookupTableEintOp : HLFHELinalg_Op<"apply_multi_lookup_table", []> }]; } +def ApplyMappedLookupTableEintOp : HLFHELinalg_Op<"apply_mapped_lookup_table", []> { + let summary = "Returns a tensor that contains the result of the lookup on a table, using a different lookup table for each element, specified by a map."; + + let description = [{ + Performs for each encrypted indice a lookup on a table of clear integers. Multiple lookup tables are passed, and the application of lookup tables + is performed following the broadcasting rules. The precise lookup is specified by a map. + + ```mlir + // The result of this operation, is a tensor that contains the result of the lookup on different tables. + // i.e. %res[i, ..., k] = %luts[ %map[i, ..., k] ][ %t[i, ..., k] ] + %res = HLFHELinalg.apply_mapped_lookup_table(%t, %luts, %map): tensor>, tensor, tensor -> tensor> + ``` + + Examples: + ```mlir + + // Returns the lookup of 3x2 matrix of encrypted indices of width 2 on a vector of 2 tables of size 4=2^2 of clear integers. + // + // [0,1] [0, 1] = [1,2] + // [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0] + // [2,3] [0, 1] = [5,6] + "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : (tensor<3x2x!HLFHE.eint<2>>, tensor<2x4xi64>, tensor<3x2xindex>) -> tensor<3x2x!HLFHE.eint<3>> + ``` + + Others examples: + // [0,1] [1, 0] = [3,2] + // [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0] + // [2,3] [1, 0] = [4,7] + + // [0,1] [0, 0] = [1,3] + // [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1, 1] = [6,0] + // [2,3] [1, 0] = [4,7] + + // [0,1] [0] = [1,3] + // [3,0] lut [[1,3,5,7], [0,2,4,6]] with [1] = [6,0] + // [2,3] [0] = [5,7] + + // [0,1] = [1,2] + // [3,0] lut [[1,3,5,7], [0,2,4,6]] with [0, 1] = [7,0] + // [2,3] = [5,6] + + }]; + + let arguments = (ins + Type.predicate, HasStaticShapePred]>>:$t, + Type.predicate, HasStaticShapePred]>>:$luts, + Type.predicate, HasStaticShapePred]>>:$map + ); + + let results = (outs Type.predicate, HasStaticShapePred]>>); + + let verifier = [{ + return ::mlir::zamalang::HLFHELinalg::verifyApplyMappedLookupTable(*this); + }]; +} + // Dot product def Dot : HLFHELinalg_Op<"dot_eint_int"> { let summary = "Returns the encrypted dot product between a vector of encrypted integers and a vector of clean integers."; diff --git a/compiler/include/zamalang/Support/LambdaArgument.h b/compiler/include/zamalang/Support/LambdaArgument.h index c2b15d600..18495f2c6 100644 --- a/compiler/include/zamalang/Support/LambdaArgument.h +++ b/compiler/include/zamalang/Support/LambdaArgument.h @@ -136,6 +136,15 @@ public: llvm::ArrayRef value) : TensorLambdaArgument(value, {(int64_t)value.size()}) {} + template + TensorLambdaArgument( + typename ScalarArgumentT::value_type (&a)[size1][size2]) { + dimensions = {size1, size2}; + auto value = llvm::MutableArrayRef( + (typename ScalarArgumentT::value_type *)a, size1 * size2); + std::copy(value.begin(), value.end(), std::back_inserter(this->value)); + } + const std::vector &getDimensions() const { return this->dimensions; } // Returns the total number of elements in the tensor. If the number diff --git a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp index b1103504a..9a10911d5 100644 --- a/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/HLFHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -276,6 +276,172 @@ struct HLFHELinalgOpToLinalgGeneric }; }; +template inline mlir::RankedTensorType getRankedTensorType(T v) { + return ((mlir::Type)v.getType()).cast(); +} + +llvm::SmallVector parallelIteratorType(int n) { + return llvm::SmallVector(n, "parallel"); +} + +// This class rewrite pattern transforms any instance of +// operators `HLFHELinalg.ApplyMappedLookupTableEintOp` that implements the +// broadasting rules to an instance of `linalg.generic` with an appropriate +// region using `HLFHE.ApplyLookupTableEintOp` operation, an appropriate +// specification for the iteration dimensions and appropriate operations +// managing the accumulator of `linalg.generic`. +// +// The current implementation does not rely on 'tensor.extract_slice' +// because of a bug in lowering this operation. +// +// Example: +// %res = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) +// : (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) +// -> tensor<2x3x!HLFHE.eint<2>> +// +// becomes: +// +// #map = affine_map<(d0, d1) -> (d0, d1)> +// %init = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> +// %output = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types +// = ["parallel", "parallel"]} ins(%arg0, %arg2 : +// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 : +// tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) { +// ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5: +// !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +// // SHOULD BE +// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1] +// : tensor<5x4xi64> to tensor<4xi64> +// // BUT IS +// %i0 = arith.constant 0 : index +// ... +// %i3 = arith.constant 3 : index +// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64> +// ... +// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64> +// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64> +// %res = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]]) +// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, k = -1 : i32, +// levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = +// -1 : i32, polynomialSize = -1 : i32} +// : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> +// !MidLFHE.glwe<{_,_,_}{2}> linalg.yield %res : +// !MidLFHE.glwe<{_,_,_}{2}> +// } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> + +namespace HLFHELinalg = mlir::zamalang::HLFHELinalg; + +struct HLFHELinalgApplyMappedLookupTableToLinalgGeneric + : public mlir::OpRewritePattern { + HLFHELinalgApplyMappedLookupTableToLinalgGeneric( + ::mlir::MLIRContext *context, mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern( + context, benefit) {} + + ::mlir::LogicalResult + matchAndRewrite(HLFHELinalg::ApplyMappedLookupTableEintOp mappedLookup, + ::mlir::PatternRewriter &rewriter) const override { + namespace arith = mlir::arith; + namespace linalg = mlir::linalg; + namespace tensor = mlir::tensor; + namespace HLFHE = mlir::zamalang::HLFHE; + using Values = llvm::SmallVector; + using Types = llvm::SmallVector; + using AffineMaps = llvm::SmallVector; + using sliceArg = llvm::SmallVector; + + auto input = mappedLookup.t(); + auto luts = mappedLookup.luts(); + auto map = mappedLookup.map(); + + auto loc = mappedLookup.getLoc(); + auto tensorTy = getRankedTensorType(input); + auto lutsTy = getRankedTensorType(luts); + auto resultTy = getRankedTensorType(mappedLookup->getResult(0)); + auto elementTy = resultTy.getElementType(); + auto lutElmtTy = lutsTy.getElementType(); + auto lutsShape = lutsTy.getShape(); + auto lutSize = lutsShape[lutsShape.size() - 1]; + auto resultShape = resultTy.getShape(); + + auto integer = [&](auto v) -> mlir::Attribute { + return rewriter.getI64IntegerAttr(v); + }; + + auto _0_ = integer(0); + auto _1_ = integer(1); + auto lutSizeValue = integer(lutSize); + + // Create the body of the `linalg.generic` op + // %arg0 is an element of t (encrypted int) + // %arg1 is an element of map (i64) + // %arg2 is the output element + auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder, + mlir::Location nestedLoc, + mlir::ValueRange blockArgs) { + auto tElmt = blockArgs[0]; + auto lutIdx = blockArgs[1]; + auto indexTy = rewriter.getIndexType(); + + // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] : + // tensor to tensor + mlir::Value lut; + const bool WORKAROUND_EXTRACT_SLICE = true; + if (!WORKAROUND_EXTRACT_SLICE) { + sliceArg offsets{lutIdx, _0_}; + sliceArg sizes{_1_, lutSizeValue}; + sliceArg strides{_1_, _1_}; + auto lutTy = mlir::RankedTensorType::get( + {static_cast(lutSize)}, lutElmtTy); + lut = nestedBuilder.create( + loc, lutTy, luts, offsets, sizes, strides); + } else { + // WORKAROUND BEGIN + // A bug in linalg-bufferize prevents rank reduction in extract_slice + // Reshaping does not work either or is too complicated so let's rebuild + // the tensor from scratch + llvm::SmallVector consts; + llvm::SmallVector extracts; + for (int i = 0; i < lutSize; i++) { + consts.push_back( + // %5 = arith.constant( : index) : index + nestedBuilder.create( + loc, indexTy, rewriter.getIndexAttr(i))); + } + for (int i = 0; i < lutSize; i++) { + extracts.push_back( + // %8 = tensor.extract %luts[, ] : ... + nestedBuilder.create( + loc, luts, mlir::ValueRange({lutIdx, consts[i]}))); + } + // %12 = tensor.from_elements %8, ... : ... + lut = nestedBuilder.create(loc, extracts); + } // WORKAROUND END + // %res1 = apply_lookup_table %arg0 %lut + auto lookup = nestedBuilder.create( + loc, elementTy, tElmt, lut); + // linalg.yield %res1 : !HLFHE.eint<2> + nestedBuilder.create(loc, lookup.getResult()); + }; + + auto output = + rewriter.create(loc, resultShape, elementTy); + + // Create the `linalg.g eneric` op + Types resTys{resultTy}; + Values ins{input, map}; + Values outs{output}; + auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter); + AffineMaps affineMaps{indexOfInput, indexOfInput, indexOfInput}; + auto iteratorTypes = parallelIteratorType(resultShape.size()); + auto genericOp = rewriter.create( + loc, resTys, ins, outs, affineMaps, iteratorTypes, lambdaBlock); + rewriter.replaceOp(mappedLookup, {genericOp.getResult(0)}); + + return ::mlir::success(); + }; +}; + // This class rewrite pattern transforms any instance of // operators `HLFHELinalg.ApplyMultiLookupTableEintOp` that implements the // broadasting rules to an instance of `linalg.generic` with an appropriate @@ -847,6 +1013,8 @@ void HLFHETensorOpsToLinalg::runOnFunction() { }); patterns.insert( &getContext()); + patterns.insert( + &getContext()); patterns.insert(&getContext()); if (mlir::applyPartialConversion(function, target, std::move(patterns)) diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index f7f340625..55aeddb5b 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -902,7 +902,8 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { norm2SqEquiv = getSqMANP(matmulIntEintOp, operands); } else if (llvm::isa< mlir::zamalang::HLFHELinalg::ApplyLookupTableEintOp, - mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp>( + mlir::zamalang::HLFHELinalg::ApplyMultiLookupTableEintOp, + mlir::zamalang::HLFHELinalg::ApplyMappedLookupTableEintOp>( op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; } diff --git a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp index 54decddf9..8bb67ae93 100644 --- a/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp +++ b/compiler/lib/Dialect/HLFHE/IR/HLFHEOps.cpp @@ -108,12 +108,11 @@ bool verifyEncryptedIntegerInputsConsistency(::mlir::OpState &op, // Check the shape of l_cst argument auto width = ct.getWidth(); + auto expectedSize = 1 << width; auto lCstShape = l_cst.getShape(); - mlir::SmallVector expectedShape{1 << width}; + mlir::SmallVector expectedShape{expectedSize}; if (!l_cst.hasStaticShape(expectedShape)) { - op.emitOpError() << " should have as `l_cst` argument a shape of one " - "dimension equals to 2^p, where p is the width of the " - "`ct` argument."; + emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width); return mlir::failure(); } if (!l_cst.getElementType().isInteger(64)) { diff --git a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp index e9cc7114a..d7974ee31 100644 --- a/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp +++ b/compiler/lib/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp @@ -281,6 +281,85 @@ verifyApplyMultiLookupTable(ApplyMultiLookupTableEintOp &op) { return mlir::success(); } +mlir::RankedTensorType getTensorType(::mlir::Value value) { + return value.getType().cast(); +} + +template T getElmentType(::mlir::Value value) { + auto tTy = getTensorType(value); + return tTy.getElementType().cast(); +} + +mlir::IntegerType getClearElmentType(::mlir::Value value) { + return getElmentType(value); +} + +HLFHE::EncryptedIntegerType getEncryptedElmentType(::mlir::Value value) { + using namespace mlir::zamalang::HLFHE; + return getElmentType(value); +} + +mlir::LogicalResult verifyMapHasRightShape(ApplyMappedLookupTableEintOp &op, + ::mlir::Value &lut_input, + ::mlir::Value &lut_map) { + auto input_shape = getTensorType(lut_input).getShape(); + auto map_shape = getTensorType(lut_map).getShape(); + if (input_shape.equals(map_shape)) { + return mlir::success(); + } + std::string error; + int input_rank = input_shape.size(); + int map_rank = map_shape.size(); + std::string input_name = "'t' (operand #1)"; + std::string map_name = "'lut_map.getName()' (operand #3)"; + if (input_rank == map_rank) { + error = ": " + input_name + " dimensions differs from " + map_name; + } else { + error = ": " + input_name + " rank (=" + std::to_string(input_rank) + + ") differs from " + map_name + + " rank (=" + std::to_string(map_rank) + ")"; + } + op.emitOpError() << error; + return mlir::failure(); +} + +mlir::LogicalResult verifyLutsSize(ApplyMappedLookupTableEintOp &op, + ::mlir::Value &encryptedIndex, + ::mlir::Value &luts) { + auto index_width = getEncryptedElmentType(encryptedIndex).getWidth(); + auto actual_lut_size = getTensorType(luts).getShape().back(); + auto expected_lut_size = 1 << index_width; + if (actual_lut_size == expected_lut_size) { + return mlir::success(); + } + HLFHE::emitErrorBadLutSize(op, "luts", "ct", expected_lut_size, index_width); + return mlir::failure(); +} + +mlir::LogicalResult +verifyApplyMappedLookupTable(ApplyMappedLookupTableEintOp &op) { + auto t = op.t(); + auto luts = op.luts(); + auto map = op.map(); + auto result = op.getResult(); + + auto t_shape = getTensorType(t).getShape(); + if (!getTensorType(result).hasStaticShape(t_shape)) { + op.emitOpError() + << ": `t` (operand #1) and `map` (operand #2) must have the same shape"; + return mlir::failure(); + } + + if (!getTensorType(map).getElementType().isIndex()) { + op.emitOpError() + << ": `map` (operand #3) should contains elements of type `index`"; + return mlir::failure(); + } + + return mlir::success(verifyMapHasRightShape(op, t, map).succeeded() && + verifyLutsSize(op, t, luts).succeeded()); +} + ::mlir::LogicalResult verifyDotEintInt(Dot &op) { if (::mlir::failed(mlir::verifyCompatibleShape(op.lhs().getType(), op.rhs().getType()))) { diff --git a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp index 0e8ac18e9..745b4b9e6 100644 --- a/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp +++ b/compiler/lib/Dialect/MidLFHE/IR/MidLFHEOps.cpp @@ -1,5 +1,6 @@ #include "mlir/IR/Region.h" +#include "zamalang/Dialect/HLFHE/IR/HLFHEOps.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHEOps.h" #include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h" @@ -146,11 +147,10 @@ mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTable &op) { // Check the shape of l_cst argument auto width = ct.getP(); auto lCstShape = l_cst.getShape(); - mlir::SmallVector expectedShape{1 << width}; + auto expectedSize = 1 << width; + mlir::SmallVector expectedShape{expectedSize}; if (!l_cst.hasStaticShape(expectedShape)) { - op.emitOpError() << "should have as `l_cst` argument a shape of one " - "dimension equals to 2^p, where p is the width of the " - "`ct` argument."; + HLFHE::emitErrorBadLutSize(op, "l_cst", "ct", expectedSize, width); return mlir::failure(); } if (!l_cst.getElementType().isInteger(64)) { diff --git a/compiler/lib/Support/JitCompilerEngine.cpp b/compiler/lib/Support/JitCompilerEngine.cpp index 6e04991a5..8482be74e 100644 --- a/compiler/lib/Support/JitCompilerEngine.cpp +++ b/compiler/lib/Support/JitCompilerEngine.cpp @@ -92,8 +92,7 @@ JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName, llvm::InitializeNativeTargetAsmPrinter(); mlir::registerLLVMDialectTranslation(mlirContext); - llvm::function_ref optPipeline = - mlir::makeOptimizingTransformer(3, 0, nullptr); + auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr); llvm::Expected> lambdaOrErr = mlir::zamalang::JITLambda::create(funcName, module, optPipeline, diff --git a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir index d05921ccf..caf8806ff 100644 --- a/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir +++ b/compiler/tests/Dialect/HLFHE/op_apply_lookup_table_bad_dimension.mlir @@ -1,6 +1,6 @@ // RUN: not zamacompiler --action=roundtrip %s 2>&1| FileCheck %s -// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument. +// CHECK-LABEL: error: 'HLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 4(=2^2) to match `ct` (operand #1) elements bitwidth (2) func @apply_lookup_table(%arg0: !HLFHE.eint<2>, %arg1: tensor<8xi3>) -> !HLFHE.eint<2> { %1 = "HLFHE.apply_lookup_table"(%arg0, %arg1): (!HLFHE.eint<2>, tensor<8xi3>) -> (!HLFHE.eint<2>) return %1: !HLFHE.eint<2> diff --git a/compiler/tests/Dialect/HLFHELinalg/apply_mapped_lut_to_linalg.mlir b/compiler/tests/Dialect/HLFHELinalg/apply_mapped_lut_to_linalg.mlir new file mode 100644 index 000000000..e1593b37d --- /dev/null +++ b/compiler/tests/Dialect/HLFHELinalg/apply_mapped_lut_to_linalg.mlir @@ -0,0 +1,31 @@ +// RUN: zamacompiler %s --action=dump-midlfhe 2>&1 | FileCheck %s + + +//CHECK-LABEL: #map = affine_map<(d0, d1) -> (d0, d1)> +//CHECK-NEXT:module { +//CHECK-NEXT: func @mapped_lut(%arg0: tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, %[[LUTS:.*]]: tensor<5x4xi64>, %arg2: tensor<2x3xindex>) -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> { +//CHECK-NEXT: %[[V0:.*]] = linalg.init_tensor [2, 3] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: %[[V1:.*]] = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg2 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>, tensor<2x3xindex>) outs(%0 : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>>) { +//CHECK-NEXT: ^bb0(%arg3: !MidLFHE.glwe<{_,_,_}{2}>, %[[LUTIDX:.*]]: index, %arg5: !MidLFHE.glwe<{_,_,_}{2}>): // no predecessors +//DISABLED-CHECK-NEXT: %[[V3:.*]] = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1, 4] [1, 1] : tensor<5x4xi64> to tensor<4xi64> +//WORKAROUND BEGIN +//CHECK-NEXT: %[[C0:.*]] = arith.constant 0 : index +//CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +//CHECK-NEXT: %[[C2:.*]] = arith.constant 2 : index +//CHECK-NEXT: %[[C3:.*]] = arith.constant 3 : index +//CHECK-NEXT: %[[E0:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C0]]] : tensor<5x4xi64> +//CHECK-NEXT: %[[E1:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C1]]] : tensor<5x4xi64> +//CHECK-NEXT: %[[E2:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C2]]] : tensor<5x4xi64> +//CHECK-NEXT: %[[E3:.*]] = tensor.extract %[[LUTS]][%[[LUTIDX]], %[[C3]]] : tensor<5x4xi64> +//CHECK-NEXT: %[[LUT:.*]] = tensor.from_elements %[[E0]], %[[E1]], %[[E2]], %[[E3]] : tensor<4xi64> +//WORKAROUND END +//CHECK-NEXT: %[[V4:.*]] = "MidLFHE.apply_lookup_table"(%arg3, %[[LUT]]) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1 : i32, polynomialSize = -1 : i32} : (!MidLFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) -> !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: linalg.yield %[[V4]] : !MidLFHE.glwe<{_,_,_}{2}> +//CHECK-NEXT: } -> tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: return %[[V1]] : tensor<2x3x!MidLFHE.glwe<{_,_,_}{2}>> +//CHECK-NEXT: } +//CHECK-NEXT: } +func @mapped_lut(%t: tensor<2x3x!HLFHE.eint<2>>, %luts: tensor<5x4xi64>, %map: tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map): (tensor<2x3x!HLFHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) -> tensor<2x3x!HLFHE.eint<2>> + return %1: tensor<2x3x!HLFHE.eint<2>> +} diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir index e828353dd..fb34a92b9 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.invalid.mlir @@ -158,6 +158,47 @@ func @apply_multi_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tenso return %1: tensor<2x3x4x!HLFHE.eint<2>> } +// ----- + + +///////////////////////////////////////////////// +// HLFHELinalg.apply_mapped_lookup_table +///////////////////////////////////////////////// +func @apply_mapped_lookup_table_bad_lut_size_127_vs_128( + %input: tensor<2x3x4x!HLFHE.eint<7>>, + %luts: tensor<127xi64>, + %map: tensor<2x3x4xindex> +) -> tensor<2x3x4x!HLFHE.eint<7>> { + // expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : `luts` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}} + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<127xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> + return %1: tensor<2x3x4x!HLFHE.eint<7>> +} + +// ----- + +func @apply_mapped_lookup_table_bad_map_size( + %input: tensor<2x3x4x!HLFHE.eint<7>>, + %luts: tensor<128xi64>, + %map: tensor<2x3xindex> +) -> tensor<2x3x4x!HLFHE.eint<7>> { + // expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}} + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> + return %1: tensor<2x3x4x!HLFHE.eint<7>> +} + +// ----- + +func @apply_mapped_lookup_table_bad_map_elmt_type( + %input: tensor<2x3x4x!HLFHE.eint<7>>, + %luts: tensor<128xi64>, + %map: tensor<2x3xindex> +) -> tensor<2x3x4x!HLFHE.eint<7>> { + // expected-error @+1 {{'HLFHELinalg.apply_mapped_lookup_table' op : 't' (operand #1) rank (=3) differs from 'lut_map.getName()' (operand #3) rank (=2)}} + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<128xi64>, tensor<2x3xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> + return %1: tensor<2x3x4x!HLFHE.eint<7>> +} + + // ----- ///////////////////////////////////////////////// diff --git a/compiler/tests/Dialect/HLFHELinalg/ops.mlir b/compiler/tests/Dialect/HLFHELinalg/ops.mlir index d9db4af31..8a3883118 100644 --- a/compiler/tests/Dialect/HLFHELinalg/ops.mlir +++ b/compiler/tests/Dialect/HLFHELinalg/ops.mlir @@ -288,6 +288,22 @@ func @apply_multi_lookup_table_broadcast(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %a return %1: tensor<2x3x4x!HLFHE.eint<2>> } +///////////////////////////////////////////////// +// HLFHELinalg.apply_mapped_lookup_table +///////////////////////////////////////////////// + +// CHECK-LABEL: func @apply_mapped_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<7>>, %arg1: tensor<10x128xi64>, %arg2: tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> { +func @apply_mapped_lookup_table( + %input: tensor<2x3x4x!HLFHE.eint<7>>, + %luts: tensor<10x128xi64>, + %map: tensor<2x3x4xindex> +) -> tensor<2x3x4x!HLFHE.eint<7>> { + // CHECK-NEXT: %0 = "HLFHELinalg.apply_mapped_lookup_table"(%arg0, %arg1, %arg2) : (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> tensor<2x3x4x!HLFHE.eint<7>> + // CHECK-NEXT: return %0 : tensor<2x3x4x!HLFHE.eint<7>> + %0 = "HLFHELinalg.apply_mapped_lookup_table"(%input, %luts, %map): (tensor<2x3x4x!HLFHE.eint<7>>, tensor<10x128xi64>, tensor<2x3x4xindex>) -> (tensor<2x3x4x!HLFHE.eint<7>>) + return %0: tensor<2x3x4x!HLFHE.eint<7>> +} + ///////////////////////////////////////////////// // HLFHELinalg.dot_eint_int ///////////////////////////////////////////////// diff --git a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir index a99754f2d..56f5a0d34 100644 --- a/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir +++ b/compiler/tests/Dialect/MidLFHE/op_apply_lookup_table.invalid.mlir @@ -2,7 +2,7 @@ // Bad dimension of the lookup table func @apply_lookup_table(%arg0: !MidLFHE.glwe<{1024,12,64}{7}>, %arg1: tensor<4xi2>) -> !MidLFHE.glwe<{512,10,64}{2}> { - // expected-error @+1 {{'MidLFHE.apply_lookup_table' op should have as `l_cst` argument a shape of one dimension equals to 2^p, where p is the width of the `ct` argument}} + // expected-error @+1 {{'MidLFHE.apply_lookup_table' op : `l_cst` (operand #2) inner dimension should have size 128(=2^7) to match `ct` (operand #1) elements bitwidth (7)}} %1 = "MidLFHE.apply_lookup_table"(%arg0, %arg1) {glweDimension = 1 : i32, polynomialSize = 1024 : i32, levelKS = 2 : i32, baseLogKS = -82 : i32, levelBS = 3 : i32, baseLogBS = -83 : i32, outputSizeKS = 600 : i32}: (!MidLFHE.glwe<{1024,12,64}{7}>, tensor<4xi2>) -> (!MidLFHE.glwe<{512,10,64}{2}>) return %1: !MidLFHE.glwe<{512,10,64}{2}> } diff --git a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc index ea63465c7..15088472f 100644 --- a/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc +++ b/compiler/tests/unittest/end_to_end_jit_hlfhelinalg.cc @@ -1,5 +1,9 @@ #include "end_to_end_jit_test.h" +namespace Z = mlir::zamalang; +template +using tensorArgTy = Z::TensorLambdaArgument>; + #define GET_3D(tensor, i, j, k, di, dj, dk) (tensor)[i * dj * dk + j * dk + k] #define GET_2D(tensor, i, j, di, dj) (tensor)[i * dj + j] @@ -1149,6 +1153,110 @@ TEST(End2EndJit_HLFHELinalg, apply_multi_lookup_table_with_boradcast) { } } +/////////////////////////////////////////////////////////////////////////////// +// HLFHELinalg apply_mapped_lookup_table ///////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_sequential) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers. + func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : + (tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> + return %1: tensor<3x3x!HLFHE.eint<2>> + } +)XXX"); + uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + uint64_t luts[9][4]{ + {3, 0, 0, 0}, {0, 3, 0, 0}, {0, 0, 3, 0}, + {0, 0, 0, 3}, {3, 0, 0, 0}, {0, 3, 0, 0}, + {0, 0, 3, 0}, {0, 0, 0, 3}, {3, 0, 0, 0}, + }; + uint64_t map[3][3]{ + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + }; + uint8_t expected[3][3]{ + {3, 3, 3}, + {3, 3, 3}, + {3, 3, 3}, + }; + + tensorArgTy tArg(t); + tensorArgTy lutsArg(luts), mapArg(map); + + llvm::Expected> res = + lambda.operator()>( + {&tArg, &lutsArg, &mapArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + +TEST(End2EndJit_HLFHELinalg, apply_mapped_lookup_table_same_lut) { + + mlir::zamalang::JitCompilerEngine::Lambda lambda = checkedJit(R"XXX( + // Returns the lookup of 3x3 matrix of encrypted indices of width 2 of a 3x3 matrix of tables of size 4=2² of clear integers. + func @main(%t: tensor<3x3x!HLFHE.eint<2>>, %luts: tensor<9x4xi64>, %map: tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> { + %1 = "HLFHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) : + (tensor<3x3x!HLFHE.eint<2>>, tensor<9x4xi64>, tensor<3x3xindex>) -> tensor<3x3x!HLFHE.eint<2>> + return %1: tensor<3x3x!HLFHE.eint<2>> + } +)XXX"); + uint8_t t[3][3]{ + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + }; + uint64_t luts[9][4]{ + {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, + {0, 0, 0, 0}, {1, 2, 3, 1}, {0, 0, 0, 0}, + {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, + }; + uint64_t map[3][3]{ + {4, 4, 4}, + {4, 4, 4}, + {4, 4, 4}, + }; + uint8_t expected[3][3]{ + {1, 2, 3}, + {1, 1, 2}, + {3, 1, 1}, + }; + + tensorArgTy tArg(t); + tensorArgTy lutsArg(luts), mapArg(map); + + llvm::Expected> res = + lambda.operator()>( + {&tArg, &lutsArg, &mapArg}); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), 3 * 3); + + for (size_t i = 0; i < 3; i++) { + for (size_t j = 0; j < 3; j++) { + EXPECT_EQ((*res)[i * 3 + j], expected[i][j]) + << ", at pos(" << i << "," << j << ")"; + } + } +} + /////////////////////////////////////////////////////////////////////////////// // HLFHELinalg dot_eint_int /////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////