From 710dd7a88cc2a021f75981593d497c4401dfcd75 Mon Sep 17 00:00:00 2001 From: Andi Drebes Date: Tue, 20 Sep 2022 15:58:26 +0200 Subject: [PATCH] enhance(compiler): Add function returning LambdaArgument type as string This adds a new function `getLambdaArgumentTypeAsString(const LambdaArgument&)` returning the name of a lambda argument type as a string, e.g., `"uint8_t"` for an `IntLambdaArgument` or `"tensor"` for a `TensorLambdaArgument>`. Note that, due to the static inheritance scheme for Lambda Arguments and explicit instantiation, this is only implemented for the common backing integer types `uint8_t`, `int8_t`, `uint16_t`, `int16_t`, `uint32_t`, `int32_t`, `uint64_t`, and `int64_t`. --- .../concretelang/Support/LambdaArgument.h | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/compiler/include/concretelang/Support/LambdaArgument.h b/compiler/include/concretelang/Support/LambdaArgument.h index 7fe0cf5f8..f20b497c2 100644 --- a/compiler/include/concretelang/Support/LambdaArgument.h +++ b/compiler/include/concretelang/Support/LambdaArgument.h @@ -226,6 +226,77 @@ protected: template char TensorLambdaArgument::ID = 0; +namespace { +template struct NameOfFundamentalType { + static const char *get(); +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint8_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int8_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint16_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int16_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint32_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int32_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint64_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int64_t"; } +}; + +template struct LambdaArgumentTypeName; + +template <> struct LambdaArgumentTypeName<> { + static const char *get(const mlir::concretelang::LambdaArgument &arg) { + assert(false && "No name implemented for this lambda argument type"); + return nullptr; + } +}; + +template struct LambdaArgumentTypeName { + static const std::string get(const mlir::concretelang::LambdaArgument &arg) { + if (arg.dyn_cast>()) { + return NameOfFundamentalType::get(); + } else if (arg.dyn_cast>()) { + return std::string("encrypted ") + NameOfFundamentalType::get(); + } else if (arg.dyn_cast< + const TensorLambdaArgument>>()) { + return std::string("tensor<") + NameOfFundamentalType::get() + ">"; + } else if (arg.dyn_cast< + const TensorLambdaArgument>>()) { + return std::string("tensor::get() + ">"; + } + + return LambdaArgumentTypeName::get(arg); + } +}; +} // namespace + +const std::string getLambdaArgumentTypeAsString(const LambdaArgument &arg) { + return LambdaArgumentTypeName::get(arg); +} + } // namespace concretelang } // namespace mlir