mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(compiler): Add method getResultVectorSize to JITLambda::Argument
Add method `JITLambda::Argument::getResultVectorSize` that returns the number of elements of the result if the result is a vector.
This commit is contained in:
@@ -57,6 +57,10 @@ public:
|
||||
// Fill the result.
|
||||
llvm::Error getResult(size_t pos, uint64_t *res, size_t size);
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, void *data, size_t size);
|
||||
|
||||
|
||||
@@ -344,6 +344,20 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
// Returns the number of elements of the result vector at position
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Result at pos %zu is not a tensor", pos);
|
||||
}
|
||||
|
||||
return info.shape.size;
|
||||
}
|
||||
|
||||
llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t *res,
|
||||
size_t size) {
|
||||
auto gate = outputGates[pos];
|
||||
|
||||
Reference in New Issue
Block a user