mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(compiler/frontend-python): Expose default GPU CompilerOptions set and use it in concrete-python
This commit is contained in:
committed by
Antoniu Pop
parent
6691c8f107
commit
1dec886770
@@ -60,6 +60,11 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
|
||||
m.def("init_df_parallelization", &initDataflowParallelization);
|
||||
|
||||
pybind11::enum_<mlir::concretelang::Backend>(m, "Backend")
|
||||
.value("CPU", mlir::concretelang::Backend::CPU)
|
||||
.value("GPU", mlir::concretelang::Backend::GPU)
|
||||
.export_values();
|
||||
|
||||
pybind11::enum_<optimizer::Strategy>(m, "OptimizerStrategy")
|
||||
.value("V0", optimizer::Strategy::V0)
|
||||
.value("DAG_MONO", optimizer::Strategy::DAG_MONO)
|
||||
@@ -74,7 +79,9 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
|
||||
pybind11::class_<CompilationOptions>(m, "CompilationOptions")
|
||||
.def(pybind11::init(
|
||||
[](std::string funcname) { return CompilationOptions(funcname); }))
|
||||
[](std::string funcname, mlir::concretelang::Backend backend) {
|
||||
return CompilationOptions(funcname, backend);
|
||||
}))
|
||||
.def("set_funcname",
|
||||
[](CompilationOptions &options, std::string funcname) {
|
||||
options.mainFuncName = funcname;
|
||||
|
||||
@@ -10,6 +10,7 @@ from mlir._mlir_libs._concretelang._compiler import (
|
||||
CompilationOptions as _CompilationOptions,
|
||||
OptimizerStrategy as _OptimizerStrategy,
|
||||
Encoding,
|
||||
Backend as _Backend,
|
||||
)
|
||||
from .wrapper import WrapperCpp
|
||||
|
||||
@@ -41,7 +42,7 @@ class CompilationOptions(WrapperCpp):
|
||||
|
||||
@staticmethod
|
||||
# pylint: disable=arguments-differ
|
||||
def new(function_name="main") -> "CompilationOptions":
|
||||
def new(function_name="main", backend=_Backend.CPU) -> "CompilationOptions":
|
||||
"""Build a CompilationOptions.
|
||||
|
||||
Args:
|
||||
@@ -57,7 +58,11 @@ class CompilationOptions(WrapperCpp):
|
||||
raise TypeError(
|
||||
f"function_name must be of type str not {type(function_name)}"
|
||||
)
|
||||
return CompilationOptions.wrap(_CompilationOptions(function_name))
|
||||
if not isinstance(backend, _Backend):
|
||||
raise TypeError(
|
||||
f"backend must be of type Backend not {type(function_name)}"
|
||||
)
|
||||
return CompilationOptions.wrap(_CompilationOptions(function_name, backend))
|
||||
|
||||
# pylint: enable=arguments-differ
|
||||
|
||||
|
||||
@@ -25,7 +25,12 @@ from concrete.compiler import (
|
||||
set_compiler_logging,
|
||||
set_llvm_debug_flag,
|
||||
)
|
||||
from mlir._mlir_libs._concretelang._compiler import KeyType, OptimizerStrategy, PrimitiveOperation
|
||||
from mlir._mlir_libs._concretelang._compiler import (
|
||||
Backend,
|
||||
KeyType,
|
||||
OptimizerStrategy,
|
||||
PrimitiveOperation,
|
||||
)
|
||||
from mlir.ir import Module as MlirModule
|
||||
|
||||
from ..internal.utils import assert_that
|
||||
@@ -106,7 +111,9 @@ class Server:
|
||||
context to use for the Compiler
|
||||
"""
|
||||
|
||||
options = CompilationOptions.new("main")
|
||||
backend = Backend.GPU if configuration.use_gpu else Backend.CPU
|
||||
options = CompilationOptions.new("main", backend)
|
||||
|
||||
options.simulation(is_simulated)
|
||||
|
||||
options.set_loop_parallelize(configuration.loop_parallelize)
|
||||
@@ -114,8 +121,6 @@ class Server:
|
||||
options.set_auto_parallelize(configuration.auto_parallelize)
|
||||
options.set_compress_inputs(configuration.compress_inputs)
|
||||
options.set_composable(configuration.composable)
|
||||
options.set_emit_gpu_ops(configuration.use_gpu)
|
||||
options.set_batch_tfhe_ops(False)
|
||||
|
||||
if configuration.auto_parallelize or configuration.dataflow_parallelize:
|
||||
# pylint: disable=c-extension-no-member,no-member
|
||||
|
||||
@@ -282,7 +282,6 @@ class Helpers:
|
||||
only_simulation (bool, default = False):
|
||||
whether to just check simulation but not execution
|
||||
"""
|
||||
|
||||
if not isinstance(sample, list):
|
||||
sample = [sample]
|
||||
|
||||
@@ -330,6 +329,11 @@ class Helpers:
|
||||
raise AssertionError(message)
|
||||
|
||||
circuit.enable_fhe_simulation()
|
||||
|
||||
# Skip simulation for GPU
|
||||
if circuit.configuration.use_gpu:
|
||||
return
|
||||
|
||||
for i in range(retries):
|
||||
expected = sanitize(function(*deepcopy(sample)))
|
||||
actual = sanitize(circuit.simulate(*deepcopy(sample)))
|
||||
|
||||
Reference in New Issue
Block a user