diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp
index 4ce0dcebb..abbb9a80f 100644
--- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp
+++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp
@@ -430,10 +430,6 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
///
/// #maps_0 = [
/// affine_map<(d0, d1) -> (d0, d1)>
-/// affine_map<(d0, d1) -> (d1, 0)>
-/// affine_map<(d0, d1) -> (d1, 1)>
-/// affine_map<(d0, d1) -> (d1, 2)>
-/// affine_map<(d0, d1) -> (d1, 3)>
/// ]
/// #attributes_0 {
/// indexing_maps = #maps_0,
@@ -442,18 +438,18 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
/// %init = linalg.init_tensor [4, 3]
/// : tensor<4x3x!FHE.eint<2>>
/// %res = linalg.generic {
-/// ins(%t, %luts, %luts, %luts, %luts: tensor<4x3x!FHE.eint
>,
-/// tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>)
+/// ins(%t, %luts: tensor<4x3x!FHE.eint
>)
/// outs(%init : tensor<4x3x!FHE.eint<2>>)
/// {
-/// ^bb0(%arg0: !FHE.eint<2>, %arg1: i64, %arg2: i64, %arg3: i64,
-/// %arg4: i64, %arg5: !FHE.eint<2>):
-/// %lut = tensor.from_elements %arg1, %arg2, %arg3, %arg4 :
-/// tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0, %lut)
-/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension = -1
-/// : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS = -1
-/// : i32, polynomialSize = -1 : i32} : (!TFHE.glwe<{_,_,_}{2}>,
-/// tensor<4xi64>) -> !TFHE.glwe<{_,_,_}{2}>
+/// ^bb0(%arg0: !FHE.eint<2>):
+/// %i_lut = linalg.index 0 ; index
+/// %lut = tensor.extract_slice %arg21[%i_lut, 0] [1, lut_size] [1,
+/// 1] : ... tensor<4xi64> %0 = "TFHE.apply_lookup_table"(%arg0,
+/// %lut) {baseLogBS = -1 : i32, baseLogKS = -1 : i32, glweDimension
+/// = -1 : i32, levelBS = -1 : i32, levelKS = -1 : i32, outputSizeKS
+/// = -1 : i32, polynomialSize = -1 : i32} :
+/// (!TFHE.glwe<{_,_,_}{2}>, tensor<4xi64>) ->
+/// !TFHE.glwe<{_,_,_}{2}>
/// linalg.yield %0 : !FHE.eint<2>
/// }
/// }
@@ -477,55 +473,83 @@ struct FHELinalgApplyMultiLookupTableToLinalgGeneric
.cast();
mlir::RankedTensorType tensorTy = ((mlir::Type)fheLinalgLutOp.t().getType())
.cast();
- mlir::RankedTensorType lutsTy =
- ((mlir::Type)fheLinalgLutOp.luts().getType())
- .cast();
+ auto luts = fheLinalgLutOp.luts();
+ mlir::RankedTensorType lutsTy = getRankedTensorType(luts);
+ auto lutElmtTy = lutsTy.getElementType();
// linalg.init_tensor for initial value
mlir::Value init = rewriter.create(
fheLinalgLutOp.getLoc(), resultTy, mlir::ValueRange{});
auto lutsShape = lutsTy.getShape();
auto lut_size = lutsShape[lutsShape.size() - 1];
+ auto indexOfInput = getBroadcastedAffineMap(resultTy, tensorTy, rewriter);
// Create the affine maps
- llvm::SmallVector maps{
- // Input tensor map
- getBroadcastedAffineMap(resultTy, tensorTy, rewriter)};
- maps.reserve(lut_size + 1);
- // Create as much affine maps as the size of the lut dimension
- for (int64_t i = 0; i < lut_size; i++)
- maps.push_back(
- getBroadcastedAffineMapMultiLUT(resultTy, lutsTy, i, rewriter));
- // Result map
- maps.push_back(getBroadcastedAffineMap(resultTy, resultTy, rewriter));
+ llvm::SmallVector maps{indexOfInput, indexOfInput};
// Create the iterator_types
- llvm::SmallVector iteratorTypes(resultTy.getShape().size(),
- "parallel");
+ auto iteratorTypes = parallelIteratorType(resultTy.getShape().size());
+ auto integer = [&](auto v) -> mlir::Attribute {
+ return rewriter.getI64IntegerAttr(v);
+ };
+
+ // We need to know with linalg.generic index to use for lut
+ // In broadcast case the lut index is inner dimensions of the tensor index
+ auto tensorShape = tensorTy.getShape();
+ auto tensorRank = tensorTy.getShape().size();
+ auto lutsRank = lutsShape.size() - 1; // do not count inner dim of luts
+ auto lutIndexDimAt = tensorRank - lutsRank;
+ llvm::SmallVector indexLutsToLinalg(lutsRank);
+ for (auto lutsIndex = 0u; lutsIndex < lutsRank; lutsIndex++) {
+ auto tensorIndex = lutIndexDimAt + lutsIndex;
+ if (tensorShape[tensorIndex] != lutsShape[lutsIndex]) {
+ llvm::errs() << "ERROR: Broadcast only works by having more outer "
+ "dims.\nConflict: "
+ << tensorShape[tensorIndex] << " (tensor dim "
+ << tensorIndex << ") is not compatible with "
+ << lutsShape[lutsIndex] << " (luts dim " << lutsIndex
+ << ")\n\n";
+ return ::mlir::LogicalResult::failure();
+ };
+ indexLutsToLinalg[lutsIndex] = tensorIndex;
+ }
+
+ auto _0_ = integer(0);
+ auto _1_ = integer(1);
+ auto lutSizeValue = integer(lut_size);
// Create the body of the `linalg.generic` op
auto bodyBuilder = [&](mlir::OpBuilder &nestedBuilder,
mlir::Location nestedLoc,
mlir::ValueRange blockArgs) {
- mlir::tensor::FromElementsOp lut =
- nestedBuilder.create(
- fheLinalgLutOp.getLoc(), blockArgs.slice(1, lut_size));
- mlir::concretelang::FHE::ApplyLookupTableEintOp lutOp =
- nestedBuilder.create(
- fheLinalgLutOp.getLoc(), resultTy.getElementType(), blockArgs[0],
- lut.result());
+ auto loc = fheLinalgLutOp.getLoc();
+ auto tElmt = blockArgs[0];
- nestedBuilder.create(fheLinalgLutOp.getLoc(),
- lutOp.getResult());
+ // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
+ // tensor to tensor
+ auto sliceArgDim = lutsShape.size();
+ using sliceArg = llvm::SmallVector;
+ sliceArg offsets(sliceArgDim, _0_);
+ auto lutsIndex = 0;
+ for (auto index : indexLutsToLinalg) {
+ auto offset = nestedBuilder.create(loc, index);
+ offsets[lutsIndex++] = (mlir::OpFoldResult)offset;
+ }
+ sliceArg sizes(sliceArgDim, _1_);
+ sizes[sliceArgDim - 1] = lutSizeValue;
+ sliceArg strides(sliceArgDim, _1_);
+ auto lutTy = mlir::RankedTensorType::get({static_cast(lut_size)},
+ lutElmtTy);
+ mlir::Value lut = nestedBuilder.create(
+ loc, lutTy, luts, offsets, sizes, strides);
+ auto lutOp = nestedBuilder.create(
+ loc, resultTy.getElementType(), tElmt, lut);
+
+ nestedBuilder.create(loc, lutOp.getResult());
};
// Create the `linalg.generic` op
llvm::SmallVector resTypes{init.getType()};
llvm::SmallVector ins{fheLinalgLutOp.t()};
- ins.reserve(lut_size + 2);
- // We extract one value at a time from one LUT using different maps, so we
- // need to pass the LUT `lut_size` time
- for (auto i = 0; i < lut_size; i++)
- ins.push_back(fheLinalgLutOp.luts());
llvm::SmallVector outs{init};
llvm::StringRef doc{""};
llvm::StringRef call{""};
diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir
index 8f0f7a8fb..df21a3c21 100644
--- a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir
+++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg.mlir
@@ -1,21 +1,20 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
-//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
-//CHECK: #map1 = affine_map<(d0, d1) -> (d0, d1, 0)>
-//CHECK: #map2 = affine_map<(d0, d1) -> (d0, d1, 1)>
-//CHECK: #map3 = affine_map<(d0, d1) -> (d0, d1, 2)>
-//CHECK: #map4 = affine_map<(d0, d1) -> (d0, d1, 3)>
-//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x4x!FHE.eint<2>>, %[[A1:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
-//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
-//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>, tensor<4x4x4xi64>) outs(%[[V0]] : tensor<4x4x!FHE.eint<2>>) {
-//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>):
-//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64>
-//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
-//CHECK: linalg.yield %[[V3]] : !FHE.eint<2>
+
+//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+//CHECK: func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
+//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x4x!FHE.eint<2>>
+//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x4x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x4x!FHE.eint<2>>) {
+//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
+//CHECK: %[[INDEXA:.*]] = linalg.index 0 : index
+//CHECK: %[[INDEXB:.*]] = linalg.index 1 : index
+//CHECK: %[[LUT:.*]] = tensor.extract_slice %[[LUTS]][%[[INDEXA]], %[[INDEXB]], 0] [1, 1, 4] [1, 1, 1] : tensor<4x4x4xi64> to tensor<4xi64>
+//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
+//CHECK: linalg.yield %[[V]] : !FHE.eint<2>
//CHECK: } -> tensor<4x4x!FHE.eint<2>>
-//CHECK: return %[[V1]] : tensor<4x4x!FHE.eint<2>>
+//CHECK: return %[[R]] : tensor<4x4x!FHE.eint<2>>
//CHECK: }
-func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %arg1: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
- %0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>>
+func.func @multi_lut(%arg0: tensor<4x4x!FHE.eint<2>>, %luts: tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>> {
+ %0 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x4x!FHE.eint<2>>, tensor<4x4x4xi64>) -> tensor<4x4x!FHE.eint<2>>
return %0: tensor<4x4x!FHE.eint<2>>
}
diff --git a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir
index dffeb009f..8361c193f 100644
--- a/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir
+++ b/compiler/tests/check_tests/Conversion/FHELinalgToLinalg/apply_multi_lut_to_linalg_broadcast.mlir
@@ -1,21 +1,18 @@
// RUN: concretecompiler %s --action=dump-tfhe --passes fhe-tensor-ops-to-linalg 2>&1 | FileCheck %s
-//CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
-//CHECK: #map1 = affine_map<(d0, d1) -> (d1, 0)>
-//CHECK: #map2 = affine_map<(d0, d1) -> (d1, 1)>
-//CHECK: #map3 = affine_map<(d0, d1) -> (d1, 2)>
-//CHECK: #map4 = affine_map<(d0, d1) -> (d1, 3)>
-//CHECK: func.func @multi_lut(%[[A0:.*]]: tensor<4x3x!FHE.eint<2>>, %[[A1:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
-//CHECK: %[[V0:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
-//CHECK: %[[V1:.*]] = linalg.generic {indexing_maps = [#map0, #map1, #map2, #map3, #map4, #map0], iterator_types = ["parallel", "parallel"]} ins(%[[A0]], %[[A1]], %arg1, %arg1, %arg1 : tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>, tensor<3x4xi64>) outs(%[[V0]] : tensor<4x3x!FHE.eint<2>>) {
-//CHECK: ^bb0(%[[A2:.*]]: !FHE.eint<2>, %[[A3:.*]]: i64, %[[A4:.*]]: i64, %[[A5:.*]]: i64, %[[A6:.*]]: i64, %[[A7:.*]]: !FHE.eint<2>):
-//CHECK: %[[V2:.*]] = tensor.from_elements %[[A3]], %[[A4]], %[[A5]], %[[A6]] : tensor<4xi64>
-//CHECK: %[[V3:.*]] = "FHE.apply_lookup_table"(%[[A2]], %[[V2]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
-//CHECK: linalg.yield %[[V3]] : !FHE.eint<2>
+//CHECK: #map = affine_map<(d0, d1) -> (d0, d1)>
+//CHECK: func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %[[LUTS:.*]]: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
+//CHECK: %[[MEM:.*]] = bufferization.alloc_tensor() : tensor<4x3x!FHE.eint<2>>
+//CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel"]} ins(%arg0 : tensor<4x3x!FHE.eint<2>>) outs(%[[MEM]] : tensor<4x3x!FHE.eint<2>>) {
+//CHECK: ^bb0(%[[IN:.*]]: !FHE.eint<2>, %[[UNUSED:.*]]: !FHE.eint<2>):
+//CHECK: %[[INDEX:.*]] = linalg.index 1 : index
+//CHECK: %[[LUT:.*]] = tensor.extract_slice %arg1[%[[INDEX]], 0] [1, 4] [1, 1] : tensor<3x4xi64> to tensor<4xi64>
+//CHECK: %[[V:.*]] = "FHE.apply_lookup_table"(%[[IN]], %[[LUT]]) : (!FHE.eint<2>, tensor<4xi64>) -> !FHE.eint<2>
+//CHECK: linalg.yield %[[V]] : !FHE.eint<2>
//CHECK: } -> tensor<4x3x!FHE.eint<2>>
-//CHECK: return %[[V1]] : tensor<4x3x!FHE.eint<2>>
+//CHECK: return %[[R]] : tensor<4x3x!FHE.eint<2>>
//CHECK: }
-func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %arg1: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
- %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %arg1): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
+func.func @multi_lut(%arg0: tensor<4x3x!FHE.eint<2>>, %luts: tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>> {
+ %1 = "FHELinalg.apply_multi_lookup_table"(%arg0, %luts): (tensor<4x3x!FHE.eint<2>>, tensor<3x4xi64>) -> tensor<4x3x!FHE.eint<2>>
return %1: tensor<4x3x!FHE.eint<2>>
}