mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): HLFHELinalg.apply_lookup_table definition
This commit is contained in:
committed by
Andi Drebes
parent
f72d51d98d
commit
dea1be9d52
@@ -220,7 +220,45 @@ def MulEintIntOp : HLFHELinalg_Op<"mul_eint_int", [TensorBroadcastingRules, Tens
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
}
|
||||
|
||||
def ApplyLookupTableEintOp : HLFHELinalg_Op<"apply_lookup_table", []> {
|
||||
let summary = "Returns a tensor that contains the result of the lookup on a table.";
|
||||
|
||||
let description = [{
|
||||
Performs for each encrypted indices a lookup on a table of clear integers.
|
||||
|
||||
```mlir
|
||||
// The result of this operation, is a tensor that contains the result of the lookup on a table.
|
||||
// i.e. %res[i, ..., k] = %lut[%t[i, ..., k]]
|
||||
%res = HLFHELinalg.apply_lookup_table(%t, %lut): tensor<DNx...xD1x!HLFHE.eint<$p>>, tensor<D2^$pxi64> -> tensor<DNx...xD1x!HLFHE.eint<$p>>
|
||||
```
|
||||
|
||||
The `%lut` argument should be a tensor with one dimension, where its dimension is equals to `2^p` where `p` is the width of the encrypted integers.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
|
||||
// Returns the lookup of 3x3 matrix of encrypted indices of with 2 on a table of size 4=2² of clear integers.
|
||||
//
|
||||
// [0,1,2] [2,4,6]
|
||||
// [3,0,1] lut [2,4,6,8] = [8,2,4]
|
||||
// [2,3,0] [6,8,0]
|
||||
//
|
||||
"HLFHELinalg.apply_lookup_table(%t, %lut)" : (tensor<3x3x!HLFHE.eint<2>>, tensor<4xi4>) -> tensor<3x3x!HLFHE.eint<3>>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>:$t,
|
||||
Type<And<[TensorOf<[AnyInteger]>.predicate, HasStaticShapePred]>>:$lut
|
||||
);
|
||||
|
||||
let results = (outs Type<And<[TensorOf<[EncryptedIntegerType]>.predicate, HasStaticShapePred]>>);
|
||||
|
||||
let verifier = [{
|
||||
return ::mlir::zamalang::HLFHELinalg::verifyApplyLookupTable(*this);
|
||||
}];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -202,5 +202,38 @@ LogicalResult verifyTensorBinaryEint(mlir::Operation *op) {
|
||||
} // namespace OpTrait
|
||||
} // namespace mlir
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
namespace HLFHELinalg {
|
||||
|
||||
mlir::LogicalResult verifyApplyLookupTable(ApplyLookupTableEintOp &op) {
|
||||
auto tTy = op.t().getType().cast<mlir::RankedTensorType>();
|
||||
auto tEltTy =
|
||||
tTy.getElementType().cast<mlir::zamalang::HLFHE::EncryptedIntegerType>();
|
||||
auto lutTy = op.lut().getType().cast<mlir::RankedTensorType>();
|
||||
auto lutEltTy = lutTy.getElementType().cast<mlir::IntegerType>();
|
||||
auto resultTy = op.getResult().getType().cast<mlir::RankedTensorType>();
|
||||
|
||||
// Check the shape of lut argument
|
||||
auto tEltwidth = tEltTy.getWidth();
|
||||
mlir::SmallVector<int64_t, 1> expectedShape{1 << tEltwidth};
|
||||
if (!lutTy.hasStaticShape(expectedShape) || !lutEltTy.isInteger(64)) {
|
||||
op.emitOpError()
|
||||
<< "should have as operand #2 a tensor<2^pxi64>, where p is the width "
|
||||
"of the encrypted integer of the operand #1,"
|
||||
<< "expect tensor <" << expectedShape[0] << "xi64>";
|
||||
return mlir::failure();
|
||||
}
|
||||
if (!resultTy.hasStaticShape(tTy.getShape())) {
|
||||
op.emitOpError()
|
||||
<< " should have same shapes for operand #1 and the result";
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace HLFHELinalg
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.cpp.inc"
|
||||
|
||||
@@ -118,4 +118,30 @@ func @main(%a0: tensor<2x3x4x!HLFHE.eint<2>>, %a1: tensor<2x3x4xi4>) -> tensor<2
|
||||
return %1 : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
// -----
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.apply_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi32>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
|
||||
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi32>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<12xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have as operand #2 a tensor<2^pxi64>, where p is the width of the encrypted integer of the operand #1,expect tensor <4xi64>}}
|
||||
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<12xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @apply_lookup_table(%arg0: tensor<3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// expected-error @+1 {{'HLFHELinalg.apply_lookup_table' op should have same shapes for operand #1 and the result}}
|
||||
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
@@ -216,4 +216,17 @@ func @mul_eint_int_broadcast_1(%a0: tensor<1x4x5x!HLFHE.eint<2>>, %a1: tensor<3x
|
||||
func @mul_eint_int_broadcast_2(%a0: tensor<4x!HLFHE.eint<2>>, %a1: tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>> {
|
||||
%1 ="HLFHELinalg.mul_eint_int"(%a0, %a1) : (tensor<4x!HLFHE.eint<2>>, tensor<3x4xi3>) -> tensor<3x4x!HLFHE.eint<2>>
|
||||
return %1: tensor<3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////
|
||||
// HLFHELinalg.apply_lookup_table
|
||||
/////////////////////////////////////////////////
|
||||
|
||||
// CHECK-LABEL: func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>>
|
||||
func @apply_lookup_table(%arg0: tensor<2x3x4x!HLFHE.eint<2>>, %arg1: tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1) : (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> tensor<2x3x4x!HLFHE.eint<2>>
|
||||
// CHECK-NEXT: return %[[V1]] : tensor<2x3x4x!HLFHE.eint<2>>
|
||||
|
||||
%1 = "HLFHELinalg.apply_lookup_table"(%arg0, %arg1): (tensor<2x3x4x!HLFHE.eint<2>>, tensor<4xi64>) -> (tensor<2x3x4x!HLFHE.eint<2>>)
|
||||
return %1: tensor<2x3x4x!HLFHE.eint<2>>
|
||||
}
|
||||
Reference in New Issue
Block a user