diff --git a/compiler/include/concretelang/Support/LambdaArgument.h b/compiler/include/concretelang/Support/LambdaArgument.h index 67358b339..62615e333 100644 --- a/compiler/include/concretelang/Support/LambdaArgument.h +++ b/compiler/include/concretelang/Support/LambdaArgument.h @@ -65,6 +65,16 @@ public: unsigned int getPrecision() const { return this->precision; } BackingIntType getValue() const { return this->value; } + template + bool operator==(const IntLambdaArgument &other) const { + return getValue() == other.getValue(); + } + + template + bool operator!=(const IntLambdaArgument &other) const { + return !(*this == other); + } + static char ID; protected: @@ -177,6 +187,26 @@ public: return this->value.data(); } + template + bool + operator==(const TensorLambdaArgument &other) const { + if (getDimensions() != other.getDimensions()) + return false; + + for (auto pair : llvm::zip(value, other.value)) { + if (std::get<0>(pair) != std::get<1>(pair)) + return false; + } + + return true; + } + + template + bool + operator!=(const TensorLambdaArgument &other) const { + return !(*this == other); + } + static char ID; protected: