mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor: remove as ir imports to avoid breaking sphinx docs links
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import List, Optional
|
||||
from .data_types.integers import Integer
|
||||
from .debugging import custom_assert
|
||||
from .operator_graph import OPGraph
|
||||
from .representation import intermediate as ir
|
||||
from .representation.intermediate import IntermediateNode
|
||||
|
||||
|
||||
def is_a_power_of_2(x: int) -> bool:
|
||||
@@ -22,11 +22,11 @@ def is_a_power_of_2(x: int) -> bool:
|
||||
return x > 0 and (x & (x - 1)) == 0
|
||||
|
||||
|
||||
def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool:
|
||||
def ir_nodes_has_integer_input_and_output(node: IntermediateNode) -> bool:
|
||||
"""Check if an ir node has Integer inputs and outputs.
|
||||
|
||||
Args:
|
||||
node (ir.IntermediateNode): Node to check
|
||||
node (IntermediateNode): Node to check
|
||||
|
||||
Returns:
|
||||
bool: True if all input and output values hold Integers
|
||||
@@ -40,13 +40,13 @@ def ir_nodes_has_integer_input_and_output(node: ir.IntermediateNode) -> bool:
|
||||
# long run probably
|
||||
def check_op_graph_is_integer_program(
|
||||
op_graph: OPGraph,
|
||||
offending_nodes_out: Optional[List[ir.IntermediateNode]] = None,
|
||||
offending_nodes_out: Optional[List[IntermediateNode]] = None,
|
||||
) -> bool:
|
||||
"""Check if an op_graph inputs, outputs and intermediate values are Integers.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph to check
|
||||
offending_nodes_out (Optional[List[ir.IntermediateNode]]): Optionally pass a list that will
|
||||
offending_nodes_out (Optional[List[IntermediateNode]]): Optionally pass a list that will
|
||||
be populated with offending nodes, the list will be cleared before being filled
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -12,7 +12,7 @@ from PIL import Image
|
||||
|
||||
from ..debugging import custom_assert, draw_graph, get_printable_graph
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
from ..values import BaseValue
|
||||
|
||||
DEFAULT_OUTPUT_DIRECTORY: Path = Path(".artifacts")
|
||||
@@ -30,7 +30,7 @@ class CompilationArtifacts:
|
||||
textual_representations_of_operation_graphs: Dict[str, str]
|
||||
|
||||
final_operation_graph: Optional[OPGraph]
|
||||
bounds_of_the_final_operation_graph: Optional[Dict[ir.IntermediateNode, Dict[str, Any]]]
|
||||
bounds_of_the_final_operation_graph: Optional[Dict[IntermediateNode, Dict[str, Any]]]
|
||||
mlir_of_the_final_operation_graph: Optional[str]
|
||||
|
||||
def __init__(self, output_directory: Path = DEFAULT_OUTPUT_DIRECTORY):
|
||||
@@ -92,11 +92,11 @@ class CompilationArtifacts:
|
||||
|
||||
self.final_operation_graph = operation_graph
|
||||
|
||||
def add_final_operation_graph_bounds(self, bounds: Dict[ir.IntermediateNode, Dict[str, Any]]):
|
||||
def add_final_operation_graph_bounds(self, bounds: Dict[IntermediateNode, Dict[str, Any]]):
|
||||
"""Add the bounds of the final operation graph to the artifacts.
|
||||
|
||||
Args:
|
||||
bounds (Dict[ir.IntermediateNode, Dict[str, Any]]): the bound dictionary
|
||||
bounds (Dict[IntermediateNode, Dict[str, Any]]): the bound dictionary
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
@@ -11,17 +11,25 @@ from PIL import Image
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import ALL_IR_NODES
|
||||
from ..representation.intermediate import (
|
||||
ALL_IR_NODES,
|
||||
Add,
|
||||
ArbitraryFunction,
|
||||
Constant,
|
||||
Dot,
|
||||
Input,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {
|
||||
ir.Input: "blue",
|
||||
ir.Constant: "cyan",
|
||||
ir.Add: "red",
|
||||
ir.Sub: "yellow",
|
||||
ir.Mul: "green",
|
||||
ir.ArbitraryFunction: "orange",
|
||||
ir.Dot: "purple",
|
||||
Input: "blue",
|
||||
Constant: "cyan",
|
||||
Add: "red",
|
||||
Sub: "yellow",
|
||||
Mul: "green",
|
||||
ArbitraryFunction: "orange",
|
||||
Dot: "purple",
|
||||
"ArbitraryFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
@@ -63,7 +71,7 @@ def draw_graph(
|
||||
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
||||
if node in output_nodes:
|
||||
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
||||
elif isinstance(node, ir.ArbitraryFunction):
|
||||
elif isinstance(node, ArbitraryFunction):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input
|
||||
|
||||
|
||||
def output_data_type_to_string(node):
|
||||
@@ -49,15 +49,15 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
# they only are done by incrementing i
|
||||
custom_assert(len(node.outputs) == 1)
|
||||
|
||||
if isinstance(node, ir.Input):
|
||||
if isinstance(node, Input):
|
||||
what_to_print = node.input_name
|
||||
elif isinstance(node, ir.Constant):
|
||||
elif isinstance(node, Constant):
|
||||
what_to_print = f"Constant({node.constant_data})"
|
||||
else:
|
||||
|
||||
base_name = node.__class__.__name__
|
||||
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
if isinstance(node, ArbitraryFunction):
|
||||
base_name = node.op_name
|
||||
|
||||
what_to_print = base_name + "("
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Iterable, Tuple, Union
|
||||
from ..common_helpers import is_a_power_of_2
|
||||
from ..data_types.base import BaseDataType
|
||||
from ..data_types.integers import make_integer_to_hold
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import ArbitraryFunction
|
||||
from ..tracing.base_tracer import BaseTracer
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ class LookupTable:
|
||||
# we need to create an `ArbitraryFunction` node
|
||||
# because the result will be determined during the runtime
|
||||
if isinstance(key, BaseTracer):
|
||||
traced_computation = ir.ArbitraryFunction(
|
||||
traced_computation = ArbitraryFunction(
|
||||
input_base_value=key.output,
|
||||
arbitrary_func=LookupTable._checked_indexing,
|
||||
output_dtype=self.output_dtype,
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_encrypted_tensor_integer,
|
||||
)
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import Add, ArbitraryFunction, Constant, Dot, Mul, Sub
|
||||
|
||||
|
||||
def add(node, preds, ir_to_mlir_node, ctx):
|
||||
@@ -189,12 +189,12 @@ def dot(node, preds, ir_to_mlir_node, ctx):
|
||||
|
||||
|
||||
V0_OPSET_CONVERSION_FUNCTIONS = {
|
||||
ir.Add: add,
|
||||
ir.Sub: sub,
|
||||
ir.Mul: mul,
|
||||
ir.Constant: constant,
|
||||
ir.ArbitraryFunction: apply_lut,
|
||||
ir.Dot: dot,
|
||||
Add: add,
|
||||
Sub: sub,
|
||||
Mul: mul,
|
||||
Constant: constant,
|
||||
ArbitraryFunction: apply_lut,
|
||||
Dot: dot,
|
||||
}
|
||||
|
||||
# pylint: enable=no-name-in-module,no-member
|
||||
|
||||
@@ -20,7 +20,7 @@ from ..data_types.dtypes_helpers import (
|
||||
)
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import Input
|
||||
|
||||
|
||||
class MLIRConverter:
|
||||
@@ -151,7 +151,7 @@ class MLIRConverter:
|
||||
for arg_num, node in op_graph.input_nodes.items():
|
||||
ir_to_mlir_node[node] = arg[arg_num]
|
||||
for node in nx.topological_sort(op_graph.graph):
|
||||
if isinstance(node, ir.Input):
|
||||
if isinstance(node, Input):
|
||||
continue
|
||||
mlir_op = self.conversion_functions.get(type(node), None)
|
||||
if mlir_op is None: # pragma: no cover
|
||||
|
||||
@@ -13,7 +13,7 @@ from .data_types.dtypes_helpers import (
|
||||
from .data_types.floats import Float
|
||||
from .data_types.integers import Integer, make_integer_to_hold
|
||||
from .debugging.custom_assert import custom_assert
|
||||
from .representation import intermediate as ir
|
||||
from .representation.intermediate import Input, IntermediateNode
|
||||
from .tracing import BaseTracer
|
||||
from .tracing.tracing_helpers import create_graph_from_output_tracers
|
||||
|
||||
@@ -22,25 +22,25 @@ class OPGraph:
|
||||
"""Class to make work with nx graphs easier."""
|
||||
|
||||
graph: nx.MultiDiGraph
|
||||
input_nodes: Dict[int, ir.Input]
|
||||
output_nodes: Dict[int, ir.IntermediateNode]
|
||||
input_nodes: Dict[int, Input]
|
||||
output_nodes: Dict[int, IntermediateNode]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Dict[int, ir.Input],
|
||||
output_nodes: Dict[int, ir.IntermediateNode],
|
||||
input_nodes: Dict[int, Input],
|
||||
output_nodes: Dict[int, IntermediateNode],
|
||||
) -> None:
|
||||
custom_assert(
|
||||
len(input_nodes) > 0, "Got a graph without input nodes which is not supported"
|
||||
)
|
||||
custom_assert(
|
||||
all(isinstance(node, ir.Input) for node in input_nodes.values()),
|
||||
"Got input nodes that were not ir.Input, which is not supported",
|
||||
all(isinstance(node, Input) for node in input_nodes.values()),
|
||||
"Got input nodes that were not Input, which is not supported",
|
||||
)
|
||||
custom_assert(
|
||||
all(isinstance(node, ir.IntermediateNode) for node in output_nodes.values()),
|
||||
"Got output nodes which were not ir.IntermediateNode, which is not supported",
|
||||
all(isinstance(node, IntermediateNode) for node in output_nodes.values()),
|
||||
"Got output nodes which were not IntermediateNode, which is not supported",
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
@@ -75,7 +75,7 @@ class OPGraph:
|
||||
input_nodes = {
|
||||
node.program_input_idx: node
|
||||
for node in graph.nodes()
|
||||
if len(graph.pred[node]) == 0 and isinstance(node, ir.Input)
|
||||
if len(graph.pred[node]) == 0 and isinstance(node, Input)
|
||||
}
|
||||
output_nodes = {
|
||||
output_idx: tracer.traced_computation
|
||||
@@ -86,50 +86,50 @@ class OPGraph:
|
||||
@staticmethod
|
||||
def from_graph(
|
||||
graph: nx.MultiDiGraph,
|
||||
input_nodes: Iterable[ir.Input],
|
||||
output_nodes: Iterable[ir.IntermediateNode],
|
||||
input_nodes: Iterable[Input],
|
||||
output_nodes: Iterable[IntermediateNode],
|
||||
) -> "OPGraph":
|
||||
"""Construct OPGraph from an existing networkx MultiDiGraph.
|
||||
|
||||
Args:
|
||||
graph (nx.MultiDiGraph): The networkx MultiDiGraph to use.
|
||||
input_nodes (Iterable[ir.Input]): The input nodes of the MultiDiGraph.
|
||||
output_nodes (Iterable[ir.IntermediateNode]): The output nodes of the MultiDiGraph.
|
||||
input_nodes (Iterable[Input]): The input nodes of the MultiDiGraph.
|
||||
output_nodes (Iterable[IntermediateNode]): The output nodes of the MultiDiGraph.
|
||||
|
||||
Returns:
|
||||
OPGraph: The resulting OPGraph.
|
||||
"""
|
||||
return OPGraph(graph, dict(enumerate(input_nodes)), dict(enumerate(output_nodes)))
|
||||
|
||||
def get_ordered_inputs(self) -> List[ir.Input]:
|
||||
def get_ordered_inputs(self) -> List[Input]:
|
||||
"""Get the input nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[ir.Input]: ordered input nodes
|
||||
List[Input]: ordered input nodes
|
||||
"""
|
||||
return [self.input_nodes[idx] for idx in range(len(self.input_nodes))]
|
||||
|
||||
def get_ordered_outputs(self) -> List[ir.IntermediateNode]:
|
||||
def get_ordered_outputs(self) -> List[IntermediateNode]:
|
||||
"""Get the output nodes of the graph, ordered by their index.
|
||||
|
||||
Returns:
|
||||
List[ir.IntermediateNode]: ordered input nodes
|
||||
List[IntermediateNode]: ordered input nodes
|
||||
"""
|
||||
return [self.output_nodes[idx] for idx in range(len(self.output_nodes))]
|
||||
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[ir.IntermediateNode, Any]:
|
||||
def evaluate(self, inputs: Dict[int, Any]) -> Dict[IntermediateNode, Any]:
|
||||
"""Evaluate a graph and get intermediate values for all nodes.
|
||||
|
||||
Args:
|
||||
inputs (Dict[int, Any]): The inputs to the program
|
||||
|
||||
Returns:
|
||||
Dict[ir.IntermediateNode, Any]: Dictionary with node as keys and resulting values
|
||||
Dict[IntermediateNode, Any]: Dictionary with node as keys and resulting values
|
||||
"""
|
||||
node_results: Dict[ir.IntermediateNode, Any] = {}
|
||||
node_results: Dict[IntermediateNode, Any] = {}
|
||||
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if not isinstance(node, ir.Input):
|
||||
if not isinstance(node, Input):
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.pred[node]:
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
@@ -168,7 +168,7 @@ class OPGraph:
|
||||
callback function to determine the type constructor of the data encountered while
|
||||
updating the graph bounds. Defaults to get_type_constructor_python_constant_data.
|
||||
"""
|
||||
node: ir.IntermediateNode
|
||||
node: IntermediateNode
|
||||
|
||||
for node in self.graph.nodes():
|
||||
current_node_bounds = node_bounds[node]
|
||||
@@ -193,7 +193,7 @@ class OPGraph:
|
||||
|
||||
data_type_constructor = max_data_type_constructor
|
||||
|
||||
if not isinstance(node, ir.Input):
|
||||
if not isinstance(node, Input):
|
||||
for output_value in node.outputs:
|
||||
if isinstance(min_data_type, Integer) and isinstance(max_data_type, Integer):
|
||||
output_value.data_type = make_integer_to_hold(
|
||||
@@ -242,9 +242,9 @@ class OPGraph:
|
||||
"""Remove unreachable nodes from outputs."""
|
||||
|
||||
current_nodes = set(self.output_nodes.values())
|
||||
useful_nodes: Set[ir.IntermediateNode] = set()
|
||||
useful_nodes: Set[IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Set[ir.IntermediateNode] = set()
|
||||
next_nodes: Set[IntermediateNode] = set()
|
||||
useful_nodes.update(current_nodes)
|
||||
for node in current_nodes:
|
||||
next_nodes.update(self.graph.pred[node])
|
||||
|
||||
@@ -9,7 +9,7 @@ from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import ArbitraryFunction, Constant, Input, IntermediateNode
|
||||
|
||||
|
||||
def fuse_float_operations(
|
||||
@@ -26,7 +26,7 @@ def fuse_float_operations(
|
||||
"""
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
processed_terminal_nodes: Set[ir.IntermediateNode] = set()
|
||||
processed_terminal_nodes: Set[IntermediateNode] = set()
|
||||
number_of_fuse = 0
|
||||
while True:
|
||||
float_subgraph_search_result = find_float_subgraph_with_unique_terminal_node(
|
||||
@@ -56,7 +56,7 @@ def fuse_float_operations(
|
||||
if terminal_node in op_graph.output_nodes.values():
|
||||
# Output value replace it
|
||||
# As the graph changes recreate the output_node_to_idx dict
|
||||
output_node_to_idx: Dict[ir.IntermediateNode, List[int]] = {
|
||||
output_node_to_idx: Dict[IntermediateNode, List[int]] = {
|
||||
out_node: [] for out_node in op_graph.output_nodes.values()
|
||||
}
|
||||
for output_idx, output_node in op_graph.output_nodes.items():
|
||||
@@ -87,21 +87,21 @@ def fuse_float_operations(
|
||||
|
||||
def convert_float_subgraph_to_fused_node(
|
||||
op_graph: OPGraph,
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode],
|
||||
terminal_node: ir.IntermediateNode,
|
||||
subgraph_all_nodes: Set[ir.IntermediateNode],
|
||||
) -> Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]:
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
terminal_node: IntermediateNode,
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
) -> Optional[Tuple[ArbitraryFunction, IntermediateNode]]:
|
||||
"""Convert a float subgraph to an equivalent fused ArbitraryFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the float subgraph
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph
|
||||
in `op_graph`.
|
||||
terminal_node (ir.IntermediateNode): The node ending the float subgraph.
|
||||
subgraph_all_nodes (Set[ir.IntermediateNode]): All the nodes in the float subgraph.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[ir.ArbitraryFunction, ir.IntermediateNode]]: None if the float subgraph
|
||||
Optional[Tuple[ArbitraryFunction, IntermediateNode]]: None if the float subgraph
|
||||
cannot be fused, otherwise returns a tuple containing the fused node and the node whose
|
||||
output must be plugged as the input to the subgraph.
|
||||
"""
|
||||
@@ -111,7 +111,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
# Only one variable input node, find which node feeds its input
|
||||
non_constant_start_nodes = [
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, ir.Constant)
|
||||
node for node in float_subgraph_start_nodes if not isinstance(node, Constant)
|
||||
]
|
||||
custom_assert(len(non_constant_start_nodes) == 1)
|
||||
|
||||
@@ -126,7 +126,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
|
||||
|
||||
new_subgraph_variable_input = ir.Input(new_input_value, "float_subgraph_input", 0)
|
||||
new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0)
|
||||
float_subgraph.add_node(new_subgraph_variable_input)
|
||||
|
||||
for node_after_input in nodes_after_input_set:
|
||||
@@ -155,7 +155,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
)
|
||||
|
||||
# Create fused_node
|
||||
fused_node = ir.ArbitraryFunction(
|
||||
fused_node = ArbitraryFunction(
|
||||
deepcopy(new_subgraph_variable_input.inputs[0]),
|
||||
lambda x, float_op_subgraph, terminal_node: float_op_subgraph.evaluate({0: x})[
|
||||
terminal_node
|
||||
@@ -176,8 +176,8 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[ir.IntermediateNode],
|
||||
) -> Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]:
|
||||
processed_terminal_nodes: Set[IntermediateNode],
|
||||
) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
|
||||
"""Find a subgraph of the graph with float computations.
|
||||
|
||||
The subgraph has a single terminal node with a single Integer output and has a single variable
|
||||
@@ -185,24 +185,24 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
|
||||
processed_terminal_nodes (Set[ir.IntermediateNode]): The set of terminal nodes for which
|
||||
processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which
|
||||
subgraphs have already been searched, those will be skipped.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Set[ir.IntermediateNode], ir.IntermediateNode, Set[ir.IntermediateNode]]]:
|
||||
Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
|
||||
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple
|
||||
containing the set of nodes beginning a float subgraph, the terminal node of the
|
||||
subgraph and the set of all the nodes in the subgraph.
|
||||
"""
|
||||
|
||||
def is_float_to_single_int_node(node: ir.IntermediateNode) -> bool:
|
||||
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
|
||||
return (
|
||||
any(isinstance(input_.data_type, Float) for input_ in node.inputs)
|
||||
and len(node.outputs) == 1
|
||||
and isinstance(node.outputs[0].data_type, Integer)
|
||||
)
|
||||
|
||||
def single_int_output_node(node: ir.IntermediateNode) -> bool:
|
||||
def single_int_output_node(node: IntermediateNode) -> bool:
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].data_type, Integer)
|
||||
|
||||
float_subgraphs_terminal_nodes = (
|
||||
@@ -211,7 +211,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
if is_float_to_single_int_node(node) and node not in processed_terminal_nodes
|
||||
)
|
||||
|
||||
terminal_node: ir.IntermediateNode
|
||||
terminal_node: IntermediateNode
|
||||
|
||||
try:
|
||||
terminal_node = next(float_subgraphs_terminal_nodes)
|
||||
@@ -220,10 +220,10 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
# Use dict as ordered set
|
||||
current_nodes = {terminal_node: None}
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode] = set()
|
||||
subgraph_all_nodes: Set[ir.IntermediateNode] = set()
|
||||
float_subgraph_start_nodes: Set[IntermediateNode] = set()
|
||||
subgraph_all_nodes: Set[IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Dict[ir.IntermediateNode, None] = {}
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
for node in current_nodes:
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
@@ -240,16 +240,16 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Set[ir.IntermediateNode],
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
) -> bool:
|
||||
"""Check that only one of the nodes starting the subgraph is variable.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[ir.IntermediateNode]): The nodes starting the subgraph.
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph.
|
||||
|
||||
Returns:
|
||||
bool: True if only one of the nodes is not an ir.Constant
|
||||
bool: True if only one of the nodes is not an Constant
|
||||
"""
|
||||
# Only one input to the subgraph where computations are done in floats is variable, this
|
||||
# is the only case we can manage with ArbitraryFunction fusing
|
||||
return sum(not isinstance(node, ir.Constant) for node in float_subgraph_start_nodes) == 1
|
||||
return sum(not isinstance(node, Constant) for node in float_subgraph_start_nodes) == 1
|
||||
|
||||
@@ -4,8 +4,13 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Iterable, List, Tuple, Type, Union
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import IR_MIX_VALUES_FUNC_ARG_NAME
|
||||
from ..representation.intermediate import (
|
||||
IR_MIX_VALUES_FUNC_ARG_NAME,
|
||||
Add,
|
||||
IntermediateNode,
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from ..values import BaseValue
|
||||
|
||||
|
||||
@@ -13,14 +18,14 @@ class BaseTracer(ABC):
|
||||
"""Base class for implementing tracers."""
|
||||
|
||||
inputs: List["BaseTracer"]
|
||||
traced_computation: ir.IntermediateNode
|
||||
traced_computation: IntermediateNode
|
||||
output: BaseValue
|
||||
_mix_values_func: Callable[..., BaseValue]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inputs: Iterable["BaseTracer"],
|
||||
traced_computation: ir.IntermediateNode,
|
||||
traced_computation: IntermediateNode,
|
||||
output_index: int,
|
||||
) -> None:
|
||||
self.inputs = list(inputs)
|
||||
@@ -62,14 +67,14 @@ class BaseTracer(ABC):
|
||||
def instantiate_output_tracers(
|
||||
self,
|
||||
inputs: Iterable[Union["BaseTracer", Any]],
|
||||
computation_to_trace: Type[ir.IntermediateNode],
|
||||
computation_to_trace: Type[IntermediateNode],
|
||||
) -> Tuple["BaseTracer", ...]:
|
||||
"""Instantiate all output BaseTracer for a given computation.
|
||||
|
||||
Args:
|
||||
inputs (Iterable[Union[BaseTracer, Any]]): Previous BaseTracer or data used as inputs
|
||||
for a new node.
|
||||
computation_to_trace (Type[ir.IntermediateNode]): The IntermediateNode class
|
||||
computation_to_trace (Type[IntermediateNode]): The IntermediateNode class
|
||||
to instantiate for the computation being traced
|
||||
|
||||
Returns:
|
||||
@@ -103,7 +108,7 @@ class BaseTracer(ABC):
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Add,
|
||||
Add,
|
||||
)
|
||||
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
@@ -120,7 +125,7 @@ class BaseTracer(ABC):
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Sub,
|
||||
Sub,
|
||||
)
|
||||
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
@@ -132,7 +137,7 @@ class BaseTracer(ABC):
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[other, self],
|
||||
ir.Sub,
|
||||
Sub,
|
||||
)
|
||||
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
@@ -144,7 +149,7 @@ class BaseTracer(ABC):
|
||||
|
||||
result_tracer = self.instantiate_output_tracers(
|
||||
[self, other],
|
||||
ir.Mul,
|
||||
Mul,
|
||||
)
|
||||
|
||||
custom_assert(len(result_tracer) == 1)
|
||||
|
||||
@@ -7,7 +7,7 @@ import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..representation import intermediate as ir
|
||||
from ..representation.intermediate import Input
|
||||
from ..values import BaseValue
|
||||
from .base_tracer import BaseTracer
|
||||
|
||||
@@ -50,7 +50,7 @@ def make_input_tracer(
|
||||
Returns:
|
||||
BaseTracer: The BaseTracer for that input value
|
||||
"""
|
||||
return tracer_class([], ir.Input(input_value, input_name, input_idx), 0)
|
||||
return tracer_class([], Input(input_value, input_name, input_idx), 0)
|
||||
|
||||
|
||||
def prepare_function_parameters(
|
||||
|
||||
@@ -19,7 +19,7 @@ from ..common.mlir.utils import (
|
||||
)
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.optimization.topological import fuse_float_operations
|
||||
from ..common.representation import intermediate as ir
|
||||
from ..common.representation.intermediate import IntermediateNode
|
||||
from ..common.values import BaseValue
|
||||
from ..numpy.tracing import trace_numpy_function
|
||||
from .np_dtypes_helpers import (
|
||||
@@ -99,7 +99,7 @@ def _compile_numpy_function_into_op_graph_internal(
|
||||
fuse_float_operations(op_graph, compilation_artifacts)
|
||||
|
||||
# TODO: To be removed once we support more than integers
|
||||
offending_non_integer_nodes: List[ir.IntermediateNode] = []
|
||||
offending_non_integer_nodes: List[IntermediateNode] = []
|
||||
op_grap_is_int_prog = check_op_graph_is_integer_program(op_graph, offending_non_integer_nodes)
|
||||
if not op_grap_is_int_prog:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user