feat(compiler): allow forcing encoding from python

This commit is contained in:
youben11
2023-08-31 10:46:39 +01:00
committed by Ayoub Benaissa
parent cba3847c92
commit 4e8b9a199c
3 changed files with 25 additions and 1 deletions

View File

@@ -64,6 +64,12 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.value("DAG_MULTI", optimizer::Strategy::DAG_MULTI)
.export_values();
pybind11::enum_<concrete_optimizer::Encoding>(m, "Encoding")
.value("AUTO", concrete_optimizer::Encoding::Auto)
.value("CRT", concrete_optimizer::Encoding::Crt)
.value("NATIVE", concrete_optimizer::Encoding::Native)
.export_values();
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
.def(pybind11::init(
[](std::string funcname) { return CompilationOptions(funcname); }))
@@ -134,6 +140,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
brLevel, brLogBase, ksLevel,
ksLogBase, largeIntegerParam};
})
.def("force_encoding",
[](CompilationOptions &options,
concrete_optimizer::Encoding encoding) {
options.optimizerConfig.encoding = encoding;
})
.def("simulation", [](CompilationOptions &options, bool simulate) {
options.simulate = simulate;
});

View File

@@ -17,7 +17,7 @@ from mlir._mlir_libs._concretelang._compiler import (
# pylint: enable=no-name-in-module,import-error
from .compilation_options import CompilationOptions
from .compilation_options import CompilationOptions, Encoding
from .compilation_context import CompilationContext
from .key_set_cache import KeySetCache
from .client_parameters import ClientParameters

View File

@@ -9,6 +9,7 @@ from typing import List
from mlir._mlir_libs._concretelang._compiler import (
CompilationOptions as _CompilationOptions,
OptimizerStrategy as _OptimizerStrategy,
Encoding,
)
from .wrapper import WrapperCpp
@@ -349,6 +350,18 @@ class CompilationOptions(WrapperCpp):
# pylint: enable=too-many-arguments,too-many-branches
def force_encoding(self, encoding: Encoding):
"""Force the compiler to use a specific encoding.
Args:
encoding (Encoding): the encoding to force the compiler to use
Raises:
TypeError: if encoding is not of type Encoding
"""
if not isinstance(encoding, Encoding):
raise TypeError("encoding need to be of type Encoding")
self.cpp().force_encoding(encoding)
def simulation(self, simulate: bool):
"""Enable or disable simulation.