mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat(compiler): Add method JITLambda::Argument::getResultType(size_t)
Add a method `JITLambda::Argument::getResultType(size_t pos)` that returns the type of the result with the index `pos` of a `JITLambda::Argument`.
This commit is contained in:
committed by
Ayoub Benaissa
parent
97ee8134ed
commit
1ad3d57f66
@@ -41,6 +41,13 @@ public:
|
||||
// Get the result at the given pos as an uint64_t.
|
||||
llvm::Error getResult(size_t pos, uint64_t &res);
|
||||
|
||||
// Specifies the type of a result
|
||||
enum ResultType { SCALAR, TENSOR };
|
||||
|
||||
// Returns the result type at position `pos`. If pos is invalid,
|
||||
// an error is returned.
|
||||
llvm::Expected<enum ResultType> getResultType(size_t pos);
|
||||
|
||||
// Get a result for tensors, fill the `res` buffer with the value of the
|
||||
// tensor result.
|
||||
// Returns an error:
|
||||
|
||||
@@ -344,6 +344,25 @@ llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
|
||||
return info.shape.size;
|
||||
}
|
||||
|
||||
llvm::Expected<enum JITLambda::Argument::ResultType>
|
||||
JITLambda::Argument::getResultType(size_t pos) {
|
||||
if (pos >= outputGates.size()) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Requesting type for result at index %zu, "
|
||||
"but lambda only generates %zu results",
|
||||
pos, outputGates.size());
|
||||
}
|
||||
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return ResultType::SCALAR;
|
||||
} else {
|
||||
return ResultType::TENSOR;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
size_t size) {
|
||||
|
||||
|
||||
Reference in New Issue
Block a user