mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
refactor: remove chunked_eint
the pass transforming operations on chunked_eint will operate now on eint
This commit is contained in:
@@ -25,11 +25,6 @@ def FheIntegerInterface : TypeInterface<"FheIntegerInterface"> {
|
||||
/*description=*/"Get whether the integer is unsigned.",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isUnsigned"
|
||||
>,
|
||||
InterfaceMethod<
|
||||
/*description=*/"Get whether the integer is chunked (composed of multiple smaller integers).",
|
||||
/*retTy=*/"bool",
|
||||
/*methodName=*/"isChunked"
|
||||
>
|
||||
];
|
||||
}
|
||||
|
||||
@@ -33,7 +33,6 @@ def FHE_EncryptedIntegerType : FHE_Type<"EncryptedInteger",
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return false; }
|
||||
bool isUnsigned() const { return true; }
|
||||
bool isChunked() const { return false; }
|
||||
}];
|
||||
}
|
||||
|
||||
@@ -62,43 +61,12 @@ def FHE_EncryptedSignedIntegerType : FHE_Type<"EncryptedSignedInteger",
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return true; }
|
||||
bool isUnsigned() const { return false; }
|
||||
bool isChunked() const { return false; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_ChunkedEncryptedIntegerType : FHE_Type<"ChunkedEncryptedInteger",
|
||||
[MemRefElementTypeInterface, FheIntegerInterface]> {
|
||||
let mnemonic = "chunked_eint";
|
||||
|
||||
let summary = "An encrypted integer composed of multiple chunks";
|
||||
|
||||
let description = [{
|
||||
An encrypted integer composed of multiple chunks.
|
||||
|
||||
Examples:
|
||||
```mlir
|
||||
!FHE.chunked_eint<64>
|
||||
!FHE.chunked_eint<32>
|
||||
```
|
||||
}];
|
||||
|
||||
let parameters = (ins "unsigned":$width);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
let genVerifyDecl = true;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
bool isSigned() const { return false; }
|
||||
bool isUnsigned() const { return true; }
|
||||
bool isChunked() const { return true; }
|
||||
}];
|
||||
}
|
||||
|
||||
def FHE_AnyEncryptedInteger : Type<Or<[
|
||||
FHE_EncryptedIntegerType.predicate,
|
||||
FHE_EncryptedSignedIntegerType.predicate,
|
||||
FHE_ChunkedEncryptedIntegerType.predicate
|
||||
FHE_EncryptedSignedIntegerType.predicate
|
||||
]>>;
|
||||
|
||||
def FHE_EncryptedBooleanType : FHE_Type<"EncryptedBoolean",
|
||||
|
||||
@@ -89,33 +89,3 @@ mlir::Type EncryptedSignedIntegerType::parse(mlir::AsmParser &p) {
|
||||
|
||||
return getChecked(loc, loc.getContext(), width);
|
||||
}
|
||||
|
||||
mlir::LogicalResult ChunkedEncryptedIntegerType::verify(
|
||||
llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, unsigned p) {
|
||||
if (p == 0) {
|
||||
emitError() << "FHE.chunked_eint doesn't support precision of 0";
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
void ChunkedEncryptedIntegerType::print(mlir::AsmPrinter &p) const {
|
||||
p << "<" << getWidth() << ">";
|
||||
}
|
||||
|
||||
mlir::Type ChunkedEncryptedIntegerType::parse(mlir::AsmParser &p) {
|
||||
if (p.parseLess())
|
||||
return mlir::Type();
|
||||
|
||||
int width;
|
||||
|
||||
if (p.parseInteger(width))
|
||||
return mlir::Type();
|
||||
|
||||
if (p.parseGreater())
|
||||
return mlir::Type();
|
||||
|
||||
mlir::Location loc = p.getEncodedSourceLoc(p.getNameLoc());
|
||||
|
||||
return getChecked(loc, loc.getContext(), width);
|
||||
}
|
||||
|
||||
@@ -29,25 +29,13 @@ bool verifyEncryptedIntegerInputAndResultConsistency(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input.isChunked() != result.isChunked()) {
|
||||
op.emitOpError("should have the same composition (chunked or not) of "
|
||||
"encrypted input and result");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool verifyEncryptedIntegerAndIntegerInputsConsistency(mlir::Operation &op,
|
||||
FheIntegerInterface &a,
|
||||
IntegerType &b) {
|
||||
if (a.isChunked()) {
|
||||
if (a.getWidth() != b.getWidth()) {
|
||||
op.emitOpError("should have the width of plain input equal to width of "
|
||||
"encrypted input");
|
||||
return false;
|
||||
}
|
||||
} else if (a.getWidth() + 1 != b.getWidth()) {
|
||||
if (a.getWidth() + 1 != b.getWidth()) {
|
||||
op.emitOpError("should have the width of plain input equal to width of "
|
||||
"encrypted input + 1");
|
||||
return false;
|
||||
@@ -69,12 +57,6 @@ bool verifyEncryptedIntegerInputsConsistency(mlir::Operation &op,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (a.isChunked() != b.isChunked()) {
|
||||
op.emitOpError("should have the same composition (chunked or not) of "
|
||||
"encrypted inputs");
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
@@ -49,10 +49,10 @@ namespace typing {
|
||||
|
||||
/// Converts `FHE::ChunkedEncryptedInteger` into a tensor of
|
||||
/// `FHE::EncryptedInteger`.
|
||||
mlir::RankedTensorType
|
||||
convertChunkedEint(mlir::MLIRContext *context,
|
||||
FHE::ChunkedEncryptedIntegerType chunkedEint,
|
||||
unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
mlir::RankedTensorType convertChunkedEint(mlir::MLIRContext *context,
|
||||
FHE::EncryptedIntegerType chunkedEint,
|
||||
unsigned int chunkSize,
|
||||
unsigned int chunkWidth) {
|
||||
auto eint = FHE::EncryptedIntegerType::get(context, chunkSize);
|
||||
auto bigIntWidth = chunkedEint.getWidth();
|
||||
assert(bigIntWidth % chunkWidth == 0 &&
|
||||
@@ -68,9 +68,13 @@ class TypeConverter : public mlir::TypeConverter {
|
||||
public:
|
||||
TypeConverter(unsigned int chunkSize, unsigned int chunkWidth) {
|
||||
addConversion([](mlir::Type type) { return type; });
|
||||
addConversion([chunkSize,
|
||||
chunkWidth](FHE::ChunkedEncryptedIntegerType type) {
|
||||
return convertChunkedEint(type.getContext(), type, chunkSize, chunkWidth);
|
||||
addConversion([chunkSize, chunkWidth](FHE::EncryptedIntegerType type) {
|
||||
if (type.getWidth() > chunkSize) {
|
||||
return (mlir::Type)convertChunkedEint(type.getContext(), type,
|
||||
chunkSize, chunkWidth);
|
||||
} else {
|
||||
return (mlir::Type)type;
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -296,11 +296,13 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) {
|
||||
return StreamStringError("Rewriting of encrypted mul failed");
|
||||
}
|
||||
|
||||
if (mlir::concretelang::pipeline::transformFHEBigInt(
|
||||
mlirContext, module, enablePass, options.chunkSize,
|
||||
options.chunkWidth)
|
||||
.failed()) {
|
||||
return errorDiag("Transforming FHE big integer ops failed");
|
||||
if (options.chunkIntegers) {
|
||||
if (mlir::concretelang::pipeline::transformFHEBigInt(
|
||||
mlirContext, module, enablePass, options.chunkSize,
|
||||
options.chunkWidth)
|
||||
.failed()) {
|
||||
return errorDiag("Transforming FHE big integer ops failed");
|
||||
}
|
||||
}
|
||||
|
||||
// FHE High level pass to determine FHE parameters
|
||||
|
||||
@@ -98,46 +98,6 @@ gateFromMLIRType(V0FHEContext fheContext, LweSecretKeyID secretKeyID,
|
||||
/*.chunkInfo = */ chunkInfo,
|
||||
};
|
||||
}
|
||||
// TODO: this a duplicate of the last if: should be removed when we remove
|
||||
// chinked eint
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::ChunkedEncryptedIntegerType>()) {
|
||||
bool sign = lweTy.isSignedInteger();
|
||||
std::vector<int64_t> crt;
|
||||
if (fheContext.parameter.largeInteger.has_value()) {
|
||||
crt = fheContext.parameter.largeInteger.value().crtDecomposition;
|
||||
}
|
||||
size_t width;
|
||||
uint64_t size;
|
||||
std::vector<int64_t> dims;
|
||||
if (chunkInfo.hasValue()) {
|
||||
width = chunkInfo->size;
|
||||
assert(lweTy.getWidth() % chunkInfo->width == 0);
|
||||
size = lweTy.getWidth() / chunkInfo->width;
|
||||
dims.push_back(size);
|
||||
} else {
|
||||
width = (size_t)lweTy.getWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
/* .encryption = */ llvm::Optional<EncryptionGate>({
|
||||
/* .secretKeyID = */ secretKeyID,
|
||||
/* .variance = */ variance,
|
||||
/* .encoding = */
|
||||
{
|
||||
/* .precision = */ width,
|
||||
/* .crt = */ crt,
|
||||
},
|
||||
}),
|
||||
/*.shape = */
|
||||
{
|
||||
/*.width = */ width,
|
||||
/*.dimensions = */ dims,
|
||||
/*.size = */ size,
|
||||
/*.sign = */ sign,
|
||||
},
|
||||
/*.chunkInfo = */ chunkInfo,
|
||||
};
|
||||
}
|
||||
if (auto lweTy = type.dyn_cast_or_null<
|
||||
mlir::concretelang::FHE::EncryptedBooleanType>()) {
|
||||
size_t width = mlir::concretelang::FHE::EncryptedBooleanType::getWidth();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: concretecompiler --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
// RUN: concretecompiler --chunk-integers --chunk-size 4 --chunk-width 2 --passes fhe-big-int-transform --action=dump-fhe %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: tensor<32x!FHE.eint<4>>, %arg1: tensor<32x!FHE.eint<4>>) -> tensor<32x!FHE.eint<4>>
|
||||
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
func.func @add_chunked_eint(%arg0: !FHE.eint<64>, %arg1: !FHE.eint<64>) -> !FHE.eint<64> {
|
||||
// CHECK-NEXT: %[[V0:.*]] = "FHE.zero"() : () -> !FHE.eint<4>
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.zero_tensor"() : () -> tensor<32x!FHE.eint<4>>
|
||||
// CHECK-NEXT: %[[c4_i5:.*]] = arith.constant 4 : i5
|
||||
@@ -19,6 +19,6 @@ func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_ei
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %2 : tensor<32x!FHE.eint<4>>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<64>, !FHE.eint<64>) -> (!FHE.eint<64>)
|
||||
return %1: !FHE.eint<64>
|
||||
}
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
// RUN: concretecompiler --action=roundtrip %s 2>&1| FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @mul_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.mul_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
|
||||
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "FHE.mul_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @add_chunked_eint_int(%arg0: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = arith.constant 1 : i64
|
||||
// CHECK-NEXT: %[[V2:.*]] = "FHE.add_eint_int"(%arg0, %[[V1]]) : (!FHE.chunked_eint<64>, i64) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V2]] : !FHE.chunked_eint<64>
|
||||
|
||||
%0 = arith.constant 1 : i64
|
||||
%1 = "FHE.add_eint_int"(%arg0, %0): (!FHE.chunked_eint<64>, i64) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
func.func @add_chunked_eint(%arg0: !FHE.chunked_eint<64>, %arg1: !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64> {
|
||||
// CHECK-NEXT: %[[V1:.*]] = "FHE.add_eint"(%arg0, %arg1) : (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> !FHE.chunked_eint<64>
|
||||
// CHECK-NEXT: return %[[V1]] : !FHE.chunked_eint<64>
|
||||
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.chunked_eint<64>, !FHE.chunked_eint<64>) -> (!FHE.chunked_eint<64>)
|
||||
return %1: !FHE.chunked_eint<64>
|
||||
}
|
||||
|
||||
// TODO: max/min
|
||||
Reference in New Issue
Block a user