mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs
This commit is contained in:
@@ -14,12 +14,25 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
Precision precision,
|
||||
mlir::Type type) {
|
||||
if (type.isIntOrIndex()) {
|
||||
// TODO - The index type is dependant of the target architecture, so
|
||||
// actually we assume we target only 64 bits, we need to have some the size
|
||||
// of the word of the target system.
|
||||
size_t width = 64;
|
||||
if (!type.isIndex()) {
|
||||
width = type.getIntOrFloatBitWidth();
|
||||
}
|
||||
return CircuitGate{
|
||||
.encryption = llvm::None,
|
||||
.shape = {.size = 0},
|
||||
.shape =
|
||||
{
|
||||
.width = width,
|
||||
.size = 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
|
||||
// TODO - Get the width from the LWECiphertextType instead of global
|
||||
// precision (could be possible after merge lowlfhe-ciphertext-parameter)
|
||||
return CircuitGate{
|
||||
.encryption = llvm::Optional<EncryptionGate>({
|
||||
.secretKeyID = secretKeyID,
|
||||
@@ -27,17 +40,17 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
|
||||
.variance = 0.,
|
||||
.encoding = {.precision = precision},
|
||||
}),
|
||||
.shape = {.size = 0},
|
||||
.shape = {.width = precision, .size = 0},
|
||||
};
|
||||
}
|
||||
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
|
||||
if (memref != nullptr) {
|
||||
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
|
||||
if (tensor != nullptr) {
|
||||
auto gate =
|
||||
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
|
||||
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
|
||||
if (auto err = gate.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
gate->shape.size = memref.getDimSize(0);
|
||||
gate->shape.size = tensor.getDimSize(0);
|
||||
return gate;
|
||||
}
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
|
||||
Reference in New Issue
Block a user