mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(compiler): JitCompilerEngine: Add support for arbitrary int result tensors
Add support for result tensors composed of uint8_t, uint16_t, uint32_t and uint64_t elements, replacing the current implementation, which only supports uint64_t tensors.
This commit is contained in:
committed by
Quentin Bourgerie
parent
3118983287
commit
f54c0dd8d8
@@ -32,18 +32,15 @@ inline llvm::Expected<uint64_t> typedResult(JITLambda::Argument &arguments) {
|
||||
return res;
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<size_t> n = arguments.getResultVectorSize(0);
|
||||
|
||||
if (auto err = n.takeError())
|
||||
return std::move(err);
|
||||
|
||||
std::vector<uint64_t> res(*n);
|
||||
std::vector<T> res(*n);
|
||||
|
||||
if (auto err = arguments.getResult(0, res.data(), res.size()))
|
||||
return StreamStringError() << "Cannot retrieve result:" << err;
|
||||
@@ -51,6 +48,54 @@ typedResult(JITLambda::Argument &arguments) {
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
// Specializations of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
//
|
||||
// Cannot factor out into a template template <typename T> inline
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(JITLambda::Argument &arguments); due to ambiguity with
|
||||
// scalar template
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint8_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return std::move(typedVectorResult<uint8_t>(arguments));
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint16_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return std::move(typedVectorResult<uint16_t>(arguments));
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint32_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return std::move(typedVectorResult<uint32_t>(arguments));
|
||||
}
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(JITLambda::Argument &arguments) {
|
||||
return std::move(typedVectorResult<uint64_t>(arguments));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildTensorLambdaResult(JITLambda::Argument &arguments) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(arguments);
|
||||
|
||||
if (!tensorOrError)
|
||||
return std::move(tensorOrError.takeError());
|
||||
|
||||
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
|
||||
arguments.getResultDimensions(0);
|
||||
|
||||
if (!tensorDimOrError)
|
||||
return std::move(tensorDimOrError.takeError());
|
||||
|
||||
return std::move(std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, *tensorDimOrError));
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for a single result wrapped into
|
||||
// a `LambdaArgument`.
|
||||
template <>
|
||||
@@ -73,25 +118,26 @@ typedResult(JITLambda::Argument &arguments) {
|
||||
}
|
||||
|
||||
case JITLambda::Argument::ResultType::TENSOR: {
|
||||
llvm::Expected<std::vector<uint64_t>> tensorOrError =
|
||||
typedResult<std::vector<uint64_t>>(arguments);
|
||||
llvm::Expected<size_t> width = arguments.getResultWidth(0);
|
||||
|
||||
if (!tensorOrError)
|
||||
return std::move(tensorOrError.takeError());
|
||||
if (!width)
|
||||
return std::move(width.takeError());
|
||||
|
||||
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
|
||||
arguments.getResultDimensions(0);
|
||||
if (!tensorDimOrError)
|
||||
return std::move(tensorDimOrError.takeError());
|
||||
|
||||
return std::move(
|
||||
std::make_unique<TensorLambdaArgument<IntLambdaArgument<uint64_t>>>(
|
||||
*tensorOrError, *tensorDimOrError));
|
||||
if (*width > 64)
|
||||
return StreamStringError("Cannot handle scalars with more than 64 bits");
|
||||
if (*width > 32)
|
||||
return buildTensorLambdaResult<uint64_t>(arguments);
|
||||
else if (*width > 16)
|
||||
return buildTensorLambdaResult<uint32_t>(arguments);
|
||||
else if (*width > 8)
|
||||
return buildTensorLambdaResult<uint16_t>(arguments);
|
||||
else if (*width <= 8)
|
||||
return buildTensorLambdaResult<uint8_t>(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
return StreamStringError("Unknown result type");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Adaptor class that adds arguments specified as instances of
|
||||
// `LambdaArgument` to `JitLambda::Argument`.
|
||||
|
||||
Reference in New Issue
Block a user