feat: use signed weights in optimizer dot

This commit is contained in:
rudy
2022-09-19 15:07:40 +02:00
committed by rudy-6-4
parent f2cbb1e719
commit 637a004529
3 changed files with 17 additions and 7 deletions

View File

@@ -154,7 +154,7 @@ struct FunctionToDag {
}
void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs,
std::vector<std::uint64_t> &weights_vector) {
std::vector<std::int64_t> &weights_vector) {
assert(encrypted_inputs.size() == 1);
auto weights = concrete_optimizer::weights::vector(slice(weights_vector));
index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights));
@@ -228,19 +228,19 @@ struct FunctionToDag {
return value.isa<mlir::BlockArgument>();
}
std::vector<std::uint64_t>
std::vector<std::int64_t>
resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) {
std::vector<std::uint64_t> values;
std::vector<std::int64_t> values;
mlir::DenseIntElementsAttr denseVals =
cstOp->getAttrOfType<mlir::DenseIntElementsAttr>("value");
for (llvm::APInt val : denseVals.getValues<llvm::APInt>()) {
values.push_back(val.getZExtValue());
values.push_back(val.getSExtValue());
}
return values;
}
llvm::Optional<std::vector<std::uint64_t>>
llvm::Optional<std::vector<std::int64_t>>
resolveConstantWeights(mlir::Value &value) {
if (auto cstOp = llvm::dyn_cast_or_null<mlir::arith::ConstantOp>(
value.getDefiningOp())) {
@@ -258,7 +258,7 @@ struct FunctionToDag {
}
}
llvm::Optional<std::vector<std::uint64_t>>
llvm::Optional<std::vector<std::int64_t>>
dotWeights(mlir::concretelang::FHELinalg::Dot &dot) {
if (dot.getOperands().size() != 2) {
return llvm::None;

View File

@@ -0,0 +1,10 @@
// RUN: concretecompiler --verbose --split-input-file --action=dump-fhe %s 2>&1| FileCheck %s
func.func @main(%arg0: tensor<5x!FHE.eint<5>>) -> !FHE.eint<5> {
%weights = arith.constant dense<[-1, -1, -1, -1, -1]> : tensor<5xi6>
%tlu = arith.constant dense<[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]> : tensor<32xi64>
%0 = "FHELinalg.apply_lookup_table"(%arg0, %tlu) : (tensor<5x!FHE.eint<5>>, tensor<32xi64>) -> tensor<5x!FHE.eint<5>>
// CHECK: Dot { [[a:.*]], weights: ClearTensor { shape: Shape { dimensions_size: [5] }, values: [-1, -1, -1, -1, -1] }, [[b:.*]]}
%1 = "FHELinalg.dot_eint_int"(%0, %weights) : (tensor<5x!FHE.eint<5>>, tensor<5xi6>) -> !FHE.eint<5>
return %1 : !FHE.eint<5>
}