Files
concrete/hdk/common/mlir/utils.py
Ayoub Benaissa f6c9618b5a feat(mlir): MLIR Conversion (#103)
* feat(mlir): conversion from HDKIR to MLIR

* feat(mlir): support ir.Sub and ir.Mul

- better type conversion from HDK to MLIR
- Context management inside the converter class
- better handling of input type in conversion functions

* refactor(mlir): use input and output from OPGraph

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>

* feat(mlir): eint-int subtractions

* feat(mlir): adhere to spec for supported ops

* feat(OPGraph): getters for ordered inputs/outputs

+ formatting

* tests(mlir): test converion via compiler roundtrip

* fix(mlir): flip operands on int_eint sym ops

* feat(mlir): check that the outputs are unsigned

* feat(mlir): set bit_width of all nodes to the max

This is currently required as the compiler is already assuming this.
Could be removed from HDK when the compiler can do it on its own

* feat: value_is_integer + CRs

disable some linting errors

* tests: update compile tests + coverage

* refactor: reorganize mlir package + better doc

* doc: conformance with pydocstyle

Co-authored-by: Arthur Meyre <arthur.meyre@zama.ai>
2021-08-13 12:50:31 +01:00

60 lines
1.9 KiB
Python

"""Utilities for MLIR conversion."""
from typing import cast
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
value_is_clear_integer,
value_is_encrypted_integer,
value_is_integer,
)
from ..operator_graph import OPGraph
def is_graph_values_compatible_with_mlir(op_graph: OPGraph) -> bool:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
Args:
op_graph: computation graph to check
Returns:
bool: is the graph compatible with the expected MLIR representation
"""
return all(
all(
value_is_integer(out) and not cast(Integer, out.data_type).is_signed
for out in out_node.outputs
)
for out_node in op_graph.output_nodes.values()
)
def _set_all_bit_width(op_graph: OPGraph, p: int):
"""Set all bit_width in the graph to `p` and `p+1` for clear and encrypted values respectively.
Args:
op_graph: graph to set bit_width for
p: bit_width to set everywhere
"""
for node in op_graph.graph.nodes:
for value in node.outputs + node.inputs:
if value_is_clear_integer(value):
value.data_type.bit_width = p + 1
elif value_is_encrypted_integer(value):
value.data_type.bit_width = p
def update_bit_width_for_mlir(op_graph: OPGraph):
"""Prepare bit_width of all nodes to be the same, set to the maximum value in the graph.
Args:
op_graph: graph to update bit_width for
"""
max_bit_width = 0
for node in op_graph.graph.nodes:
for value_out in node.outputs:
if value_is_clear_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width - 1)
elif value_is_encrypted_integer(value_out):
max_bit_width = max(max_bit_width, value_out.data_type.bit_width)
_set_all_bit_width(op_graph, max_bit_width)