diff --git a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp index aa6da0a4f..4ce0dcebb 100644 --- a/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp +++ b/compiler/lib/Conversion/FHETensorOpsToLinalg/TensorOpsToLinalg.cpp @@ -299,9 +299,6 @@ llvm::SmallVector parallelIteratorType(int n) { /// specification for the iteration dimensions and appropriate operations /// managing the accumulator of `linalg.generic`. /// -/// The current implementation does not rely on 'tensor.extract_slice' -/// because of a bug in lowering this operation. -/// /// Example: /// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map) /// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>) @@ -317,17 +314,8 @@ llvm::SmallVector parallelIteratorType(int n) { /// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) { /// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5: /// !TFHE.glwe<{_,_,_}{2}>): // no predecessors -/// // SHOULD BE /// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1] /// : tensor<5x4xi64> to tensor<4xi64> -/// // BUT IS -/// %i0 = arith.constant 0 : index -/// ... -/// %i3 = arith.constant 3 : index -/// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64> -/// ... -/// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64> -/// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64> /// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]]) /// {baseLogBS = -1 : i32, baseLogKS = -1 : i32, /// glweDimension = -1 : i32, @@ -383,49 +371,23 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric // Create the body of the `linalg.generic` op // %arg0 is an element of t (encrypted int) - // %arg1 is an element of map (i64) + // %arg1 is the lut index (i64) // %arg2 is the output element auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder, mlir::Location nestedLoc, mlir::ValueRange blockArgs) { auto tElmt = blockArgs[0]; auto lutIdx = blockArgs[1]; - auto indexTy = rewriter.getIndexType(); // %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] : // tensor to tensor - mlir::Value lut; - const bool WORKAROUND_EXTRACT_SLICE = true; - if (!WORKAROUND_EXTRACT_SLICE) { - sliceArg offsets{lutIdx, _0_}; - sliceArg sizes{_1_, lutSizeValue}; - sliceArg strides{_1_, _1_}; - auto lutTy = mlir::RankedTensorType::get( - {static_cast(lutSize)}, lutElmtTy); - lut = nestedBuilder.create( - loc, lutTy, luts, offsets, sizes, strides); - } else { - // WORKAROUND BEGIN - // A bug in linalg-bufferize prevents rank reduction in extract_slice - // Reshaping does not work either or is too complicated so let's rebuild - // the tensor from scratch - llvm::SmallVector consts; - llvm::SmallVector extracts; - for (int i = 0; i < lutSize; i++) { - consts.push_back( - // %5 = arith.constant( : index) : index - nestedBuilder.create( - loc, indexTy, rewriter.getIndexAttr(i))); - } - for (int i = 0; i < lutSize; i++) { - extracts.push_back( - // %8 = tensor.extract %luts[, ] : ... - nestedBuilder.create( - loc, luts, mlir::ValueRange({lutIdx, consts[i]}))); - } - // %12 = tensor.from_elements %8, ... : ... - lut = nestedBuilder.create(loc, extracts); - } // WORKAROUND END + sliceArg offsets{lutIdx, _0_}; + sliceArg sizes{_1_, lutSizeValue}; + sliceArg strides{_1_, _1_}; + auto lutTy = mlir::RankedTensorType::get({static_cast(lutSize)}, + lutElmtTy); + mlir::Value lut = nestedBuilder.create( + loc, lutTy, luts, offsets, sizes, strides); // %res1 = apply_lookup_table %arg0 %lut auto lookup = nestedBuilder.create( loc, elementTy, tElmt, lut);