enhance(compiler): Add function returning LambdaArgument type as string

This adds a new function `getLambdaArgumentTypeAsString(const
LambdaArgument&)` returning the name of a lambda argument type as a
string, e.g., `"uint8_t"` for an `IntLambdaArgument<uint8_t>` or
`"tensor<uint8_t>"` for a
`TensorLambdaArgument<IntLambdaArgument<uint8_t>>`.

Note that, due to the static inheritance scheme for Lambda Arguments
and explicit instantiation, this is only implemented for the common
backing integer types `uint8_t`, `int8_t`, `uint16_t`, `int16_t`,
`uint32_t`, `int32_t`, `uint64_t`, and `int64_t`.
This commit is contained in:
Andi Drebes
2022-09-20 15:58:26 +02:00
committed by rudy-6-4
parent 7a3cf64171
commit 710dd7a88c

View File

@@ -226,6 +226,77 @@ protected:
template <typename ScalarArgumentT>
char TensorLambdaArgument<ScalarArgumentT>::ID = 0;
namespace {
template <typename T> struct NameOfFundamentalType {
static const char *get();
};
template <> struct NameOfFundamentalType<uint8_t> {
static const char *get() { return "uint8_t"; }
};
template <> struct NameOfFundamentalType<int8_t> {
static const char *get() { return "int8_t"; }
};
template <> struct NameOfFundamentalType<uint16_t> {
static const char *get() { return "uint16_t"; }
};
template <> struct NameOfFundamentalType<int16_t> {
static const char *get() { return "int16_t"; }
};
template <> struct NameOfFundamentalType<uint32_t> {
static const char *get() { return "uint32_t"; }
};
template <> struct NameOfFundamentalType<int32_t> {
static const char *get() { return "int32_t"; }
};
template <> struct NameOfFundamentalType<uint64_t> {
static const char *get() { return "uint64_t"; }
};
template <> struct NameOfFundamentalType<int64_t> {
static const char *get() { return "int64_t"; }
};
template <typename... Ts> struct LambdaArgumentTypeName;
template <> struct LambdaArgumentTypeName<> {
static const char *get(const mlir::concretelang::LambdaArgument &arg) {
assert(false && "No name implemented for this lambda argument type");
return nullptr;
}
};
template <typename T, typename... Ts> struct LambdaArgumentTypeName<T, Ts...> {
static const std::string get(const mlir::concretelang::LambdaArgument &arg) {
if (arg.dyn_cast<const IntLambdaArgument<T>>()) {
return NameOfFundamentalType<T>::get();
} else if (arg.dyn_cast<const EIntLambdaArgument<T>>()) {
return std::string("encrypted ") + NameOfFundamentalType<T>::get();
} else if (arg.dyn_cast<
const TensorLambdaArgument<IntLambdaArgument<T>>>()) {
return std::string("tensor<") + NameOfFundamentalType<T>::get() + ">";
} else if (arg.dyn_cast<
const TensorLambdaArgument<EIntLambdaArgument<T>>>()) {
return std::string("tensor<encrypted ") +
NameOfFundamentalType<T>::get() + ">";
}
return LambdaArgumentTypeName<Ts...>::get(arg);
}
};
} // namespace
const std::string getLambdaArgumentTypeAsString(const LambdaArgument &arg) {
return LambdaArgumentTypeName<int8_t, uint8_t, int16_t, uint16_t, int32_t,
uint32_t, int64_t, uint64_t>::get(arg);
}
} // namespace concretelang
} // namespace mlir