From 4883eebfa3dee95c428d1ddfbb9e987e7e5230ba Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 23 Nov 2021 11:03:48 +0100 Subject: [PATCH] feat(compiler): Add HLFHELinalg.zero operation Add a new operation `HLFHELinalg.zero`, broadcasting an encrypted, zero-valued integer into a tensor of encrypted integers with static shape. Example creating a one-dimensional tensor with five elements all initialized to an encrypted zero: %tensor = "HLFHELinalg.zero"() : () -> tensor<5x!HLFHE.eint<4>> --- .../Dialect/HLFHELinalg/IR/HLFHELinalgOps.td | 16 ++++++++++++++++ compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp | 1 + 2 files changed, 17 insertions(+) diff --git a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td index 729b1e801..88455ba63 100644 --- a/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td +++ b/compiler/include/zamalang/Dialect/HLFHELinalg/IR/HLFHELinalgOps.td @@ -444,5 +444,21 @@ def MatMulIntEintOp : HLFHELinalg_Op<"matmul_int_eint", [TensorBinaryIntEint]> { }]; } +def ZeroOp : HLFHELinalg_Op<"zero", []> { + let summary = "Creates a new tensor with all elements initialized to an encrypted zero."; + + let description = [{ + Creates a new tensor with the shape specified in the result type and initializes its elements with an encrypted zero. + + Example: + ```mlir + %tensor = "HLFHELinalg.zero"() : () -> tensor<5x!HLFHE.eint<4>> + ``` + }]; + + let arguments = (ins); + + let results = (outs Type.predicate, HasStaticShapePred]>>:$aggregate); +} #endif diff --git a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp index aa8845fd6..bc35a076a 100644 --- a/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/HLFHE/Analysis/MANP.cpp @@ -859,6 +859,7 @@ struct MANPAnalysis : public mlir::ForwardDataFlowAnalysis { llvm::dyn_cast(op)) { norm2SqEquiv = getSqMANP(mulEintIntOp, operands); } else if (llvm::isa(op) || + llvm::isa(op) || llvm::isa(op)) { norm2SqEquiv = llvm::APInt{1, 1, false}; }