mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): pass dimensions at TensorLambdaArg creation
This commit is contained in:
@@ -1,47 +0,0 @@
|
||||
#ifndef ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
|
||||
#define ZAMALANG_SUPPORT_EXECUTION_ARGUMENT_H
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// Frontend object to abstract the different types of possible arguments,
|
||||
// namely, integers, and tensors.
|
||||
class ExecutionArgument {
|
||||
public:
|
||||
// There are two possible underlying types for the execution argument, either
|
||||
// and int, or a tensor
|
||||
bool isTensor() { return isTensorArg; }
|
||||
bool isInt() { return !isTensorArg; }
|
||||
|
||||
uint8_t *getTensorArgument() { return tensorArg.data(); }
|
||||
|
||||
size_t getTensorSize() { return tensorArg.size(); }
|
||||
|
||||
uint64_t getIntegerArgument() { return intArg; }
|
||||
|
||||
// Create an execution argument from an integer
|
||||
static std::shared_ptr<ExecutionArgument> create(uint64_t arg) {
|
||||
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
|
||||
}
|
||||
// Create an execution argument from a tensor
|
||||
static std::shared_ptr<ExecutionArgument> create(std::vector<uint8_t> arg) {
|
||||
return std::shared_ptr<ExecutionArgument>(new ExecutionArgument(arg));
|
||||
}
|
||||
|
||||
private:
|
||||
ExecutionArgument(int arg) : isTensorArg(false), intArg(arg) {}
|
||||
|
||||
ExecutionArgument(std::vector<uint8_t> tensor)
|
||||
: isTensorArg(true), tensorArg(tensor) {}
|
||||
|
||||
uint64_t intArg;
|
||||
std::vector<uint8_t> tensorArg;
|
||||
bool isTensorArg;
|
||||
};
|
||||
|
||||
} // namespace zamalang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -60,6 +60,10 @@ public:
|
||||
// `pos` or an error if the result is a scalar value
|
||||
llvm::Expected<size_t> getResultVectorSize(size_t pos);
|
||||
|
||||
// Returns the dimensions of the result tensor at position `pos` or
|
||||
// an error if the result is a scalar value
|
||||
llvm::Expected<std::vector<int64_t>> getResultDimensions(size_t pos);
|
||||
|
||||
private:
|
||||
llvm::Error setArg(size_t pos, size_t width, const void *data,
|
||||
llvm::ArrayRef<int64_t> shape);
|
||||
|
||||
@@ -79,9 +79,14 @@ typedResult(JITLambda::Argument &arguments) {
|
||||
if (!tensorOrError)
|
||||
return std::move(tensorOrError.takeError());
|
||||
|
||||
llvm::Expected<std::vector<int64_t>> tensorDimOrError =
|
||||
arguments.getResultDimensions(0);
|
||||
if (!tensorDimOrError)
|
||||
return std::move(tensorDimOrError.takeError());
|
||||
|
||||
return std::move(
|
||||
std::make_unique<TensorLambdaArgument<IntLambdaArgument<uint64_t>>>(
|
||||
*tensorOrError));
|
||||
*tensorOrError, *tensorDimOrError));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -345,6 +345,21 @@ llvm::Expected<size_t> JITLambda::Argument::getResultVectorSize(size_t pos) {
|
||||
return info.shape.size;
|
||||
}
|
||||
|
||||
// Returns the dimensions of the result tensor at position `pos` or
|
||||
// an error if the result is a scalar value
|
||||
llvm::Expected<std::vector<int64_t>>
|
||||
JITLambda::Argument::getResultDimensions(size_t pos) {
|
||||
auto gate = outputGates[pos];
|
||||
auto info = std::get<0>(gate);
|
||||
|
||||
if (info.shape.size == 0) {
|
||||
return llvm::createStringError(llvm::inconvertibleErrorCode(),
|
||||
"Result at pos %zu is not a tensor", pos);
|
||||
}
|
||||
|
||||
return info.shape.dimensions;
|
||||
}
|
||||
|
||||
llvm::Expected<enum JITLambda::Argument::ResultType>
|
||||
JITLambda::Argument::getResultType(size_t pos) {
|
||||
if (pos >= outputGates.size()) {
|
||||
|
||||
Reference in New Issue
Block a user