feat(optimization): support more fusing topologies

- corrected docstring that was mistaken on what was returned
- updated pyproject.toml to ignore warnings that happened naturally in
networkx and that was blocking proper test execution (no way around that
this is code from networkx that triggered the warning)
- add a test case for the newly supported fusing topology

closes #499
This commit is contained in:
Arthur Meyre
2021-11-17 11:53:35 +01:00
parent bc145e21e1
commit ff03bc2220
4 changed files with 153 additions and 32 deletions

View File

@@ -27,6 +27,34 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def fusable_with_bigger_search(x, y):
"""fusable with bigger search"""
x = x + 1
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_2 = x.astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def fusable_with_bigger_search_needs_second_iteration(x, y):
"""fusable with bigger search and triggers a second iteration in the fusing"""
x = x + 1
x = x + 0.5
x = numpy.cos(x)
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_p = x + 1
x_p2 = x_p + 1
x_2 = (x_p + x_p2).astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def no_fuse_big_constant_3_10_10(x):
"""Pass an array x with size < 100 to trigger a no fuse condition."""
x = x.astype(numpy.float64)
@@ -177,6 +205,20 @@ return %7
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_unhandled",
),
pytest.param(
fusable_with_bigger_search,
True,
get_func_params_int32(fusable_with_bigger_search),
None,
id="fusable_with_bigger_search",
),
pytest.param(
fusable_with_bigger_search_needs_second_iteration,
True,
get_func_params_int32(fusable_with_bigger_search_needs_second_iteration),
None,
id="fusable_with_bigger_search",
),
pytest.param(
no_fuse_dot,
False,