mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): Add function returning LambdaArgument type as string
This adds a new function `getLambdaArgumentTypeAsString(const LambdaArgument&)` returning the name of a lambda argument type as a string, e.g., `"uint8_t"` for an `IntLambdaArgument<uint8_t>` or `"tensor<uint8_t>"` for a `TensorLambdaArgument<IntLambdaArgument<uint8_t>>`. Note that, due to the static inheritance scheme for Lambda Arguments and explicit instantiation, this is only implemented for the common backing integer types `uint8_t`, `int8_t`, `uint16_t`, `int16_t`, `uint32_t`, `int32_t`, `uint64_t`, and `int64_t`.
This commit is contained in:
@@ -226,6 +226,77 @@ protected:
|
||||
template <typename ScalarArgumentT>
|
||||
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
|
||||
|
||||
namespace {
|
||||
template <typename T> struct NameOfFundamentalType {
|
||||
static const char *get();
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint8_t> {
|
||||
static const char *get() { return "uint8_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int8_t> {
|
||||
static const char *get() { return "int8_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint16_t> {
|
||||
static const char *get() { return "uint16_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int16_t> {
|
||||
static const char *get() { return "int16_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint32_t> {
|
||||
static const char *get() { return "uint32_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int32_t> {
|
||||
static const char *get() { return "int32_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<uint64_t> {
|
||||
static const char *get() { return "uint64_t"; }
|
||||
};
|
||||
|
||||
template <> struct NameOfFundamentalType<int64_t> {
|
||||
static const char *get() { return "int64_t"; }
|
||||
};
|
||||
|
||||
template <typename... Ts> struct LambdaArgumentTypeName;
|
||||
|
||||
template <> struct LambdaArgumentTypeName<> {
|
||||
static const char *get(const mlir::concretelang::LambdaArgument &arg) {
|
||||
assert(false && "No name implemented for this lambda argument type");
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename... Ts> struct LambdaArgumentTypeName<T, Ts...> {
|
||||
static const std::string get(const mlir::concretelang::LambdaArgument &arg) {
|
||||
if (arg.dyn_cast<const IntLambdaArgument<T>>()) {
|
||||
return NameOfFundamentalType<T>::get();
|
||||
} else if (arg.dyn_cast<const EIntLambdaArgument<T>>()) {
|
||||
return std::string("encrypted ") + NameOfFundamentalType<T>::get();
|
||||
} else if (arg.dyn_cast<
|
||||
const TensorLambdaArgument<IntLambdaArgument<T>>>()) {
|
||||
return std::string("tensor<") + NameOfFundamentalType<T>::get() + ">";
|
||||
} else if (arg.dyn_cast<
|
||||
const TensorLambdaArgument<EIntLambdaArgument<T>>>()) {
|
||||
return std::string("tensor<encrypted ") +
|
||||
NameOfFundamentalType<T>::get() + ">";
|
||||
}
|
||||
|
||||
return LambdaArgumentTypeName<Ts...>::get(arg);
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
const std::string getLambdaArgumentTypeAsString(const LambdaArgument &arg) {
|
||||
return LambdaArgumentTypeName<int8_t, uint8_t, int16_t, uint16_t, int32_t,
|
||||
uint32_t, int64_t, uint64_t>::get(arg);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
Reference in New Issue
Block a user