mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
* feat(mlir): conversion from HDKIR to MLIR * feat(mlir): support ir.Sub and ir.Mul - better type conversion from HDK to MLIR - Context management inside the converter class - better handling of input type in conversion functions * refactor(mlir): use input and output from OPGraph Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai> * feat(mlir): eint-int subtractions * feat(mlir): adhere to spec for supported ops * feat(OPGraph): getters for ordered inputs/outputs + formatting * tests(mlir): test converion via compiler roundtrip * fix(mlir): flip operands on int_eint sym ops * feat(mlir): check that the outputs are unsigned * feat(mlir): set bit_width of all nodes to the max This is currently required as the compiler is already assuming this. Could be removed from HDK when the compiler can do it on its own * feat: value_is_integer + CRs disable some linting errors * tests: update compile tests + coverage * refactor: reorganize mlir package + better doc * doc: conformance with pydocstyle Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>
62 lines
2.4 KiB
Python
62 lines
2.4 KiB
Python
"""hnumpy compilation function."""
|
|
|
|
from typing import Any, Callable, Dict, Iterator, Optional, Tuple
|
|
|
|
from hdk.common.bounds_measurement.dataset_eval import eval_op_graph_bounds_on_dataset
|
|
from hdk.hnumpy.tracing import trace_numpy_function
|
|
|
|
from ..common.compilation import CompilationArtifacts
|
|
from ..common.data_types import BaseValue
|
|
from ..common.mlir.utils import (
|
|
is_graph_values_compatible_with_mlir,
|
|
update_bit_width_for_mlir,
|
|
)
|
|
from ..common.operator_graph import OPGraph
|
|
from ..hnumpy.tracing import trace_numpy_function
|
|
|
|
|
|
def compile_numpy_function(
|
|
function_to_trace: Callable,
|
|
function_parameters: Dict[str, BaseValue],
|
|
dataset: Iterator[Tuple[Any, ...]],
|
|
compilation_artifacts: Optional[CompilationArtifacts] = None,
|
|
) -> OPGraph:
|
|
"""Main API of hnumpy, to be able to compile an homomorphic program.
|
|
|
|
Args:
|
|
function_to_trace (Callable): The function you want to trace
|
|
function_parameters (Dict[str, BaseValue]): A dictionary indicating what each input of the
|
|
function is e.g. an EncryptedValue holding a 7bits unsigned Integer
|
|
dataset (Iterator[Tuple[Any, ...]]): The dataset over which op_graph is evaluated. It
|
|
needs to be an iterator on tuples which are of the same length than the number of
|
|
parameters in the function, and in the same order than these same parameters
|
|
compilation_artifacts (Optional[CompilationArtifacts]): Artifacts object to fill
|
|
during compilation
|
|
|
|
Returns:
|
|
OPGraph: currently returns a compilable graph, but later, it will return an MLIR compatible
|
|
with the compiler, and even later, it will return the result of the compilation
|
|
"""
|
|
# Trace
|
|
op_graph = trace_numpy_function(function_to_trace, function_parameters)
|
|
|
|
# Find bounds with the dataset
|
|
node_bounds = eval_op_graph_bounds_on_dataset(op_graph, dataset)
|
|
|
|
# Update the graph accordingly: after that, we have the compilable graph
|
|
op_graph.update_values_with_bounds(node_bounds)
|
|
|
|
# Make sure the graph can be lowered to MLIR
|
|
if not is_graph_values_compatible_with_mlir(op_graph):
|
|
raise TypeError("signed integers aren't supported for MLIR lowering")
|
|
|
|
# Update bit_width for MLIR
|
|
update_bit_width_for_mlir(op_graph)
|
|
|
|
# Fill compilation artifacts
|
|
if compilation_artifacts is not None:
|
|
compilation_artifacts.operation_graph = op_graph
|
|
compilation_artifacts.bounds = node_bounds
|
|
|
|
return op_graph
|