diff --git a/compiler/include/concretelang/Support/LambdaArgument.h b/compiler/include/concretelang/Support/LambdaArgument.h index 7fe0cf5f8..f20b497c2 100644 --- a/compiler/include/concretelang/Support/LambdaArgument.h +++ b/compiler/include/concretelang/Support/LambdaArgument.h @@ -226,6 +226,77 @@ protected: template char TensorLambdaArgument::ID = 0; +namespace { +template struct NameOfFundamentalType { + static const char *get(); +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint8_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int8_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint16_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int16_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint32_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int32_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "uint64_t"; } +}; + +template <> struct NameOfFundamentalType { + static const char *get() { return "int64_t"; } +}; + +template struct LambdaArgumentTypeName; + +template <> struct LambdaArgumentTypeName<> { + static const char *get(const mlir::concretelang::LambdaArgument &arg) { + assert(false && "No name implemented for this lambda argument type"); + return nullptr; + } +}; + +template struct LambdaArgumentTypeName { + static const std::string get(const mlir::concretelang::LambdaArgument &arg) { + if (arg.dyn_cast>()) { + return NameOfFundamentalType::get(); + } else if (arg.dyn_cast>()) { + return std::string("encrypted ") + NameOfFundamentalType::get(); + } else if (arg.dyn_cast< + const TensorLambdaArgument>>()) { + return std::string("tensor<") + NameOfFundamentalType::get() + ">"; + } else if (arg.dyn_cast< + const TensorLambdaArgument>>()) { + return std::string("tensor::get() + ">"; + } + + return LambdaArgumentTypeName::get(arg); + } +}; +} // namespace + +const std::string getLambdaArgumentTypeAsString(const LambdaArgument &arg) { + return LambdaArgumentTypeName::get(arg); +} + } // namespace concretelang } // namespace mlir