Files
concrete/compiler/python/zamalang/compiler.py
2021-09-10 10:00:04 +01:00

94 lines
3.0 KiB
Python

"""Compiler submodule"""
from typing import List, Union
from _zamalang._compiler import CompilerEngine as _CompilerEngine
from _zamalang._compiler import ExecutionArgument as _ExecutionArgument
from _zamalang._compiler import round_trip as _round_trip
def round_trip(mlir_str: str) -> str:
"""Parse the MLIR input, then return it back.
Args:
mlir_str (str): MLIR code to parse.
Raises:
TypeError: if the argument is not an str.
Returns:
str: parsed MLIR input.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return _round_trip(mlir_str)
def create_execution_argument(value: Union[int, List[int]]) -> "_ExecutionArgument":
"""Create an execution argument holding either an int or tensor value.
Args:
value (Union[int, List[int]]): value of the argument, either an int, or a list of int
Raises:
TypeError: if the values aren't in the expected range, or using a wrong type
Returns:
_ExecutionArgument: execution argument holding the appropriate value
"""
if not isinstance(value, (int, list)):
raise TypeError("value of execution argument must be either int or list[int]")
if isinstance(value, int):
if not (0 <= value < (2 ** 64 - 1)):
raise TypeError("single integer must be in the range [0, 2**64 - 1] (uint64)")
else:
assert isinstance(value, list)
for elem in value:
if not (0 <= elem < (2 ** 8 - 1)):
raise TypeError("values of the list must be in the range [0, 255] (uint8)")
return _ExecutionArgument.create(value)
class CompilerEngine:
def __init__(self, mlir_str: str = None):
self._engine = _CompilerEngine()
if mlir_str is not None:
self.compile_fhe(mlir_str)
def compile_fhe(self, mlir_str: str) -> "CompilerEngine":
"""Compile the MLIR input and build a CompilerEngine.
Args:
mlir_str (str): MLIR to compile.
Raises:
TypeError: if the argument is not an str.
Returns:
CompilerEngine: engine used for execution.
"""
if not isinstance(mlir_str, str):
raise TypeError("input must be an `str`")
return self._engine.compile_fhe(mlir_str)
def run(self, *args: List[Union[int, List[int]]]) -> int:
"""Run the compiled code.
Args:
*args: list of arguments for execution. Each argument can be an int, or a list of int
Raises:
TypeError: if execution arguments can't be constructed
Returns:
int: result of execution.
"""
execution_arguments = [create_execution_argument(arg) for arg in args]
return self._engine.run(execution_arguments)
def get_compiled_module(self) -> str:
"""Compiled module in printable form.
Returns:
str: Compiled module in printable form.
"""
return self._engine.get_compiled_module()