mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
@@ -299,9 +299,6 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
/// specification for the iteration dimensions and appropriate operations
|
||||
/// managing the accumulator of `linalg.generic`.
|
||||
///
|
||||
/// The current implementation does not rely on 'tensor.extract_slice'
|
||||
/// because of a bug in lowering this operation.
|
||||
///
|
||||
/// Example:
|
||||
/// %res = "FHELinalg.apply_mapped_lookup_table"(%t, %luts, %map)
|
||||
/// : (tensor<2x3x!FHE.eint<2>>, tensor<5x4xi64>, tensor<2x3xindex>)
|
||||
@@ -317,17 +314,8 @@ llvm::SmallVector<llvm::StringRef> parallelIteratorType(int n) {
|
||||
/// tensor<2x3x!TFHE.glwe<{_,_,_}{2}>>) {
|
||||
/// ^bb0(%arg3: !TFHE.glwe<{_,_,_}{2}>, %lut_idx: index, %arg5:
|
||||
/// !TFHE.glwe<{_,_,_}{2}>): // no predecessors
|
||||
/// // SHOULD BE
|
||||
/// %lut = tensor.extract_slice %arg1[%[[LUTIDX]], 0] [1,4] [1, 1]
|
||||
/// : tensor<5x4xi64> to tensor<4xi64>
|
||||
/// // BUT IS
|
||||
/// %i0 = arith.constant 0 : index
|
||||
/// ...
|
||||
/// %i3 = arith.constant 3 : index
|
||||
/// %e0 = tensor.extract %arg5[%lut_idx, %i0] : tensor<5x4xi64>
|
||||
/// ...
|
||||
/// %e3 = tensor.extract %arg5[%lut_idx, %i3] : tensor<5x4xi64>
|
||||
/// %lut = tensor.from_elements %e0, ..., %e3 : tensor<4xi64>
|
||||
/// %res = "TFHE.apply_lookup_table"(%arg3, %[[LUT]])
|
||||
/// {baseLogBS = -1 : i32, baseLogKS = -1 : i32,
|
||||
/// glweDimension = -1 : i32,
|
||||
@@ -383,49 +371,23 @@ struct FHELinalgApplyMappedLookupTableToLinalgGeneric
|
||||
|
||||
// Create the body of the `linalg.generic` op
|
||||
// %arg0 is an element of t (encrypted int)
|
||||
// %arg1 is an element of map (i64)
|
||||
// %arg1 is the lut index (i64)
|
||||
// %arg2 is the output element
|
||||
auto lambdaBlock = [&](mlir::OpBuilder &nestedBuilder,
|
||||
mlir::Location nestedLoc,
|
||||
mlir::ValueRange blockArgs) {
|
||||
auto tElmt = blockArgs[0];
|
||||
auto lutIdx = blockArgs[1];
|
||||
auto indexTy = rewriter.getIndexType();
|
||||
|
||||
// %lut = extract_slice %luts[%lutIdx, 0][1, lutSize][1, 1] :
|
||||
// tensor<NxKxi64> to tensor<Kxi64>
|
||||
mlir::Value lut;
|
||||
const bool WORKAROUND_EXTRACT_SLICE = true;
|
||||
if (!WORKAROUND_EXTRACT_SLICE) {
|
||||
sliceArg offsets{lutIdx, _0_};
|
||||
sliceArg sizes{_1_, lutSizeValue};
|
||||
sliceArg strides{_1_, _1_};
|
||||
auto lutTy = mlir::RankedTensorType::get(
|
||||
{static_cast<int64_t>(lutSize)}, lutElmtTy);
|
||||
lut = nestedBuilder.create<tensor::ExtractSliceOp>(
|
||||
loc, lutTy, luts, offsets, sizes, strides);
|
||||
} else {
|
||||
// WORKAROUND BEGIN
|
||||
// A bug in linalg-bufferize prevents rank reduction in extract_slice
|
||||
// Reshaping does not work either or is too complicated so let's rebuild
|
||||
// the tensor from scratch
|
||||
llvm::SmallVector<mlir::Value> consts;
|
||||
llvm::SmallVector<mlir::Value> extracts;
|
||||
for (int i = 0; i < lutSize; i++) {
|
||||
consts.push_back(
|
||||
// %5 = arith.constant(<i> : index) : index
|
||||
nestedBuilder.create<mlir::arith::ConstantOp>(
|
||||
loc, indexTy, rewriter.getIndexAttr(i)));
|
||||
}
|
||||
for (int i = 0; i < lutSize; i++) {
|
||||
extracts.push_back(
|
||||
// %8 = tensor.extract %luts[<lutIdx>, <i>] : ...
|
||||
nestedBuilder.create<tensor::ExtractOp>(
|
||||
loc, luts, mlir::ValueRange({lutIdx, consts[i]})));
|
||||
}
|
||||
// %12 = tensor.from_elements %8, ... : ...
|
||||
lut = nestedBuilder.create<tensor::FromElementsOp>(loc, extracts);
|
||||
} // WORKAROUND END
|
||||
sliceArg offsets{lutIdx, _0_};
|
||||
sliceArg sizes{_1_, lutSizeValue};
|
||||
sliceArg strides{_1_, _1_};
|
||||
auto lutTy = mlir::RankedTensorType::get({static_cast<int64_t>(lutSize)},
|
||||
lutElmtTy);
|
||||
mlir::Value lut = nestedBuilder.create<tensor::ExtractSliceOp>(
|
||||
loc, lutTy, luts, offsets, sizes, strides);
|
||||
// %res1 = apply_lookup_table %arg0 %lut
|
||||
auto lookup = nestedBuilder.create<FHE::ApplyLookupTableEintOp>(
|
||||
loc, elementTy, tElmt, lut);
|
||||
|
||||
Reference in New Issue
Block a user