enhance(compiler/runtime): Add runtime tools to handle tensor inputs and outputs

This commit is contained in:
Quentin Bourgerie
2021-08-24 15:02:45 +02:00
parent dba76a1e1b
commit af0789f128
11 changed files with 701 additions and 73 deletions

View File

@@ -14,12 +14,25 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
Precision precision,
mlir::Type type) {
if (type.isIntOrIndex()) {
// TODO - The index type is dependant of the target architecture, so
// actually we assume we target only 64 bits, we need to have some the size
// of the word of the target system.
size_t width = 64;
if (!type.isIndex()) {
width = type.getIntOrFloatBitWidth();
}
return CircuitGate{
.encryption = llvm::None,
.shape = {.size = 0},
.shape =
{
.width = width,
.size = 0,
},
};
}
if (type.isa<mlir::zamalang::LowLFHE::LweCiphertextType>()) {
// TODO - Get the width from the LWECiphertextType instead of global
// precision (could be possible after merge lowlfhe-ciphertext-parameter)
return CircuitGate{
.encryption = llvm::Optional<EncryptionGate>({
.secretKeyID = secretKeyID,
@@ -27,17 +40,17 @@ llvm::Expected<CircuitGate> gateFromMLIRType(std::string secretKeyID,
.variance = 0.,
.encoding = {.precision = precision},
}),
.shape = {.size = 0},
.shape = {.width = precision, .size = 0},
};
}
auto memref = type.dyn_cast_or_null<mlir::MemRefType>();
if (memref != nullptr) {
auto tensor = type.dyn_cast_or_null<mlir::RankedTensorType>();
if (tensor != nullptr) {
auto gate =
gateFromMLIRType(secretKeyID, precision, memref.getElementType());
gateFromMLIRType(secretKeyID, precision, tensor.getElementType());
if (auto err = gate.takeError()) {
return std::move(err);
}
gate->shape.size = memref.getDimSize(0);
gate->shape.size = tensor.getDimSize(0);
return gate;
}
return llvm::make_error<llvm::StringError>(